mirror of
				https://github.com/torvalds/linux.git
				synced 2025-11-04 10:40:15 +02:00 
			
		
		
		
	sock: Introduce sk->sk_prot->psock_update_sk_prot()
Currently sockmap calls into each protocol to update the struct proto and replace it. This certainly won't work when the protocol is implemented as a module, for example, AF_UNIX. Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each protocol can implement its own way to replace the struct proto. This also helps get rid of symbol dependencies on CONFIG_INET. Signed-off-by: Cong Wang <cong.wang@bytedance.com> Signed-off-by: Alexei Starovoitov <ast@kernel.org> Link: https://lore.kernel.org/bpf/20210331023237.41094-11-xiyou.wangcong@gmail.com
This commit is contained in:
		
							parent
							
								
									a7ba4558e6
								
							
						
					
					
						commit
						8a59f9d1e3
					
				
					 12 changed files with 58 additions and 45 deletions
				
			
		| 
						 | 
					@ -99,6 +99,7 @@ struct sk_psock {
 | 
				
			||||||
	void (*saved_close)(struct sock *sk, long timeout);
 | 
						void (*saved_close)(struct sock *sk, long timeout);
 | 
				
			||||||
	void (*saved_write_space)(struct sock *sk);
 | 
						void (*saved_write_space)(struct sock *sk);
 | 
				
			||||||
	void (*saved_data_ready)(struct sock *sk);
 | 
						void (*saved_data_ready)(struct sock *sk);
 | 
				
			||||||
 | 
						int  (*psock_update_sk_prot)(struct sock *sk, bool restore);
 | 
				
			||||||
	struct proto			*sk_proto;
 | 
						struct proto			*sk_proto;
 | 
				
			||||||
	struct mutex			work_mutex;
 | 
						struct mutex			work_mutex;
 | 
				
			||||||
	struct sk_psock_work_state	work_state;
 | 
						struct sk_psock_work_state	work_state;
 | 
				
			||||||
| 
						 | 
					@ -395,25 +396,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static inline void sk_psock_update_proto(struct sock *sk,
 | 
					 | 
				
			||||||
					 struct sk_psock *psock,
 | 
					 | 
				
			||||||
					 struct proto *ops)
 | 
					 | 
				
			||||||
{
 | 
					 | 
				
			||||||
	/* Pairs with lockless read in sk_clone_lock() */
 | 
					 | 
				
			||||||
	WRITE_ONCE(sk->sk_prot, ops);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static inline void sk_psock_restore_proto(struct sock *sk,
 | 
					static inline void sk_psock_restore_proto(struct sock *sk,
 | 
				
			||||||
					  struct sk_psock *psock)
 | 
										  struct sk_psock *psock)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	sk->sk_prot->unhash = psock->saved_unhash;
 | 
						sk->sk_prot->unhash = psock->saved_unhash;
 | 
				
			||||||
	if (inet_csk_has_ulp(sk)) {
 | 
						if (psock->psock_update_sk_prot)
 | 
				
			||||||
		tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
 | 
							psock->psock_update_sk_prot(sk, true);
 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		sk->sk_write_space = psock->saved_write_space;
 | 
					 | 
				
			||||||
		/* Pairs with lockless read in sk_clone_lock() */
 | 
					 | 
				
			||||||
		WRITE_ONCE(sk->sk_prot, psock->sk_proto);
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static inline void sk_psock_set_state(struct sk_psock *psock,
 | 
					static inline void sk_psock_set_state(struct sk_psock *psock,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1184,6 +1184,9 @@ struct proto {
 | 
				
			||||||
	void			(*unhash)(struct sock *sk);
 | 
						void			(*unhash)(struct sock *sk);
 | 
				
			||||||
	void			(*rehash)(struct sock *sk);
 | 
						void			(*rehash)(struct sock *sk);
 | 
				
			||||||
	int			(*get_port)(struct sock *sk, unsigned short snum);
 | 
						int			(*get_port)(struct sock *sk, unsigned short snum);
 | 
				
			||||||
 | 
					#ifdef CONFIG_BPF_SYSCALL
 | 
				
			||||||
 | 
						int			(*psock_update_sk_prot)(struct sock *sk, bool restore);
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	/* Keeping track of sockets in use */
 | 
						/* Keeping track of sockets in use */
 | 
				
			||||||
#ifdef CONFIG_PROC_FS
 | 
					#ifdef CONFIG_PROC_FS
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2203,6 +2203,7 @@ struct sk_psock;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef CONFIG_BPF_SYSCALL
 | 
					#ifdef CONFIG_BPF_SYSCALL
 | 
				
			||||||
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
 | 
					struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
 | 
				
			||||||
 | 
					int tcp_bpf_update_proto(struct sock *sk, bool restore);
 | 
				
			||||||
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
 | 
					void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
 | 
				
			||||||
#endif /* CONFIG_BPF_SYSCALL */
 | 
					#endif /* CONFIG_BPF_SYSCALL */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
 | 
				
			||||||
#ifdef CONFIG_BPF_SYSCALL
 | 
					#ifdef CONFIG_BPF_SYSCALL
 | 
				
			||||||
struct sk_psock;
 | 
					struct sk_psock;
 | 
				
			||||||
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
 | 
					struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
 | 
				
			||||||
 | 
					int udp_bpf_update_proto(struct sock *sk, bool restore);
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif	/* _UDP_H */
 | 
					#endif	/* _UDP_H */
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -562,11 +562,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	write_lock_bh(&sk->sk_callback_lock);
 | 
						write_lock_bh(&sk->sk_callback_lock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if (inet_csk_has_ulp(sk)) {
 | 
					 | 
				
			||||||
		psock = ERR_PTR(-EINVAL);
 | 
					 | 
				
			||||||
		goto out;
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if (sk->sk_user_data) {
 | 
						if (sk->sk_user_data) {
 | 
				
			||||||
		psock = ERR_PTR(-EBUSY);
 | 
							psock = ERR_PTR(-EBUSY);
 | 
				
			||||||
		goto out;
 | 
							goto out;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -185,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 | 
					static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	struct proto *prot;
 | 
						if (!sk->sk_prot->psock_update_sk_prot)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	switch (sk->sk_type) {
 | 
					 | 
				
			||||||
	case SOCK_STREAM:
 | 
					 | 
				
			||||||
		prot = tcp_bpf_get_proto(sk, psock);
 | 
					 | 
				
			||||||
		break;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	case SOCK_DGRAM:
 | 
					 | 
				
			||||||
		prot = udp_bpf_get_proto(sk, psock);
 | 
					 | 
				
			||||||
		break;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		return -EINVAL;
 | 
							return -EINVAL;
 | 
				
			||||||
	}
 | 
						psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
 | 
				
			||||||
 | 
						return sk->sk_prot->psock_update_sk_prot(sk, false);
 | 
				
			||||||
	if (IS_ERR(prot))
 | 
					 | 
				
			||||||
		return PTR_ERR(prot);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sk_psock_update_proto(sk, psock, prot);
 | 
					 | 
				
			||||||
	return 0;
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
 | 
					static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
 | 
				
			||||||
| 
						 | 
					@ -556,7 +540,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static bool sock_map_sk_is_suitable(const struct sock *sk)
 | 
					static bool sock_map_sk_is_suitable(const struct sock *sk)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	return sk_is_tcp(sk) || sk_is_udp(sk);
 | 
						return !!sk->sk_prot->psock_update_sk_prot;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static bool sock_map_sk_state_allowed(const struct sock *sk)
 | 
					static bool sock_map_sk_state_allowed(const struct sock *sk)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
 | 
				
			||||||
	       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 | 
						       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
 | 
					int tcp_bpf_update_proto(struct sock *sk, bool restore)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
 | 
						struct sk_psock *psock = sk_psock(sk);
 | 
				
			||||||
	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 | 
						int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 | 
				
			||||||
	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 | 
						int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if (restore) {
 | 
				
			||||||
 | 
							if (inet_csk_has_ulp(sk)) {
 | 
				
			||||||
 | 
								tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								sk->sk_write_space = psock->saved_write_space;
 | 
				
			||||||
 | 
								/* Pairs with lockless read in sk_clone_lock() */
 | 
				
			||||||
 | 
								WRITE_ONCE(sk->sk_prot, psock->sk_proto);
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return 0;
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if (inet_csk_has_ulp(sk))
 | 
				
			||||||
 | 
							return -EINVAL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if (sk->sk_family == AF_INET6) {
 | 
						if (sk->sk_family == AF_INET6) {
 | 
				
			||||||
		if (tcp_bpf_assert_proto_ops(psock->sk_proto))
 | 
							if (tcp_bpf_assert_proto_ops(psock->sk_proto))
 | 
				
			||||||
			return ERR_PTR(-EINVAL);
 | 
								return -EINVAL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 | 
							tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &tcp_bpf_prots[family][config];
 | 
						/* Pairs with lockless read in sk_clone_lock() */
 | 
				
			||||||
 | 
						WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
 | 
				
			||||||
 | 
						return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* If a child got cloned from a listening socket that had tcp_bpf
 | 
					/* If a child got cloned from a listening socket that had tcp_bpf
 | 
				
			||||||
 * protocol callbacks installed, we need to restore the callbacks to
 | 
					 * protocol callbacks installed, we need to restore the callbacks to
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2806,6 +2806,9 @@ struct proto tcp_prot = {
 | 
				
			||||||
	.hash			= inet_hash,
 | 
						.hash			= inet_hash,
 | 
				
			||||||
	.unhash			= inet_unhash,
 | 
						.unhash			= inet_unhash,
 | 
				
			||||||
	.get_port		= inet_csk_get_port,
 | 
						.get_port		= inet_csk_get_port,
 | 
				
			||||||
 | 
					#ifdef CONFIG_BPF_SYSCALL
 | 
				
			||||||
 | 
						.psock_update_sk_prot	= tcp_bpf_update_proto,
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
	.enter_memory_pressure	= tcp_enter_memory_pressure,
 | 
						.enter_memory_pressure	= tcp_enter_memory_pressure,
 | 
				
			||||||
	.leave_memory_pressure	= tcp_leave_memory_pressure,
 | 
						.leave_memory_pressure	= tcp_leave_memory_pressure,
 | 
				
			||||||
	.stream_memory_free	= tcp_stream_memory_free,
 | 
						.stream_memory_free	= tcp_stream_memory_free,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2849,6 +2849,9 @@ struct proto udp_prot = {
 | 
				
			||||||
	.unhash			= udp_lib_unhash,
 | 
						.unhash			= udp_lib_unhash,
 | 
				
			||||||
	.rehash			= udp_v4_rehash,
 | 
						.rehash			= udp_v4_rehash,
 | 
				
			||||||
	.get_port		= udp_v4_get_port,
 | 
						.get_port		= udp_v4_get_port,
 | 
				
			||||||
 | 
					#ifdef CONFIG_BPF_SYSCALL
 | 
				
			||||||
 | 
						.psock_update_sk_prot	= udp_bpf_update_proto,
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
	.memory_allocated	= &udp_memory_allocated,
 | 
						.memory_allocated	= &udp_memory_allocated,
 | 
				
			||||||
	.sysctl_mem		= sysctl_udp_mem,
 | 
						.sysctl_mem		= sysctl_udp_mem,
 | 
				
			||||||
	.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_udp_wmem_min),
 | 
						.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_udp_wmem_min),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
core_initcall(udp_bpf_v4_build_proto);
 | 
					core_initcall(udp_bpf_v4_build_proto);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
 | 
					int udp_bpf_update_proto(struct sock *sk, bool restore)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
 | 
						int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
 | 
				
			||||||
 | 
						struct sk_psock *psock = sk_psock(sk);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if (restore) {
 | 
				
			||||||
 | 
							sk->sk_write_space = psock->saved_write_space;
 | 
				
			||||||
 | 
							/* Pairs with lockless read in sk_clone_lock() */
 | 
				
			||||||
 | 
							WRITE_ONCE(sk->sk_prot, psock->sk_proto);
 | 
				
			||||||
 | 
							return 0;
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if (sk->sk_family == AF_INET6)
 | 
						if (sk->sk_family == AF_INET6)
 | 
				
			||||||
		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 | 
							udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &udp_bpf_prots[family];
 | 
						/* Pairs with lockless read in sk_clone_lock() */
 | 
				
			||||||
 | 
						WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
 | 
				
			||||||
 | 
						return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2139,6 +2139,9 @@ struct proto tcpv6_prot = {
 | 
				
			||||||
	.hash			= inet6_hash,
 | 
						.hash			= inet6_hash,
 | 
				
			||||||
	.unhash			= inet_unhash,
 | 
						.unhash			= inet_unhash,
 | 
				
			||||||
	.get_port		= inet_csk_get_port,
 | 
						.get_port		= inet_csk_get_port,
 | 
				
			||||||
 | 
					#ifdef CONFIG_BPF_SYSCALL
 | 
				
			||||||
 | 
						.psock_update_sk_prot	= tcp_bpf_update_proto,
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
	.enter_memory_pressure	= tcp_enter_memory_pressure,
 | 
						.enter_memory_pressure	= tcp_enter_memory_pressure,
 | 
				
			||||||
	.leave_memory_pressure	= tcp_leave_memory_pressure,
 | 
						.leave_memory_pressure	= tcp_leave_memory_pressure,
 | 
				
			||||||
	.stream_memory_free	= tcp_stream_memory_free,
 | 
						.stream_memory_free	= tcp_stream_memory_free,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1713,6 +1713,9 @@ struct proto udpv6_prot = {
 | 
				
			||||||
	.unhash			= udp_lib_unhash,
 | 
						.unhash			= udp_lib_unhash,
 | 
				
			||||||
	.rehash			= udp_v6_rehash,
 | 
						.rehash			= udp_v6_rehash,
 | 
				
			||||||
	.get_port		= udp_v6_get_port,
 | 
						.get_port		= udp_v6_get_port,
 | 
				
			||||||
 | 
					#ifdef CONFIG_BPF_SYSCALL
 | 
				
			||||||
 | 
						.psock_update_sk_prot	= udp_bpf_update_proto,
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
	.memory_allocated	= &udp_memory_allocated,
 | 
						.memory_allocated	= &udp_memory_allocated,
 | 
				
			||||||
	.sysctl_mem		= sysctl_udp_mem,
 | 
						.sysctl_mem		= sysctl_udp_mem,
 | 
				
			||||||
	.sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
 | 
						.sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue