Skip to content

Commit

Permalink
selftests/bpf: Introduce __attribute__((cleanup)) in create_pair()
Browse files Browse the repository at this point in the history
Rewrite function to have (unneeded) socket descriptors automatically
close()d when leaving the scope. Make sure the "ownership" of fds is
correctly passed via take_fd(); i.e. descriptor returned to caller will
remain valid.

Reviewed-by: Jakub Sitnicki <[email protected]>
Tested-by: Jakub Sitnicki <[email protected]>
Suggested-by: Jakub Sitnicki <[email protected]>
Signed-off-by: Michal Luczaj <[email protected]>
Link: https://lore.kernel.org/r/[email protected]
Signed-off-by: Martin KaFai Lau <[email protected]>
  • Loading branch information
mmhal authored and Martin KaFai Lau committed Aug 19, 2024
1 parent c9c70b2 commit 86149b4
Showing 1 changed file with 36 additions and 25 deletions.
61 changes: 36 additions & 25 deletions tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@

#define __always_unused __attribute__((__unused__))

/* include/linux/cleanup.h */
#define __get_and_null(p, nullvalue) \
({ \
__auto_type __ptr = &(p); \
__auto_type __val = *__ptr; \
*__ptr = nullvalue; \
__val; \
})

#define take_fd(fd) __get_and_null(fd, -EBADF)

#define _FAIL(errnum, fmt...) \
({ \
error_at_line(0, (errnum), __func__, __LINE__, fmt); \
Expand Down Expand Up @@ -182,6 +193,14 @@
__ret; \
})

static inline void close_fd(int *fd)
{
if (*fd >= 0)
xclose(*fd);
}

#define __close_fd __attribute__((cleanup(close_fd)))

static inline int poll_connect(int fd, unsigned int timeout_sec)
{
struct timeval timeout = { .tv_sec = timeout_sec };
Expand Down Expand Up @@ -369,72 +388,64 @@ static inline int socket_loopback(int family, int sotype)

static inline int create_pair(int family, int sotype, int *p0, int *p1)
{
__close_fd int s, c = -1, p = -1;
struct sockaddr_storage addr;
socklen_t len = sizeof(addr);
int s, c, p, err;
int err;

s = socket_loopback(family, sotype);
if (s < 0)
return s;

err = xgetsockname(s, sockaddr(&addr), &len);
if (err)
goto close_s;
return err;

c = xsocket(family, sotype, 0);
if (c < 0) {
err = c;
goto close_s;
}
if (c < 0)
return c;

err = connect(c, sockaddr(&addr), len);
if (err) {
if (errno != EINPROGRESS) {
FAIL_ERRNO("connect");
goto close_c;
return err;
}

err = poll_connect(c, IO_TIMEOUT_SEC);
if (err) {
FAIL_ERRNO("poll_connect");
goto close_c;
return err;
}
}

switch (sotype & SOCK_TYPE_MASK) {
case SOCK_DGRAM:
err = xgetsockname(c, sockaddr(&addr), &len);
if (err)
goto close_c;
return err;

err = xconnect(s, sockaddr(&addr), len);
if (!err) {
*p0 = s;
*p1 = c;
if (err)
return err;
}

*p0 = take_fd(s);
break;
case SOCK_STREAM:
case SOCK_SEQPACKET:
p = xaccept_nonblock(s, NULL, NULL);
if (p >= 0) {
*p0 = p;
*p1 = c;
goto close_s;
}
if (p < 0)
return p;

err = p;
*p0 = take_fd(p);
break;
default:
FAIL("Unsupported socket type %#x", sotype);
err = -EOPNOTSUPP;
return -EOPNOTSUPP;
}

close_c:
close(c);
close_s:
close(s);
return err;
*p1 = take_fd(c);
return 0;
}

static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
Expand Down

0 comments on commit 86149b4

Please sign in to comment.