Skip to content

Commit

Permalink
net/tcp: Merge TCP-MD5 inbound callbacks
Browse files Browse the repository at this point in the history
The functions do essentially the same work to verify TCP-MD5 sign.
Code can be merged into one family-independent function in order to
reduce copy'n'paste and generated code.
Later with TCP-AO option added, this will allow to create one function
that's responsible for segment verification, that will have all the
different checks for MD5/AO/non-signed packets, which in turn will help
to see checks for all corner-cases in one function, rather than spread
around different families and functions.

Cc: Eric Dumazet <[email protected]>
Cc: Hideaki YOSHIFUJI <[email protected]>
Signed-off-by: Dmitry Safonov <[email protected]>
Reviewed-by: David Ahern <[email protected]>
Link: https://lore.kernel.org/r/[email protected]
Signed-off-by: Jakub Kicinski <[email protected]>
  • Loading branch information
0x7f454c46 authored and kuba-moo committed Feb 25, 2022
1 parent 53110c6 commit 7bbb765
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 131 deletions.
13 changes: 13 additions & 0 deletions include/net/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,11 @@ tcp_md5_do_lookup(const struct sock *sk, int l3index,
return NULL;
return __tcp_md5_do_lookup(sk, l3index, addr, family);
}
bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
enum skb_drop_reason *reason,
const void *saddr, const void *daddr,
int family, int dif, int sdif);


#define tcp_twsk_md5_key(twsk) ((twsk)->tw_md5_key)
#else
Expand All @@ -1683,6 +1688,14 @@ tcp_md5_do_lookup(const struct sock *sk, int l3index,
{
return NULL;
}
static inline bool tcp_inbound_md5_hash(const struct sock *sk,
const struct sk_buff *skb,
enum skb_drop_reason *reason,
const void *saddr, const void *daddr,
int family, int dif, int sdif)
{
return false;
}
#define tcp_twsk_md5_key(twsk) NULL
#endif

Expand Down
70 changes: 70 additions & 0 deletions net/ipv4/tcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -4431,6 +4431,76 @@ int tcp_md5_hash_key(struct tcp_md5sig_pool *hp, const struct tcp_md5sig_key *ke
}
EXPORT_SYMBOL(tcp_md5_hash_key);

/* Called with rcu_read_lock() */
bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
enum skb_drop_reason *reason,
const void *saddr, const void *daddr,
int family, int dif, int sdif)
{
/*
* This gets called for each TCP segment that arrives
* so we want to be efficient.
* We have 3 drop cases:
* o No MD5 hash and one expected.
* o MD5 hash and we're not expecting one.
* o MD5 hash and its wrong.
*/
const __u8 *hash_location = NULL;
struct tcp_md5sig_key *hash_expected;
const struct tcphdr *th = tcp_hdr(skb);
struct tcp_sock *tp = tcp_sk(sk);
int genhash, l3index;
u8 newhash[16];

/* sdif set, means packet ingressed via a device
* in an L3 domain and dif is set to the l3mdev
*/
l3index = sdif ? dif : 0;

hash_expected = tcp_md5_do_lookup(sk, l3index, saddr, family);
hash_location = tcp_parse_md5sig_option(th);

/* We've parsed the options - do we have a hash? */
if (!hash_expected && !hash_location)
return false;

if (hash_expected && !hash_location) {
*reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
return true;
}

if (!hash_expected && hash_location) {
*reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
return true;
}

/* check the signature */
genhash = tp->af_specific->calc_md5_hash(newhash, hash_expected,
NULL, skb);

if (genhash || memcmp(hash_location, newhash, 16) != 0) {
*reason = SKB_DROP_REASON_TCP_MD5FAILURE;
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
if (family == AF_INET) {
net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
saddr, ntohs(th->source),
daddr, ntohs(th->dest),
genhash ? " tcp_v4_calc_md5_hash failed"
: "", l3index);
} else {
net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
genhash ? "failed" : "mismatch",
saddr, ntohs(th->source),
daddr, ntohs(th->dest), l3index);
}
return true;
}
return false;
}
EXPORT_SYMBOL(tcp_inbound_md5_hash);

#endif

void tcp_done(struct sock *sk)
Expand Down
78 changes: 5 additions & 73 deletions net/ipv4/tcp_ipv4.c
Original file line number Diff line number Diff line change
Expand Up @@ -1409,76 +1409,6 @@ EXPORT_SYMBOL(tcp_v4_md5_hash_skb);

#endif

