Skip to content

Commit

Permalink
net: convert sock.sk_refcnt from atomic_t to refcount_t
Browse files Browse the repository at this point in the history
refcount_t type and corresponding API should be
used instead of atomic_t when the variable is used as
a reference counter. This allows to avoid accidental
refcounter overflows that might lead to use-after-free
situations.

This patch uses refcount_inc_not_zero() instead of
atomic_inc_not_zero_hint() due to absense of a _hint()
version of refcount API. If the hint() version must
be used, we might need to revisit API.

Signed-off-by: Elena Reshetova <[email protected]>
Signed-off-by: Hans Liljestrand <[email protected]>
Signed-off-by: Kees Cook <[email protected]>
Signed-off-by: David Windsor <[email protected]>
Signed-off-by: David S. Miller <[email protected]>
  • Loading branch information
ereshetova authored and davem330 committed Jul 1, 2017
1 parent 14afee4 commit 41c6d65
Show file tree
Hide file tree
Showing 35 changed files with 70 additions and 69 deletions.
2 changes: 1 addition & 1 deletion crypto/algif_aead.c
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ static void aead_sock_destruct(struct sock *sk)
unsigned int ivlen = crypto_aead_ivsize(
crypto_aead_reqtfm(&ctx->aead_req));

WARN_ON(atomic_read(&sk->sk_refcnt) != 0);
WARN_ON(refcount_read(&sk->sk_refcnt) != 0);
aead_put_sgl(sk);
sock_kzfree_s(sk, ctx->iv, ivlen);
sock_kfree_s(sk, ctx, ctx->len);
Expand Down
4 changes: 2 additions & 2 deletions include/net/inet_hashtables.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <net/tcp_states.h>
#include <net/netns/hash.h>

#include <linux/atomic.h>
#include <linux/refcount.h>
#include <asm/byteorder.h>

/* This is for all connections with a full identity, no wildcards.
Expand Down Expand Up @@ -334,7 +334,7 @@ static inline struct sock *inet_lookup(struct net *net,
sk = __inet_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
dport, dif, &refcounted);

if (sk && !refcounted && !atomic_inc_not_zero(&sk->sk_refcnt))
if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL;
return sk;
}
Expand Down
9 changes: 5 additions & 4 deletions include/net/request_sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <linux/spinlock.h>
#include <linux/types.h>
#include <linux/bug.h>
#include <linux/refcount.h>

#include <net/sock.h>

Expand Down Expand Up @@ -89,7 +90,7 @@ reqsk_alloc(const struct request_sock_ops *ops, struct sock *sk_listener,
return NULL;
req->rsk_listener = NULL;
if (attach_listener) {
if (unlikely(!atomic_inc_not_zero(&sk_listener->sk_refcnt))) {
if (unlikely(!refcount_inc_not_zero(&sk_listener->sk_refcnt))) {
kmem_cache_free(ops->slab, req);
return NULL;
}
Expand All @@ -100,15 +101,15 @@ reqsk_alloc(const struct request_sock_ops *ops, struct sock *sk_listener,
sk_node_init(&req_to_sk(req)->sk_node);
sk_tx_queue_clear(req_to_sk(req));
req->saved_syn = NULL;
atomic_set(&req->rsk_refcnt, 0);
refcount_set(&req->rsk_refcnt, 0);

return req;
}

static inline void reqsk_free(struct request_sock *req)
{
/* temporary debugging */
WARN_ON_ONCE(atomic_read(&req->rsk_refcnt) != 0);
WARN_ON_ONCE(refcount_read(&req->rsk_refcnt) != 0);

req->rsk_ops->destructor(req);
if (req->rsk_listener)
Expand All @@ -119,7 +120,7 @@ static inline void reqsk_free(struct request_sock *req)

static inline void reqsk_put(struct request_sock *req)
{
if (atomic_dec_and_test(&req->rsk_refcnt))
if (refcount_dec_and_test(&req->rsk_refcnt))
reqsk_free(req);
}

Expand Down
17 changes: 9 additions & 8 deletions include/net/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
#include <linux/poll.h>

