forked from mirrors/linux
		
	sock: Introduce sk->sk_prot->psock_update_sk_prot()
Currently sockmap calls into each protocol to update the struct proto and replace it. This certainly won't work when the protocol is implemented as a module, for example, AF_UNIX. Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each protocol can implement its own way to replace the struct proto. This also helps get rid of symbol dependencies on CONFIG_INET. Signed-off-by: Cong Wang <cong.wang@bytedance.com> Signed-off-by: Alexei Starovoitov <ast@kernel.org> Link: https://lore.kernel.org/bpf/20210331023237.41094-11-xiyou.wangcong@gmail.com
This commit is contained in:
		
							parent
							
								
									a7ba4558e6
								
							
						
					
					
						commit
						8a59f9d1e3
					
				
					 12 changed files with 58 additions and 45 deletions
				
			
		|  | @ -99,6 +99,7 @@ struct sk_psock { | |||
| 	void (*saved_close)(struct sock *sk, long timeout); | ||||
| 	void (*saved_write_space)(struct sock *sk); | ||||
| 	void (*saved_data_ready)(struct sock *sk); | ||||
| 	int  (*psock_update_sk_prot)(struct sock *sk, bool restore); | ||||
| 	struct proto			*sk_proto; | ||||
| 	struct mutex			work_mutex; | ||||
| 	struct sk_psock_work_state	work_state; | ||||
|  | @ -395,25 +396,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock) | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| static inline void sk_psock_update_proto(struct sock *sk, | ||||
| 					 struct sk_psock *psock, | ||||
| 					 struct proto *ops) | ||||
| { | ||||
| 	/* Pairs with lockless read in sk_clone_lock() */ | ||||
| 	WRITE_ONCE(sk->sk_prot, ops); | ||||
| } | ||||
| 
 | ||||
| static inline void sk_psock_restore_proto(struct sock *sk, | ||||
| 					  struct sk_psock *psock) | ||||
| { | ||||
| 	sk->sk_prot->unhash = psock->saved_unhash; | ||||
| 	if (inet_csk_has_ulp(sk)) { | ||||
| 		tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); | ||||
| 	} else { | ||||
| 		sk->sk_write_space = psock->saved_write_space; | ||||
| 		/* Pairs with lockless read in sk_clone_lock() */ | ||||
| 		WRITE_ONCE(sk->sk_prot, psock->sk_proto); | ||||
| 	} | ||||
| 	if (psock->psock_update_sk_prot) | ||||
| 		psock->psock_update_sk_prot(sk, true); | ||||
| } | ||||
| 
 | ||||
| static inline void sk_psock_set_state(struct sk_psock *psock, | ||||
|  |  | |||
|  | @ -1184,6 +1184,9 @@ struct proto { | |||
| 	void			(*unhash)(struct sock *sk); | ||||
| 	void			(*rehash)(struct sock *sk); | ||||
| 	int			(*get_port)(struct sock *sk, unsigned short snum); | ||||
| #ifdef CONFIG_BPF_SYSCALL | ||||
| 	int			(*psock_update_sk_prot)(struct sock *sk, bool restore); | ||||
| #endif | ||||
| 
 | ||||
| 	/* Keeping track of sockets in use */ | ||||
| #ifdef CONFIG_PROC_FS | ||||
|  |  | |||
|  | @ -2203,6 +2203,7 @@ struct sk_psock; | |||
| 
 | ||||
| #ifdef CONFIG_BPF_SYSCALL | ||||
| struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock); | ||||
| int tcp_bpf_update_proto(struct sock *sk, bool restore); | ||||
| void tcp_bpf_clone(const struct sock *sk, struct sock *newsk); | ||||
| #endif /* CONFIG_BPF_SYSCALL */ | ||||
| 
 | ||||
|  |  | |||
|  | @ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk, | |||
| #ifdef CONFIG_BPF_SYSCALL | ||||
| struct sk_psock; | ||||
| struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock); | ||||
| int udp_bpf_update_proto(struct sock *sk, bool restore); | ||||
| #endif | ||||
| 
 | ||||