/* Called with rcu_read_lock() */
static bool tcp_v4_inbound_md5_hash(const struct sock *sk,
const struct sk_buff *skb,
int dif, int sdif,
enum skb_drop_reason *reason)
{
#ifdef CONFIG_TCP_MD5SIG
/*
* This gets called for each TCP segment that arrives
* so we want to be efficient.
* We have 3 drop cases:
* o No MD5 hash and one expected.
* o MD5 hash and we're not expecting one.
* o MD5 hash and its wrong.
*/
const __u8 *hash_location = NULL;
struct tcp_md5sig_key *hash_expected;
const struct iphdr *iph = ip_hdr(skb);
const struct tcphdr *th = tcp_hdr(skb);
const union tcp_md5_addr *addr;
unsigned char newhash[16];
int genhash, l3index;

/* sdif set, means packet ingressed via a device
* in an L3 domain and dif is set to the l3mdev
*/
l3index = sdif ? dif : 0;

addr = (union tcp_md5_addr *)&iph->saddr;
hash_expected = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
hash_location = tcp_parse_md5sig_option(th);

/* We've parsed the options - do we have a hash? */
if (!hash_expected && !hash_location)
return false;

if (hash_expected && !hash_location) {
*reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
return true;
}

if (!hash_expected && hash_location) {
*reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
return true;
}

/* Okay, so this is hash_expected and hash_location -
* so we need to calculate the checksum.
*/
genhash = tcp_v4_md5_hash_skb(newhash,
hash_expected,
NULL, skb);

if (genhash || memcmp(hash_location, newhash, 16) != 0) {
*reason = SKB_DROP_REASON_TCP_MD5FAILURE;
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
&iph->saddr, ntohs(th->source),
&iph->daddr, ntohs(th->dest),
genhash ? " tcp_v4_calc_md5_hash failed"
: "", l3index);
return true;
}
return false;
#endif
return false;
}

static void tcp_v4_init_req(struct request_sock *req,
const struct sock *sk_listener,
struct sk_buff *skb)
Expand Down Expand Up @@ -2035,8 +1965,9 @@ int tcp_v4_rcv(struct sk_buff *skb)
struct sock *nsk;

sk = req->rsk_listener;
if (unlikely(tcp_v4_inbound_md5_hash(sk, skb, dif, sdif,
&drop_reason))) {
if (unlikely(tcp_inbound_md5_hash(sk, skb, &drop_reason,
&iph->saddr, &iph->daddr,
AF_INET, dif, sdif))) {
sk_drops_add(sk, skb);
reqsk_put(req);
goto discard_it;
Expand Down Expand Up @@ -2110,7 +2041,8 @@ int tcp_v4_rcv(struct sk_buff *skb)
goto discard_and_relse;
}

if (tcp_v4_inbound_md5_hash(sk, skb, dif, sdif, &drop_reason))
if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &iph->saddr,
&iph->daddr, AF_INET, dif, sdif))
goto discard_and_relse;

nf_reset_ct(skb);
Expand Down
62 changes: 4 additions & 58 deletions net/ipv6/tcp_ipv6.c
Original file line number Diff line number Diff line change
Expand Up @@ -773,61 +773,6 @@ static int tcp_v6_md5_hash_skb(char *md5_hash,

#endif

static bool tcp_v6_inbound_md5_hash(const struct sock *sk,
const struct sk_buff *skb,
int dif, int sdif,
enum skb_drop_reason *reason)
{
#ifdef CONFIG_TCP_MD5SIG
const __u8 *hash_location = NULL;
struct tcp_md5sig_key *hash_expected;
const struct ipv6hdr *ip6h = ipv6_hdr(skb);
const struct tcphdr *th = tcp_hdr(skb);
int genhash, l3index;
u8 newhash[16];

/* sdif set, means packet ingressed via a device
* in an L3 domain and dif is set to the l3mdev
*/
l3index = sdif ? dif : 0;

hash_expected = tcp_v6_md5_do_lookup(sk, &ip6h->saddr, l3index);
hash_location = tcp_parse_md5sig_option(th);

/* We've parsed the options - do we have a hash? */
if (!hash_expected && !hash_location)
return false;

if (hash_expected && !hash_location) {
*reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
return true;
}

if (!hash_expected && hash_location) {
*reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
return true;
}

/* check the signature */
genhash = tcp_v6_md5_hash_skb(newhash,
hash_expected,
NULL, skb);

if (genhash || memcmp(hash_location, newhash, 16) != 0) {
*reason = SKB_DROP_REASON_TCP_MD5FAILURE;
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
genhash ? "failed" : "mismatch",
&ip6h->saddr, ntohs(th->source),
&ip6h->daddr, ntohs(th->dest), l3index);
return true;
}
#endif
return false;
}

static void tcp_v6_init_req(struct request_sock *req,
const struct sock *sk_listener,
struct sk_buff *skb)
Expand Down Expand Up @@ -1687,8 +1632,8 @@ INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb)
struct sock *nsk;

sk = req->rsk_listener;
if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif,
&drop_reason)) {
if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &hdr->saddr,
&hdr->daddr, AF_INET6, dif, sdif)) {
sk_drops_add(sk, skb);
reqsk_put(req);
goto discard_it;
Expand Down Expand Up @@ -1759,7 +1704,8 @@ INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb)
goto discard_and_relse;
}

if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif, &drop_reason))
if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &hdr->saddr,
&hdr->daddr, AF_INET6, dif, sdif))
goto discard_and_relse;

if (tcp_filter(sk, skb)) {
Expand Down

0 comments on commit 7bbb765

Please sign in to comment.