#include <linux/atomic.h>
#include <linux/refcount.h>
#include <net/dst.h>
#include <net/checksum.h>
#include <net/tcp_states.h>
Expand Down Expand Up @@ -219,7 +220,7 @@ struct sock_common {
u32 skc_tw_rcv_nxt; /* struct tcp_timewait_sock */
};

atomic_t skc_refcnt;
refcount_t skc_refcnt;
/* private: */
int skc_dontcopy_end[0];
union {
Expand Down Expand Up @@ -611,15 +612,15 @@ static inline bool __sk_del_node_init(struct sock *sk)

static __always_inline void sock_hold(struct sock *sk)
{
atomic_inc(&sk->sk_refcnt);
refcount_inc(&sk->sk_refcnt);
}

/* Ungrab socket in the context, which assumes that socket refcnt
cannot hit zero, f.e. it is true in context of any socketcall.
*/
static __always_inline void __sock_put(struct sock *sk)
{
atomic_dec(&sk->sk_refcnt);
refcount_dec(&sk->sk_refcnt);
}

static inline bool sk_del_node_init(struct sock *sk)
Expand All @@ -628,7 +629,7 @@ static inline bool sk_del_node_init(struct sock *sk)

if (rc) {
/* paranoid for a while -acme */
WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
WARN_ON(refcount_read(&sk->sk_refcnt) == 1);
__sock_put(sk);
}
return rc;
Expand All @@ -650,7 +651,7 @@ static inline bool sk_nulls_del_node_init_rcu(struct sock *sk)

if (rc) {
/* paranoid for a while -acme */
WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
WARN_ON(refcount_read(&sk->sk_refcnt) == 1);
__sock_put(sk);
}
return rc;
Expand Down Expand Up @@ -1144,9 +1145,9 @@ static inline void sk_refcnt_debug_dec(struct sock *sk)

static inline void sk_refcnt_debug_release(const struct sock *sk)
{
if (atomic_read(&sk->sk_refcnt) != 1)
if (refcount_read(&sk->sk_refcnt) != 1)
printk(KERN_DEBUG "Destruction of the %s socket %p delayed, refcnt=%d\n",
sk->sk_prot->name, sk, atomic_read(&sk->sk_refcnt));
sk->sk_prot->name, sk, refcount_read(&sk->sk_refcnt));
}
#else /* SOCK_REFCNT_DEBUG */
#define sk_refcnt_debug_inc(sk) do { } while (0)
Expand Down Expand Up @@ -1636,7 +1637,7 @@ void sock_init_data(struct socket *sock, struct sock *sk);
/* Ungrab socket and destroy it, if it was the last reference. */
static inline void sock_put(struct sock *sk)
{
if (atomic_dec_and_test(&sk->sk_refcnt))
if (refcount_dec_and_test(&sk->sk_refcnt))
sk_free(sk);
}
/* Generic version of sock_put(), dealing with all sockets
Expand Down
2 changes: 1 addition & 1 deletion net/atm/proc.c
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ static void vcc_info(struct seq_file *seq, struct atm_vcc *vcc)
vcc->flags, sk->sk_err,
sk_wmem_alloc_get(sk), sk->sk_sndbuf,
sk_rmem_alloc_get(sk), sk->sk_rcvbuf,
atomic_read(&sk->sk_refcnt));
refcount_read(&sk->sk_refcnt));
}

static void svc_info(struct seq_file *seq, struct atm_vcc *vcc)
Expand Down
2 changes: 1 addition & 1 deletion net/bluetooth/af_bluetooth.c
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ static int bt_seq_show(struct seq_file *seq, void *v)
seq_printf(seq,
"%pK %-6d %-6u %-6u %-6u %-6lu %-6lu",
sk,
atomic_read(&sk->sk_refcnt),
refcount_read(&sk->sk_refcnt),
sk_rmem_alloc_get(sk),
sk_wmem_alloc_get(sk),
from_kuid(seq_user_ns(seq), sock_i_uid(sk)),
Expand Down
2 changes: 1 addition & 1 deletion net/bluetooth/rfcomm/sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ static void rfcomm_sock_kill(struct sock *sk)
if (!sock_flag(sk, SOCK_ZAPPED) || sk->sk_socket)
return;

BT_DBG("sk %p state %d refcnt %d", sk, sk->sk_state, atomic_read(&sk->sk_refcnt));
BT_DBG("sk %p state %d refcnt %d", sk, sk->sk_state, refcount_read(&sk->sk_refcnt));

/* Kill poor orphan */
bt_sock_unlink(&rfcomm_sk_list, sk);
Expand Down
6 changes: 3 additions & 3 deletions net/core/skbuff.c
Original file line number Diff line number Diff line change
Expand Up @@ -3844,7 +3844,7 @@ struct sk_buff *skb_clone_sk(struct sk_buff *skb)
struct sock *sk = skb->sk;
struct sk_buff *clone;