| #endif	/* _UDP_H */ | ||||
|  |  | |||
|  | @ -562,11 +562,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node) | |||
| 
 | ||||
| 	write_lock_bh(&sk->sk_callback_lock); | ||||
| 
 | ||||
| 	if (inet_csk_has_ulp(sk)) { | ||||
| 		psock = ERR_PTR(-EINVAL); | ||||
| 		goto out; | ||||
| 	} | ||||
| 
 | ||||
| 	if (sk->sk_user_data) { | ||||
| 		psock = ERR_PTR(-EBUSY); | ||||
| 		goto out; | ||||
|  |  | |||
|  | @ -185,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw) | |||
| 
 | ||||
| static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) | ||||
| { | ||||
| 	struct proto *prot; | ||||
| 
 | ||||
| 	switch (sk->sk_type) { | ||||
| 	case SOCK_STREAM: | ||||
| 		prot = tcp_bpf_get_proto(sk, psock); | ||||
| 		break; | ||||
| 
 | ||||
| 	case SOCK_DGRAM: | ||||
| 		prot = udp_bpf_get_proto(sk, psock); | ||||
| 		break; | ||||
| 
 | ||||
| 	default: | ||||
| 	if (!sk->sk_prot->psock_update_sk_prot) | ||||
| 		return -EINVAL; | ||||
| 	} | ||||
| 
 | ||||
| 	if (IS_ERR(prot)) | ||||
| 		return PTR_ERR(prot); | ||||
| 
 | ||||
| 	sk_psock_update_proto(sk, psock, prot); | ||||
| 	return 0; | ||||
| 	psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot; | ||||
| 	return sk->sk_prot->psock_update_sk_prot(sk, false); | ||||
| } | ||||
| 
 | ||||
| static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) | ||||
|  | @ -556,7 +540,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk) | |||
| 
 | ||||
| static bool sock_map_sk_is_suitable(const struct sock *sk) | ||||
| { | ||||
| 	return sk_is_tcp(sk) || sk_is_udp(sk); | ||||
| 	return !!sk->sk_prot->psock_update_sk_prot; | ||||
| } | ||||
| 
 | ||||
| static bool sock_map_sk_state_allowed(const struct sock *sk) | ||||
|  |  | |||
|  | @ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops) | |||
| 	       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; | ||||
| } | ||||
| 
 | ||||
| struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) | ||||
| int tcp_bpf_update_proto(struct sock *sk, bool restore) | ||||
| { | ||||
| 	struct sk_psock *psock = sk_psock(sk); | ||||
| 	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; | ||||
| 	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE; | ||||
| 
 | ||||
| 	if (restore) { | ||||
| 		if (inet_csk_has_ulp(sk)) { | ||||
| 			tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); | ||||
| 		} else { | ||||
| 			sk->sk_write_space = psock->saved_write_space; | ||||
| 			/* Pairs with lockless read in sk_clone_lock() */ | ||||
| 			WRITE_ONCE(sk->sk_prot, psock->sk_proto); | ||||
| 		} | ||||
| 		return 0; | ||||
| 	} | ||||
| 
 | ||||
| 	if (inet_csk_has_ulp(sk)) | ||||
| 		return -EINVAL; | ||||
| 
 | ||||
| 	if (sk->sk_family == AF_INET6) { | ||||
| 		if (tcp_bpf_assert_proto_ops(psock->sk_proto)) | ||||
| 			return ERR_PTR(-EINVAL); | ||||
| 			return -EINVAL; | ||||
| 
 | ||||
| 		tcp_bpf_check_v6_needs_rebuild(psock->sk_proto); | ||||
| 	} | ||||
| 
 | ||||
| 	return &tcp_bpf_prots[family][config]; | ||||
| 	/* Pairs with lockless read in sk_clone_lock() */ | ||||
| 	WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]); | ||||
| 	return 0; | ||||
| } | ||||
| EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); | ||||
| 
 | ||||
