Skip to content

Commit

Permalink
net/tls: use RCU protection on icsk->icsk_ulp_data
Browse files Browse the repository at this point in the history
We need to make sure context does not get freed while diag
code is interrogating it. Free struct tls_context with
kfree_rcu().

We add the __rcu annotation directly in icsk, and cast it
away in the datapath accessor. Presumably all ULPs will
do a similar thing.

Signed-off-by: Jakub Kicinski <[email protected]>
Signed-off-by: David S. Miller <[email protected]>
  • Loading branch information
Jakub Kicinski authored and davem330 committed Sep 1, 2019
1 parent ed6e810 commit 15a7dea
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 12 deletions.
2 changes: 1 addition & 1 deletion include/net/inet_connection_sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct inet_connection_sock {
const struct tcp_congestion_ops *icsk_ca_ops;
const struct inet_connection_sock_af_ops *icsk_af_ops;
const struct tcp_ulp_ops *icsk_ulp_ops;
void *icsk_ulp_data;
void __rcu *icsk_ulp_data;
void (*icsk_clean_acked)(struct sock *sk, u32 acked_seq);
struct hlist_node icsk_listen_portaddr_node;
unsigned int (*icsk_sync_mss)(struct sock *sk, u32 pmtu);
Expand Down
9 changes: 7 additions & 2 deletions include/net/tls.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <linux/tcp.h>
#include <linux/skmsg.h>
#include <linux/netdevice.h>
#include <linux/rcupdate.h>

#include <net/tcp.h>
#include <net/strparser.h>
Expand Down Expand Up @@ -290,6 +291,7 @@ struct tls_context {

struct list_head list;
refcount_t refcount;
struct rcu_head rcu;
};

enum tls_offload_ctx_dir {
Expand Down Expand Up @@ -348,7 +350,7 @@ struct tls_offload_context_rx {
#define TLS_OFFLOAD_CONTEXT_SIZE_RX \
(sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX)

void tls_ctx_free(struct tls_context *ctx);
void tls_ctx_free(struct sock *sk, struct tls_context *ctx);
int wait_on_pending_writer(struct sock *sk, long *timeo);
int tls_sk_query(struct sock *sk, int optname, char __user *optval,
int __user *optlen);
Expand Down Expand Up @@ -467,7 +469,10 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);

return icsk->icsk_ulp_data;
/* Use RCU on icsk_ulp_data only for sock diag code,
* TLS data path doesn't need rcu_dereference().
*/
return (__force void *)icsk->icsk_ulp_data;
}

static inline void tls_advance_record_sn(struct sock *sk,
Expand Down
2 changes: 1 addition & 1 deletion net/core/sock_map.c
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
return -EINVAL;
if (unlikely(idx >= map->max_entries))
return -E2BIG;
if (unlikely(icsk->icsk_ulp_data))
if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data)))
return -EINVAL;

link = sk_psock_init_link();
Expand Down
2 changes: 1 addition & 1 deletion net/tls/tls_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ static void tls_device_free_ctx(struct tls_context *ctx)
if (ctx->rx_conf == TLS_HW)
kfree(tls_offload_ctx_rx(ctx));

tls_ctx_free(ctx);
tls_ctx_free(NULL, ctx);
}

static void tls_device_gc_task(struct work_struct *work)
Expand Down
26 changes: 19 additions & 7 deletions net/tls/tls_main.c
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,26 @@ static void tls_write_space(struct sock *sk)
ctx->sk_write_space(sk);
}

void tls_ctx_free(struct tls_context *ctx)
/**
* tls_ctx_free() - free TLS ULP context
* @sk: socket to with @ctx is attached
* @ctx: TLS context structure
*
* Free TLS context. If @sk is %NULL caller guarantees that the socket
* to which @ctx was attached has no outstanding references.
*/
void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
{
if (!ctx)
return;

memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
kfree(ctx);

if (sk)
kfree_rcu(ctx, rcu);
else
kfree(ctx);
}

static void tls_sk_proto_cleanup(struct sock *sk,
Expand Down Expand Up @@ -306,7 +318,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)

write_lock_bh(&sk->sk_callback_lock);
if (free_ctx)
icsk->icsk_ulp_data = NULL;
rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
sk->sk_prot = ctx->sk_proto;
if (sk->sk_write_space == tls_write_space)
sk->sk_write_space = ctx->sk_write_space;
Expand All @@ -321,7 +333,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
ctx->sk_proto_close(sk, timeout);

if (free_ctx)
tls_ctx_free(ctx);
tls_ctx_free(sk, ctx);
}

static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
Expand Down Expand Up @@ -610,7 +622,7 @@ static struct tls_context *create_ctx(struct sock *sk)
if (!ctx)
return NULL;

icsk->icsk_ulp_data = ctx;
rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
ctx->setsockopt = sk->sk_prot->setsockopt;
ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close;
Expand Down Expand Up @@ -651,8 +663,8 @@ static void tls_hw_sk_destruct(struct sock *sk)

ctx->sk_destruct(sk);
/* Free ctx */
tls_ctx_free(ctx);
icsk->icsk_ulp_data = NULL;
rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
tls_ctx_free(sk, ctx);
}

static int tls_hw_prot(struct sock *sk)
Expand Down

0 comments on commit 15a7dea

Please sign in to comment.