mirror of
				https://github.com/torvalds/linux.git
				synced 2025-11-04 10:40:15 +02:00 
			
		
		
		
	net/tcp: Merge TCP-MD5 inbound callbacks
The functions do essentially the same work to verify TCP-MD5 sign. Code can be merged into one family-independent function in order to reduce copy'n'paste and generated code. Later with TCP-AO option added, this will allow to create one function that's responsible for segment verification, that will have all the different checks for MD5/AO/non-signed packets, which in turn will help to see checks for all corner-cases in one function, rather than spread around different families and functions. Cc: Eric Dumazet <edumazet@google.com> Cc: Hideaki YOSHIFUJI <yoshfuji@linux-ipv6.org> Signed-off-by: Dmitry Safonov <dima@arista.com> Reviewed-by: David Ahern <dsahern@kernel.org> Link: https://lore.kernel.org/r/20220223175740.452397-1-dima@arista.com Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
		
							parent
							
								
									53110c67e3
								
							
						
					
					
						commit
						7bbb765b73
					
				
					 4 changed files with 92 additions and 131 deletions
				
			
		| 
						 | 
				
			
			@ -1674,6 +1674,11 @@ tcp_md5_do_lookup(const struct sock *sk, int l3index,
 | 
			
		|||
		return NULL;
 | 
			
		||||
	return __tcp_md5_do_lookup(sk, l3index, addr, family);
 | 
			
		||||
}
 | 
			
		||||
bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
 | 
			
		||||
			  enum skb_drop_reason *reason,
 | 
			
		||||
			  const void *saddr, const void *daddr,
 | 
			
		||||
			  int family, int dif, int sdif);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#define tcp_twsk_md5_key(twsk)	((twsk)->tw_md5_key)
 | 
			
		||||
#else
 | 
			
		||||