if (!sk || !atomic_inc_not_zero(&sk->sk_refcnt))
if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt))
return NULL;

clone = skb_clone(skb, GFP_ATOMIC);
Expand Down Expand Up @@ -3915,7 +3915,7 @@ void skb_complete_tx_timestamp(struct sk_buff *skb,
/* Take a reference to prevent skb_orphan() from freeing the socket,
* but only if the socket refcount is not zero.
*/
if (likely(atomic_inc_not_zero(&sk->sk_refcnt))) {
if (likely(refcount_inc_not_zero(&sk->sk_refcnt))) {
*skb_hwtstamps(skb) = *hwtstamps;
__skb_complete_tx_timestamp(skb, sk, SCM_TSTAMP_SND, false);
sock_put(sk);
Expand Down Expand Up @@ -3997,7 +3997,7 @@ void skb_complete_wifi_ack(struct sk_buff *skb, bool acked)
/* Take a reference to prevent skb_orphan() from freeing the socket,
* but only if the socket refcount is not zero.
*/
if (likely(atomic_inc_not_zero(&sk->sk_refcnt))) {
if (likely(refcount_inc_not_zero(&sk->sk_refcnt))) {
err = sock_queue_err_skb(sk, skb);
sock_put(sk);
}
Expand Down
6 changes: 3 additions & 3 deletions net/core/sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,7 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
* (Documentation/RCU/rculist_nulls.txt for details)
*/
smp_wmb();
atomic_set(&newsk->sk_refcnt, 2);
refcount_set(&newsk->sk_refcnt, 2);

/*
* Increment the counter in the same struct proto as the master
Expand Down Expand Up @@ -1851,7 +1851,7 @@ void skb_orphan_partial(struct sk_buff *skb)
) {
struct sock *sk = skb->sk;

if (atomic_inc_not_zero(&sk->sk_refcnt)) {
if (refcount_inc_not_zero(&sk->sk_refcnt)) {
WARN_ON(refcount_sub_and_test(skb->truesize, &sk->sk_wmem_alloc));
skb->destructor = sock_efree;
}
Expand Down Expand Up @@ -2687,7 +2687,7 @@ void sock_init_data(struct socket *sock, struct sock *sk)
* (Documentation/RCU/rculist_nulls.txt for details)
*/
smp_wmb();
atomic_set(&sk->sk_refcnt, 1);
refcount_set(&sk->sk_refcnt, 1);
atomic_set(&sk->sk_drops, 0);
}
EXPORT_SYMBOL(sock_init_data);
Expand Down
2 changes: 1 addition & 1 deletion net/ipv4/inet_connection_sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ static void reqsk_queue_hash_req(struct request_sock *req,
* are committed to memory and refcnt initialized.
*/
smp_wmb();
atomic_set(&req->rsk_refcnt, 2 + 1);
refcount_set(&req->rsk_refcnt, 2 + 1);
}

void inet_csk_reqsk_queue_hash_add(struct sock *sk, struct request_sock *req,
Expand Down
4 changes: 2 additions & 2 deletions net/ipv4/inet_hashtables.c
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ EXPORT_SYMBOL_GPL(__inet_lookup_listener);
/* All sockets share common refcount, but have different destructors */
void sock_gen_put(struct sock *sk)
{
if (!atomic_dec_and_test(&sk->sk_refcnt))
if (!refcount_dec_and_test(&sk->sk_refcnt))
return;

if (sk->sk_state == TCP_TIME_WAIT)
Expand Down Expand Up @@ -287,7 +287,7 @@ struct sock *__inet_lookup_established(struct net *net,
continue;
if (likely(INET_MATCH(sk, net, acookie,
saddr, daddr, ports, dif))) {
if (unlikely(!atomic_inc_not_zero(&sk->sk_refcnt)))
if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
goto out;
if (unlikely(!INET_MATCH(sk, net, acookie,
saddr, daddr, ports, dif))) {
Expand Down
8 changes: 4 additions & 4 deletions net/ipv4/inet_timewait_sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void inet_twsk_free(struct inet_timewait_sock *tw)

void inet_twsk_put(struct inet_timewait_sock *tw)
{
if (atomic_dec_and_test(&tw->tw_refcnt))
if (refcount_dec_and_test(&tw->tw_refcnt))
inet_twsk_free(tw);
}
EXPORT_SYMBOL_GPL(inet_twsk_put);
Expand Down Expand Up @@ -131,7 +131,7 @@ void __inet_twsk_hashdance(struct inet_timewait_sock *tw, struct sock *sk,
* We can use atomic_set() because prior spin_lock()/spin_unlock()
* committed into memory all tw fields.
*/
atomic_set(&tw->tw_refcnt, 4);
refcount_set(&tw->tw_refcnt, 4);
inet_twsk_add_node_rcu(tw, &ehead->chain);

/* Step 3: Remove SK from hash chain */
Expand Down Expand Up @@ -195,7 +195,7 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk,
* to a non null value before everything is setup for this
* timewait socket.
*/
atomic_set(&tw->tw_refcnt, 0);
refcount_set(&tw->tw_refcnt, 0);

__module_get(tw->tw_prot->owner);
}
Expand Down Expand Up @@ -278,7 +278,7 @@ void inet_twsk_purge(struct inet_hashinfo *hashinfo, int family)
atomic_read(&twsk_net(tw)->count))
continue;

if (unlikely(!atomic_inc_not_zero(&tw->tw_refcnt)))
if (unlikely(!refcount_inc_not_zero(&tw->tw_refcnt)))
continue;

if (unlikely((tw->tw_family != family) ||
Expand Down
4 changes: 2 additions & 2 deletions net/ipv4/ping.c
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ void ping_close(struct sock *sk, long timeout)
{
pr_debug("ping_close(sk=%p,sk->num=%u)\n",
inet_sk(sk), inet_sk(sk)->inet_num);
pr_debug("isk->refcnt = %d\n", sk->sk_refcnt.counter);
pr_debug("isk->refcnt = %d\n", refcount_read(&sk->sk_refcnt));

sk_common_release(sk);
}
Expand Down Expand Up @@ -1127,7 +1127,7 @@ static void ping_v4_format_sock(struct sock *sp, struct seq_file *f,
0, 0L, 0,
from_kuid_munged(seq_user_ns(f), sock_i_uid(sp)),
0, sock_i_ino(sp),
atomic_read(&sp->sk_refcnt), sp,
refcount_read(&sp->sk_refcnt), sp,
atomic_read(&sp->sk_drops));
}

Expand Down
2 changes: 1 addition & 1 deletion net/ipv4/raw.c
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ static void raw_sock_seq_show(struct seq_file *seq, struct sock *sp, int i)
0, 0L, 0,
from_kuid_munged(seq_user_ns(seq), sock_i_uid(sp)),
0, sock_i_ino(sp),
atomic_read(&sp->sk_refcnt), sp, atomic_read(&sp->sk_drops));
refcount_read(&sp->sk_refcnt), sp, atomic_read(&sp->sk_drops));
}

static int raw_seq_show(struct seq_file *seq, void *v)
Expand Down
2 changes: 1 addition & 1 deletion net/ipv4/syncookies.c
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ struct sock *tcp_get_cookie_sock(struct sock *sk, struct sk_buff *skb,
child = icsk->icsk_af_ops->syn_recv_sock(sk, skb, req, dst,
NULL, &own_req);
if (child) {
atomic_set(&req->rsk_refcnt, 1);
refcount_set(&req->rsk_refcnt, 1);
tcp_sk(child)->tsoffset = tsoff;
sock_rps_save_rxhash(child, skb);
inet_csk_reqsk_queue_add(sk, req, child);
Expand Down
2 changes: 1 addition & 1 deletion net/ipv4/tcp_fastopen.c
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ static struct sock *tcp_fastopen_create_child(struct sock *sk,
inet_csk_reset_xmit_timer(child, ICSK_TIME_RETRANS,
TCP_TIMEOUT_INIT, TCP_RTO_MAX);

atomic_set(&req->rsk_refcnt, 2);
refcount_set(&req->rsk_refcnt, 2);

/* Now finish processing the fastopen child socket. */
inet_csk(child)->icsk_af_ops->rebuild_header(child);
Expand Down
4 changes: 2 additions & 2 deletions net/ipv4/tcp_ipv4.c
Original file line number Diff line number Diff line change
Expand Up @@ -2323,7 +2323,7 @@ static void get_tcp4_sock(struct sock *sk, struct seq_file *f, int i)
from_kuid_munged(seq_user_ns(f), sock_i_uid(sk)),
icsk->icsk_probes_out,
sock_i_ino(sk),
atomic_read(&sk->sk_refcnt), sk,
refcount_read(&sk->sk_refcnt), sk,
jiffies_to_clock_t(icsk->icsk_rto),
jiffies_to_clock_t(icsk->icsk_ack.ato),
(icsk->icsk_ack.quick << 1) | icsk->icsk_ack.pingpong,
Expand All @@ -2349,7 +2349,7 @@ static void get_timewait4_sock(const struct inet_timewait_sock *tw,
" %02X %08X:%08X %02X:%08lX %08X %5d %8d %d %d %pK",
i, src, srcp, dest, destp, tw->tw_substate, 0, 0,
3, jiffies_delta_to_clock_t(delta), 0, 0, 0, 0,
atomic_read(&tw->tw_refcnt), tw);
refcount_read(&tw->tw_refcnt), tw);
}

#define TMPSZ 150
Expand Down
6 changes: 3 additions & 3 deletions net/ipv4/udp.c
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,

sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport,
dif, &udp_table, NULL);
if (sk && !atomic_inc_not_zero(&sk->sk_refcnt))
if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL;
return sk;
}
Expand Down Expand Up @@ -2242,7 +2242,7 @@ void udp_v4_early_demux(struct sk_buff *skb)
uh->source, iph->saddr, dif);
}

if (!sk || !atomic_inc_not_zero_hint(&sk->sk_refcnt, 2))
if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt))
return;

skb->sk = sk;
Expand Down Expand Up @@ -2691,7 +2691,7 @@ static void udp4_format_sock(struct sock *sp, struct seq_file *f,
0, 0L, 0,
from_kuid_munged(seq_user_ns(f), sock_i_uid(sp)),
0, sock_i_ino(sp),
atomic_read(&sp->sk_refcnt), sp,
refcount_read(&sp->sk_refcnt), sp,
atomic_read(&sp->sk_drops));
}

Expand Down
4 changes: 2 additions & 2 deletions net/ipv4/udp_diag.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
req->id.idiag_dport,
req->id.idiag_if, tbl, NULL);
#endif
if (sk && !atomic_inc_not_zero(&sk->sk_refcnt))
if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL;
rcu_read_unlock();
err = -ENOENT;
Expand Down Expand Up @@ -206,7 +206,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
return -EINVAL;
}

if (sk && !atomic_inc_not_zero(&sk->sk_refcnt))
if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL;

rcu_read_unlock();
Expand Down
Loading

0 comments on commit 41c6d65

Please sign in to comment.