| /* If a child got cloned from a listening socket that had tcp_bpf
 | ||||
|  * protocol callbacks installed, we need to restore the callbacks to | ||||
|  |  | |||
|  | @ -2806,6 +2806,9 @@ struct proto tcp_prot = { | |||
| 	.hash			= inet_hash, | ||||
| 	.unhash			= inet_unhash, | ||||
| 	.get_port		= inet_csk_get_port, | ||||
| #ifdef CONFIG_BPF_SYSCALL | ||||
| 	.psock_update_sk_prot	= tcp_bpf_update_proto, | ||||
| #endif | ||||
| 	.enter_memory_pressure	= tcp_enter_memory_pressure, | ||||
| 	.leave_memory_pressure	= tcp_leave_memory_pressure, | ||||
| 	.stream_memory_free	= tcp_stream_memory_free, | ||||
|  |  | |||
|  | @ -2849,6 +2849,9 @@ struct proto udp_prot = { | |||
| 	.unhash			= udp_lib_unhash, | ||||
| 	.rehash			= udp_v4_rehash, | ||||
| 	.get_port		= udp_v4_get_port, | ||||
| #ifdef CONFIG_BPF_SYSCALL | ||||
| 	.psock_update_sk_prot	= udp_bpf_update_proto, | ||||
| #endif | ||||
| 	.memory_allocated	= &udp_memory_allocated, | ||||
| 	.sysctl_mem		= sysctl_udp_mem, | ||||
| 	.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_udp_wmem_min), | ||||
|  |  | |||
|  | @ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void) | |||
| } | ||||
| core_initcall(udp_bpf_v4_build_proto); | ||||
| 
 | ||||
| struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) | ||||
| int udp_bpf_update_proto(struct sock *sk, bool restore) | ||||
| { | ||||
| 	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; | ||||
| 	struct sk_psock *psock = sk_psock(sk); | ||||
| 
 | ||||
| 	if (restore) { | ||||
| 		sk->sk_write_space = psock->saved_write_space; | ||||
| 		/* Pairs with lockless read in sk_clone_lock() */ | ||||
| 		WRITE_ONCE(sk->sk_prot, psock->sk_proto); | ||||
| 		return 0; | ||||
| 	} | ||||
| 
 | ||||
| 	if (sk->sk_family == AF_INET6) | ||||
| 		udp_bpf_check_v6_needs_rebuild(psock->sk_proto); | ||||
| 
 | ||||
| 	return &udp_bpf_prots[family]; | ||||
| 	/* Pairs with lockless read in sk_clone_lock() */ | ||||
| 	WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]); | ||||
| 	return 0; | ||||
| } | ||||
| EXPORT_SYMBOL_GPL(udp_bpf_update_proto); | ||||
|  |  | |||
|  | @ -2139,6 +2139,9 @@ struct proto tcpv6_prot = { | |||
| 	.hash			= inet6_hash, | ||||
| 	.unhash			= inet_unhash, | ||||
| 	.get_port		= inet_csk_get_port, | ||||
| #ifdef CONFIG_BPF_SYSCALL | ||||
| 	.psock_update_sk_prot	= tcp_bpf_update_proto, | ||||
| #endif | ||||
| 	.enter_memory_pressure	= tcp_enter_memory_pressure, | ||||
| 	.leave_memory_pressure	= tcp_leave_memory_pressure, | ||||
| 	.stream_memory_free	= tcp_stream_memory_free, | ||||
|  |  | |||
|  | @ -1713,6 +1713,9 @@ struct proto udpv6_prot = { | |||
| 	.unhash			= udp_lib_unhash, | ||||
| 	.rehash			= udp_v6_rehash, | ||||
| 	.get_port		= udp_v6_get_port, | ||||
| #ifdef CONFIG_BPF_SYSCALL | ||||
| 	.psock_update_sk_prot	= udp_bpf_update_proto, | ||||
| #endif | ||||
| 	.memory_allocated	= &udp_memory_allocated, | ||||
| 	.sysctl_mem		= sysctl_udp_mem, | ||||
| 	.sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_udp_wmem_min), | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue
	
	 Cong Wang
						Cong Wang