| 
						 | 
				
			
			@ -1683,6 +1688,14 @@ tcp_md5_do_lookup(const struct sock *sk, int l3index,
 | 
			
		|||
{
 | 
			
		||||
	return NULL;
 | 
			
		||||
}
 | 
			
		||||
static inline bool tcp_inbound_md5_hash(const struct sock *sk,
 | 
			
		||||
					const struct sk_buff *skb,
 | 
			
		||||
					enum skb_drop_reason *reason,
 | 
			
		||||
					const void *saddr, const void *daddr,
 | 
			
		||||
					int family, int dif, int sdif)
 | 
			
		||||
{
 | 
			
		||||
	return false;
 | 
			
		||||
}
 | 
			
		||||
#define tcp_twsk_md5_key(twsk)	NULL
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4431,6 +4431,76 @@ int tcp_md5_hash_key(struct tcp_md5sig_pool *hp, const struct tcp_md5sig_key *ke
 | 
			
		|||
}
 | 
			
		||||
EXPORT_SYMBOL(tcp_md5_hash_key);
 | 
			
		||||
 | 
			
		||||
/* Called with rcu_read_lock() */
 | 
			
		||||
bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
 | 
			
		||||
			  enum skb_drop_reason *reason,
 | 
			
		||||
			  const void *saddr, const void *daddr,
 | 
			
		||||
			  int family, int dif, int sdif)
 | 
			
		||||
{
 | 
			
		||||
	/*
 | 
			
		||||
	 * This gets called for each TCP segment that arrives
 | 
			
		||||
	 * so we want to be efficient.
 | 
			
		||||
	 * We have 3 drop cases:
 | 
			
		||||
	 * o No MD5 hash and one expected.
 | 
			
		||||
	 * o MD5 hash and we're not expecting one.
 | 
			
		||||
	 * o MD5 hash and its wrong.
 | 
			
		||||
	 */
 | 
			
		||||
	const __u8 *hash_location = NULL;
 | 
			
		||||
	struct tcp_md5sig_key *hash_expected;
 | 
			
		||||
	const struct tcphdr *th = tcp_hdr(skb);
 | 
			
		||||
	struct tcp_sock *tp = tcp_sk(sk);
 | 
			
		||||
	int genhash, l3index;
 | 
			
		||||
	u8 newhash[16];
 | 
			
		||||
 | 
			
		||||
	/* sdif set, means packet ingressed via a device
 | 
			
		||||
	 * in an L3 domain and dif is set to the l3mdev
 | 
			
		||||
	 */
 | 
			
		||||
	l3index = sdif ? dif : 0;
 | 
			
		||||
 | 
			
		||||
	hash_expected = tcp_md5_do_lookup(sk, l3index, saddr, family);
 | 
			
		||||
	hash_location = tcp_parse_md5sig_option(th);
 | 
			
		||||
 | 
			
		||||
	/* We've parsed the options - do we have a hash? */
 | 
			
		||||
	if (!hash_expected && !hash_location)
 | 
			
		||||
		return false;
 | 
			
		||||
 | 
			
		||||
	if (hash_expected && !hash_location) {
 | 
			
		||||
		*reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
 | 
			
		||||
		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if (!hash_expected && hash_location) {
 | 
			
		||||
		*reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
 | 
			
		||||
		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/* check the signature */
 | 
			
		||||
	genhash = tp->af_specific->calc_md5_hash(newhash, hash_expected,
 | 
			
		||||
						 NULL, skb);
 | 
			
		||||
 | 
			
		||||
	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
 | 
			
		||||
		*reason = SKB_DROP_REASON_TCP_MD5FAILURE;
 | 
			
		||||
		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
 | 
			
		||||
		if (family == AF_INET) {
 | 
			
		||||
			net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
 | 
			
		||||
					saddr, ntohs(th->source),
 | 
			
		||||
					daddr, ntohs(th->dest),
 | 
			
		||||
					genhash ? " tcp_v4_calc_md5_hash failed"
 | 
			
		||||
					: "", l3index);
 | 
			
		||||
		} else {
 | 
			
		||||
			net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
 | 
			
		||||
					genhash ? "failed" : "mismatch",
 | 
			
		||||
					saddr, ntohs(th->source),
 | 
			
		||||
					daddr, ntohs(th->dest), l3index);
 | 
			
		||||
		}
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
	return false;
 | 
			
		||||
}
 | 
			
		||||
EXPORT_SYMBOL(tcp_inbound_md5_hash);
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
void tcp_done(struct sock *sk)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1409,76 +1409,6 @@ EXPORT_SYMBOL(tcp_v4_md5_hash_skb);
 | 
			
		|||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
/* Called with rcu_read_lock() */
 | 
			
		||||
static bool tcp_v4_inbound_md5_hash(const struct sock *sk,
 | 
			
		||||
				    const struct sk_buff *skb,
 | 
			
		||||
				    int dif, int sdif,
 | 
			
		||||
				    enum skb_drop_reason *reason)
 | 
			
		||||
{
 | 
			
		||||
#ifdef CONFIG_TCP_MD5SIG
 | 
			
		||||
	/*
 | 
			
		||||
	 * This gets called for each TCP segment that arrives
 | 
			
		||||
	 * so we want to be efficient.
 | 
			
		||||
	 * We have 3 drop cases:
 | 
			
		||||
	 * o No MD5 hash and one expected.
 | 
			
		||||
	 * o MD5 hash and we're not expecting one.
 | 
			
		||||
	 * o MD5 hash and its wrong.
 | 
			
		||||
	 */
 | 
			
		||||
	const __u8 *hash_location = NULL;
 | 
			
		||||
	struct tcp_md5sig_key *hash_expected;
 | 
			
		||||
	const struct iphdr *iph = ip_hdr(skb);
 | 
			
		||||
	const struct tcphdr *th = tcp_hdr(skb);
 | 
			
		||||
	const union tcp_md5_addr *addr;
 | 
			
		||||
	unsigned char newhash[16];
 | 
			
		||||
	int genhash, l3index;
 | 
			
		||||
 | 
			
		||||
	/* sdif set, means packet ingressed via a device
 | 
			
		||||
	 * in an L3 domain and dif is set to the l3mdev
 | 
			
		||||
	 */
 | 
			
		||||
	l3index = sdif ? dif : 0;
 | 
			
		||||
 | 
			
		||||
	addr = (union tcp_md5_addr *)&iph->saddr;
 | 
			
		||||
	hash_expected = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
 | 
			
		||||
	hash_location = tcp_parse_md5sig_option(th);
 | 
			
		||||
 | 
			
		||||
	/* We've parsed the options - do we have a hash? */
 | 
			
		||||
	if (!hash_expected && !hash_location)
 | 
			
		||||
		return false;
 | 
			
		||||
 | 
			
		||||
	if (hash_expected && !hash_location) {
 | 
			
		||||
		*reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
 | 
			
		||||
		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if (!hash_expected && hash_location) {
 | 
			
		||||
		*reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
 | 
			
		||||
		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/* Okay, so this is hash_expected and hash_location -
 | 
			
		||||
	 * so we need to calculate the checksum.
 | 
			
		||||
	 */
 | 
			
		||||
	genhash = tcp_v4_md5_hash_skb(newhash,
 | 
			
		||||
				      hash_expected,
 | 
			
		||||
				      NULL, skb);
 | 
			
		||||
 | 
			
		||||
	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
 | 
			
		||||
		*reason = SKB_DROP_REASON_TCP_MD5FAILURE;
 | 
			
		||||
		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
 | 
			
		||||
		net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
 | 
			
		||||
				     &iph->saddr, ntohs(th->source),
 | 
			
		||||
				     &iph->daddr, ntohs(th->dest),
 | 
			
		||||
				     genhash ? " tcp_v4_calc_md5_hash failed"
 | 
			
		||||
				     : "", l3index);
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
	return false;
 | 
			
		||||
#endif
 | 
			
		||||
	return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void tcp_v4_init_req(struct request_sock *req,
 | 
			
		||||
			    const struct sock *sk_listener,
 | 
			
		||||
			    struct sk_buff *skb)
 | 
			
		||||
| 
						 | 
				
			
			@ -2035,8 +1965,9 @@ int tcp_v4_rcv(struct sk_buff *skb)
 | 
			
		|||
		struct sock *nsk;
 | 
			
		||||
 | 
			
		||||
		sk = req->rsk_listener;
 | 
			
		||||
		if (unlikely(tcp_v4_inbound_md5_hash(sk, skb, dif, sdif,
 | 
			
		||||
						     &drop_reason))) {
 | 
			
		||||
		if (unlikely(tcp_inbound_md5_hash(sk, skb, &drop_reason,
 | 
			
		||||
						  &iph->saddr, &iph->daddr,
 | 
			
		||||
						  AF_INET, dif, sdif))) {
 | 
			
		||||
			sk_drops_add(sk, skb);
 | 
			
		||||
			reqsk_put(req);
 | 
			
		||||
			goto discard_it;
 | 
			
		||||
| 
						 | 
				
			
			@ -2110,7 +2041,8 @@ int tcp_v4_rcv(struct sk_buff *skb)
 | 
			
		|||
		goto discard_and_relse;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if (tcp_v4_inbound_md5_hash(sk, skb, dif, sdif, &drop_reason))
 | 
			
		||||
	if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &iph->saddr,
 | 
			
		||||
				 &iph->daddr, AF_INET, dif, sdif))
 | 
			
		||||
		goto discard_and_relse;
 | 
			
		||||
 | 
			
		||||
	nf_reset_ct(skb);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -773,61 +773,6 @@ static int tcp_v6_md5_hash_skb(char *md5_hash,
 | 
			
		|||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
static bool tcp_v6_inbound_md5_hash(const struct sock *sk,
 | 
			
		||||
				    const struct sk_buff *skb,
 | 
			
		||||
				    int dif, int sdif,
 | 
			
		||||
				    enum skb_drop_reason *reason)
 | 
			
		||||
{
 | 
			
		||||
#ifdef CONFIG_TCP_MD5SIG
 | 
			
		||||
	const __u8 *hash_location = NULL;
 | 
			
		||||
	struct tcp_md5sig_key *hash_expected;
 | 
			
		||||
	const struct ipv6hdr *ip6h = ipv6_hdr(skb);
 | 
			
		||||
	const struct tcphdr *th = tcp_hdr(skb);
 | 
			
		||||
	int genhash, l3index;
 | 
			
		||||
	u8 newhash[16];
 | 
			
		||||
 | 
			
		||||
	/* sdif set, means packet ingressed via a device
 | 
			
		||||
	 * in an L3 domain and dif is set to the l3mdev
 | 
			
		||||
	 */
 | 
			
		||||
	l3index = sdif ? dif : 0;
 | 
			
		||||
 | 
			
		||||
	hash_expected = tcp_v6_md5_do_lookup(sk, &ip6h->saddr, l3index);
 | 
			
		||||
	hash_location = tcp_parse_md5sig_option(th);
 | 
			
		||||
 | 
			
		||||
	/* We've parsed the options - do we have a hash? */
 | 
			
		||||
	if (!hash_expected && !hash_location)
 | 
			
		||||
		return false;
 | 
			
		||||
 | 
			
		||||
	if (hash_expected && !hash_location) {
 | 
			
		||||
		*reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
 | 
			
		||||
		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if (!hash_expected && hash_location) {
 | 
			
		||||
		*reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
 | 
			
		||||
		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/* check the signature */
 | 
			
		||||
	genhash = tcp_v6_md5_hash_skb(newhash,
 | 
			
		||||
				      hash_expected,
 | 
			
		||||
				      NULL, skb);
 | 
			
		||||
 | 
			
		||||
	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
 | 
			
		||||
		*reason = SKB_DROP_REASON_TCP_MD5FAILURE;
 | 
			
		||||
		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
 | 
			
		||||
		net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
 | 
			
		||||
				     genhash ? "failed" : "mismatch",
 | 
			
		||||
				     &ip6h->saddr, ntohs(th->source),
 | 
			
		||||
				     &ip6h->daddr, ntohs(th->dest), l3index);
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
#endif
 | 
			
		||||
	return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void tcp_v6_init_req(struct request_sock *req,
 | 
			
		||||
			    const struct sock *sk_listener,
 | 
			
		||||
			    struct sk_buff *skb)
 | 
			
		||||
| 
						 | 
				
			
			@ -1687,8 +1632,8 @@ INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb)
 | 
			
		|||
		struct sock *nsk;
 | 
			
		||||
 | 
			
		||||
		sk = req->rsk_listener;
 | 
			
		||||
		if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif,
 | 
			
		||||
					    &drop_reason)) {
 | 
			
		||||
		if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &hdr->saddr,
 | 
			
		||||
					 &hdr->daddr, AF_INET6, dif, sdif)) {
 | 
			
		||||
			sk_drops_add(sk, skb);
 | 
			
		||||
			reqsk_put(req);
 | 
			
		||||
			goto discard_it;
 | 
			
		||||
| 
						 | 
				
			
			@ -1759,7 +1704,8 @@ INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb)
 | 
			
		|||
		goto discard_and_relse;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif, &drop_reason))
 | 
			
		||||
	if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &hdr->saddr,
 | 
			
		||||
				 &hdr->daddr, AF_INET6, dif, sdif))
 | 
			
		||||
		goto discard_and_relse;
 | 
			
		||||
 | 
			
		||||
	if (tcp_filter(sk, skb)) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue