Skip to content

Commit

Permalink
sock: Introduce sk->sk_prot->psock_update_sk_prot()
Browse files Browse the repository at this point in the history
Currently sockmap calls into each protocol to update the struct
proto and replace it. This certainly won't work when the protocol
is implemented as a module, for example, AF_UNIX.

Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
protocol can implement its own way to replace the struct proto.
This also helps get rid of symbol dependencies on CONFIG_INET.

Signed-off-by: Cong Wang <[email protected]>
Signed-off-by: Alexei Starovoitov <[email protected]>
Link: https://lore.kernel.org/bpf/[email protected]
  • Loading branch information
Cong Wang authored and Alexei Starovoitov committed Apr 1, 2021
1 parent a7ba455 commit 8a59f9d
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 45 deletions.
18 changes: 3 additions & 15 deletions include/linux/skmsg.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ struct sk_psock {
void (*saved_close)(struct sock *sk, long timeout);
void (*saved_write_space)(struct sock *sk);
void (*saved_data_ready)(struct sock *sk);
int (*psock_update_sk_prot)(struct sock *sk, bool restore);
struct proto *sk_proto;
struct mutex work_mutex;
struct sk_psock_work_state work_state;
Expand Down Expand Up @@ -395,25 +396,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
}
}

static inline void sk_psock_update_proto(struct sock *sk,
struct sk_psock *psock,
struct proto *ops)
{
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, ops);
}

static inline void sk_psock_restore_proto(struct sock *sk,
struct sk_psock *psock)
{
sk->sk_prot->unhash = psock->saved_unhash;
if (inet_csk_has_ulp(sk)) {
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
} else {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
}
if (psock->psock_update_sk_prot)
psock->psock_update_sk_prot(sk, true);
}

static inline void sk_psock_set_state(struct sk_psock *psock,
Expand Down
3 changes: 3 additions & 0 deletions include/net/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,9 @@ struct proto {
void (*unhash)(struct sock *sk);
void (*rehash)(struct sock *sk);
int (*get_port)(struct sock *sk, unsigned short snum);
#ifdef CONFIG_BPF_SYSCALL
int (*psock_update_sk_prot)(struct sock *sk, bool restore);
#endif

/* Keeping track of sockets in use */
#ifdef CONFIG_PROC_FS
Expand Down
1 change: 1 addition & 0 deletions include/net/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2203,6 +2203,7 @@ struct sk_psock;

#ifdef CONFIG_BPF_SYSCALL
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int tcp_bpf_update_proto(struct sock *sk, bool restore);
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
#endif /* CONFIG_BPF_SYSCALL */

Expand Down
1 change: 1 addition & 0 deletions include/net/udp.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
#ifdef CONFIG_BPF_SYSCALL
struct sk_psock;
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int udp_bpf_update_proto(struct sock *sk, bool restore);
#endif

#endif /* _UDP_H */
5 changes: 0 additions & 5 deletions net/core/skmsg.c
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)

write_lock_bh(&sk->sk_callback_lock);

if (inet_csk_has_ulp(sk)) {
psock = ERR_PTR(-EINVAL);
goto out;
}

if (sk->sk_user_data) {
psock = ERR_PTR(-EBUSY);
goto out;
Expand Down
24 changes: 4 additions & 20 deletions net/core/sock_map.c
Original file line number Diff line number Diff line change
Expand Up @@ -185,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)

static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
{
struct proto *prot;

switch (sk->sk_type) {
case SOCK_STREAM:
prot = tcp_bpf_get_proto(sk, psock);
break;

case SOCK_DGRAM:
prot = udp_bpf_get_proto(sk, psock);
break;

default:
if (!sk->sk_prot->psock_update_sk_prot)
return -EINVAL;
}

if (IS_ERR(prot))
return PTR_ERR(prot);

sk_psock_update_proto(sk, psock, prot);
return 0;
psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
return sk->sk_prot->psock_update_sk_prot(sk, false);
}

static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
Expand Down Expand Up @@ -556,7 +540,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk)

static bool sock_map_sk_is_suitable(const struct sock *sk)
{
return sk_is_tcp(sk) || sk_is_udp(sk);
return !!sk->sk_prot->psock_update_sk_prot;
}

static bool sock_map_sk_state_allowed(const struct sock *sk)
Expand Down
24 changes: 21 additions & 3 deletions net/ipv4/tcp_bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
}

struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
int tcp_bpf_update_proto(struct sock *sk, bool restore)
{
struct sk_psock *psock = sk_psock(sk);
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;

if (restore) {
if (inet_csk_has_ulp(sk)) {
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
} else {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
}
return 0;
}

if (inet_csk_has_ulp(sk))
return -EINVAL;

if (sk->sk_family == AF_INET6) {
if (tcp_bpf_assert_proto_ops(psock->sk_proto))
return ERR_PTR(-EINVAL);
return -EINVAL;

tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
}

return &tcp_bpf_prots[family][config];
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
return 0;
}
EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);

/* If a child got cloned from a listening socket that had tcp_bpf
* protocol callbacks installed, we need to restore the callbacks to
Expand Down
3 changes: 3 additions & 0 deletions net/ipv4/tcp_ipv4.c
Original file line number Diff line number Diff line change
Expand Up @@ -2806,6 +2806,9 @@ struct proto tcp_prot = {
.hash = inet_hash,
.unhash = inet_unhash,
.get_port = inet_csk_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = tcp_bpf_update_proto,
#endif
.enter_memory_pressure = tcp_enter_memory_pressure,
.leave_memory_pressure = tcp_leave_memory_pressure,
.stream_memory_free = tcp_stream_memory_free,
Expand Down
3 changes: 3 additions & 0 deletions net/ipv4/udp.c
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,9 @@ struct proto udp_prot = {
.unhash = udp_lib_unhash,
.rehash = udp_v4_rehash,
.get_port = udp_v4_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = udp_bpf_update_proto,
#endif
.memory_allocated = &udp_memory_allocated,
.sysctl_mem = sysctl_udp_mem,
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
Expand Down
15 changes: 13 additions & 2 deletions net/ipv4/udp_bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void)
}
core_initcall(udp_bpf_v4_build_proto);

struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
int udp_bpf_update_proto(struct sock *sk, bool restore)
{
int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
struct sk_psock *psock = sk_psock(sk);

if (restore) {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
return 0;
}

if (sk->sk_family == AF_INET6)
udp_bpf_check_v6_needs_rebuild(psock->sk_proto);

return &udp_bpf_prots[family];
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
return 0;
}
EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
3 changes: 3 additions & 0 deletions net/ipv6/tcp_ipv6.c
Original file line number Diff line number Diff line change
Expand Up @@ -2139,6 +2139,9 @@ struct proto tcpv6_prot = {
.hash = inet6_hash,
.unhash = inet_unhash,
.get_port = inet_csk_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = tcp_bpf_update_proto,
#endif
.enter_memory_pressure = tcp_enter_memory_pressure,
.leave_memory_pressure = tcp_leave_memory_pressure,
.stream_memory_free = tcp_stream_memory_free,
Expand Down
3 changes: 3 additions & 0 deletions net/ipv6/udp.c
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,9 @@ struct proto udpv6_prot = {
.unhash = udp_lib_unhash,
.rehash = udp_v6_rehash,
.get_port = udp_v6_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = udp_bpf_update_proto,
#endif
.memory_allocated = &udp_memory_allocated,
.sysctl_mem = sysctl_udp_mem,
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
Expand Down

0 comments on commit 8a59f9d

Please sign in to comment.