forked from mirrors/linux
		
	bpf: helper to pop data from messages
This adds a BPF SK_MSG program helper so that we can pop data from a msg. We use this to pop metadata from a previous push data call. Signed-off-by: John Fastabend <john.fastabend@gmail.com> Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
This commit is contained in:
		
							parent
							
								
									17d95e4225
								
							
						
					
					
						commit
						7246d8ed4d
					
				
					 4 changed files with 209 additions and 6 deletions
				
			
		| 
						 | 
					@ -2268,6 +2268,19 @@ union bpf_attr {
 | 
				
			||||||
 *
 | 
					 *
 | 
				
			||||||
 *	Return
 | 
					 *	Return
 | 
				
			||||||
 *		0 on success, or a negative error in case of failure.
 | 
					 *		0 on success, or a negative error in case of failure.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * int bpf_msg_pop_data(struct sk_msg_buff *msg, u32 start, u32 pop, u64 flags)
 | 
				
			||||||
 | 
					 *	 Description
 | 
				
			||||||
 | 
					 *		Will remove *pop* bytes from a *msg* starting at byte *start*.
 | 
				
			||||||
 | 
					 *		This may result in **ENOMEM** errors under certain situations if
 | 
				
			||||||
 | 
					 *		an allocation and copy are required due to a full ring buffer.
 | 
				
			||||||
 | 
					 *		However, the helper will try to avoid doing the allocation
 | 
				
			||||||
 | 
					 *		if possible. Other errors can occur if input parameters are
 | 
				
			||||||
 | 
					 *		invalid either due to *start* byte not being valid part of msg
 | 
				
			||||||
 | 
					 *		payload and/or *pop* value being to large.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *	Return
 | 
				
			||||||
 | 
					 *		0 on success, or a negative erro in case of failure.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
#define __BPF_FUNC_MAPPER(FN)		\
 | 
					#define __BPF_FUNC_MAPPER(FN)		\
 | 
				
			||||||
	FN(unspec),			\
 | 
						FN(unspec),			\
 | 
				
			||||||
| 
						 | 
					@ -2360,7 +2373,8 @@ union bpf_attr {
 | 
				
			||||||
	FN(map_push_elem),		\
 | 
						FN(map_push_elem),		\
 | 
				
			||||||
	FN(map_pop_elem),		\
 | 
						FN(map_pop_elem),		\
 | 
				
			||||||
	FN(map_peek_elem),		\
 | 
						FN(map_peek_elem),		\
 | 
				
			||||||
	FN(msg_push_data),
 | 
						FN(msg_push_data),		\
 | 
				
			||||||
 | 
						FN(msg_pop_data),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* integer value in 'imm' field of BPF_CALL instruction selects which helper
 | 
					/* integer value in 'imm' field of BPF_CALL instruction selects which helper
 | 
				
			||||||
 * function eBPF program intends to call
 | 
					 * function eBPF program intends to call
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2425,6 +2425,174 @@ static const struct bpf_func_proto bpf_msg_push_data_proto = {
 | 
				
			||||||
	.arg4_type	= ARG_ANYTHING,
 | 
						.arg4_type	= ARG_ANYTHING,
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static void sk_msg_shift_left(struct sk_msg *msg, int i)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
						int prev;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						do {
 | 
				
			||||||
 | 
							prev = i;
 | 
				
			||||||
 | 
							sk_msg_iter_var_next(i);
 | 
				
			||||||
 | 
							msg->sg.data[prev] = msg->sg.data[i];
 | 
				
			||||||
 | 
						} while (i != msg->sg.end);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sk_msg_iter_prev(msg, end);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static void sk_msg_shift_right(struct sk_msg *msg, int i)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
						struct scatterlist tmp, sge;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sk_msg_iter_next(msg, end);
 | 
				
			||||||
 | 
						sge = sk_msg_elem_cpy(msg, i);
 | 
				
			||||||
 | 
						sk_msg_iter_var_next(i);
 | 
				
			||||||
 | 
						tmp = sk_msg_elem_cpy(msg, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						while (i != msg->sg.end) {
 | 
				
			||||||
 | 
							msg->sg.data[i] = sge;
 | 
				
			||||||
 | 
							sk_msg_iter_var_next(i);
 | 
				
			||||||
 | 
							sge = tmp;
 | 
				
			||||||
 | 
							tmp = sk_msg_elem_cpy(msg, i);
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, start,
 | 
				
			||||||
 | 
						   u32, len, u64, flags)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
						u32 i = 0, l, space, offset = 0;
 | 
				
			||||||
 | 
						u64 last = start + len;
 | 
				
			||||||
 | 
						int pop;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if (unlikely(flags))
 | 
				
			||||||
 | 
							return -EINVAL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						/* First find the starting scatterlist element */
 | 
				
			||||||
 | 
						i = msg->sg.start;
 | 
				
			||||||
 | 
						do {
 | 
				
			||||||
 | 
							l = sk_msg_elem(msg, i)->length;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if (start < offset + l)
 | 
				
			||||||
 | 
								break;
 | 
				
			||||||
 | 
							offset += l;
 | 
				
			||||||
 | 
							sk_msg_iter_var_next(i);
 | 
				
			||||||
 | 
						} while (i != msg->sg.end);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						/* Bounds checks: start and pop must be inside message */
 | 
				
			||||||
 | 
						if (start >= offset + l || last >= msg->sg.size)
 | 
				
			||||||
 | 
							return -EINVAL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						space = MAX_MSG_FRAGS - sk_msg_elem_used(msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pop = len;
 | 
				
			||||||
 | 
						/* --------------| offset
 | 
				
			||||||
 | 
						 * -| start      |-------- len -------|
 | 
				
			||||||
 | 
						 *
 | 
				
			||||||
 | 
						 *  |----- a ----|-------- pop -------|----- b ----|
 | 
				
			||||||
 | 
						 *  |______________________________________________| length
 | 
				
			||||||
 | 
						 *
 | 
				
			||||||
 | 
						 *
 | 
				
			||||||
 | 
						 * a:   region at front of scatter element to save
 | 
				
			||||||
 | 
						 * b:   region at back of scatter element to save when length > A + pop
 | 
				
			||||||
 | 
						 * pop: region to pop from element, same as input 'pop' here will be
 | 
				
			||||||
 | 
						 *      decremented below per iteration.
 | 
				
			||||||
 | 
						 *
 | 
				
			||||||
 | 
						 * Two top-level cases to handle when start != offset, first B is non
 | 
				
			||||||
 | 
						 * zero and second B is zero corresponding to when a pop includes more
 | 
				
			||||||
 | 
						 * than one element.
 | 
				
			||||||
 | 
						 *
 | 
				
			||||||
 | 
						 * Then if B is non-zero AND there is no space allocate space and
 | 
				
			||||||
 | 
						 * compact A, B regions into page. If there is space shift ring to
 | 
				
			||||||
 | 
						 * the rigth free'ing the next element in ring to place B, leaving
 | 
				
			||||||
 | 
						 * A untouched except to reduce length.
 | 
				
			||||||
 | 
						 */
 | 
				
			||||||
 | 
						if (start != offset) {
 | 
				
			||||||
 | 
							struct scatterlist *nsge, *sge = sk_msg_elem(msg, i);
 | 
				
			||||||
 | 
							int a = start;
 | 
				
			||||||
 | 
							int b = sge->length - pop - a;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							sk_msg_iter_var_next(i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if (pop < sge->length - a) {
 | 
				
			||||||
 | 
								if (space) {
 | 
				
			||||||
 | 
									sge->length = a;
 | 
				
			||||||
 | 
									sk_msg_shift_right(msg, i);
 | 
				
			||||||
 | 
									nsge = sk_msg_elem(msg, i);
 | 
				
			||||||
 | 
									get_page(sg_page(sge));
 | 
				
			||||||
 | 
									sg_set_page(nsge,
 | 
				
			||||||
 | 
										    sg_page(sge),
 | 
				
			||||||
 | 
										    b, sge->offset + pop + a);
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									struct page *page, *orig;
 | 
				
			||||||
 | 
									u8 *to, *from;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									page = alloc_pages(__GFP_NOWARN |
 | 
				
			||||||
 | 
											   __GFP_COMP   | GFP_ATOMIC,
 | 
				
			||||||
 | 
											   get_order(a + b));
 | 
				
			||||||
 | 
									if (unlikely(!page))
 | 
				
			||||||
 | 
										return -ENOMEM;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									sge->length = a;
 | 
				
			||||||
 | 
									orig = sg_page(sge);
 | 
				
			||||||
 | 
									from = sg_virt(sge);
 | 
				
			||||||
 | 
									to = page_address(page);
 | 
				
			||||||
 | 
									memcpy(to, from, a);
 | 
				
			||||||
 | 
									memcpy(to + a, from + a + pop, b);
 | 
				
			||||||
 | 
									sg_set_page(sge, page, a + b, 0);
 | 
				
			||||||
 | 
									put_page(orig);
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								pop = 0;
 | 
				
			||||||
 | 
							} else if (pop >= sge->length - a) {
 | 
				
			||||||
 | 
								sge->length = a;
 | 
				
			||||||
 | 
								pop -= (sge->length - a);
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						/* From above the current layout _must_ be as follows,
 | 
				
			||||||
 | 
						 *
 | 
				
			||||||
 | 
						 * -| offset
 | 
				
			||||||
 | 
						 * -| start
 | 
				
			||||||
 | 
						 *
 | 
				
			||||||
 | 
						 *  |---- pop ---|---------------- b ------------|
 | 
				
			||||||
 | 
						 *  |____________________________________________| length
 | 
				
			||||||
 | 
						 *
 | 
				
			||||||
 | 
						 * Offset and start of the current msg elem are equal because in the
 | 
				
			||||||
 | 
						 * previous case we handled offset != start and either consumed the
 | 
				
			||||||
 | 
						 * entire element and advanced to the next element OR pop == 0.
 | 
				
			||||||
 | 
						 *
 | 
				
			||||||
 | 
						 * Two cases to handle here are first pop is less than the length
 | 
				
			||||||
 | 
						 * leaving some remainder b above. Simply adjust the element's layout
 | 
				
			||||||
 | 
						 * in this case. Or pop >= length of the element so that b = 0. In this
 | 
				
			||||||
 | 
						 * case advance to next element decrementing pop.
 | 
				
			||||||
 | 
						 */
 | 
				
			||||||
 | 
						while (pop) {
 | 
				
			||||||
 | 
							struct scatterlist *sge = sk_msg_elem(msg, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if (pop < sge->length) {
 | 
				
			||||||
 | 
								sge->length -= pop;
 | 
				
			||||||
 | 
								sge->offset += pop;
 | 
				
			||||||
 | 
								pop = 0;
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								pop -= sge->length;
 | 
				
			||||||
 | 
								sk_msg_shift_left(msg, i);
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							sk_msg_iter_var_next(i);
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sk_mem_uncharge(msg->sk, len - pop);
 | 
				
			||||||
 | 
						msg->sg.size -= (len - pop);
 | 
				
			||||||
 | 
						sk_msg_compute_data_pointers(msg);
 | 
				
			||||||
 | 
						return 0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static const struct bpf_func_proto bpf_msg_pop_data_proto = {
 | 
				
			||||||
 | 
						.func		= bpf_msg_pop_data,
 | 
				
			||||||
 | 
						.gpl_only	= false,
 | 
				
			||||||
 | 
						.ret_type	= RET_INTEGER,
 | 
				
			||||||
 | 
						.arg1_type	= ARG_PTR_TO_CTX,
 | 
				
			||||||
 | 
						.arg2_type	= ARG_ANYTHING,
 | 
				
			||||||
 | 
						.arg3_type	= ARG_ANYTHING,
 | 
				
			||||||
 | 
						.arg4_type	= ARG_ANYTHING,
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
 | 
					BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	return task_get_classid(skb);
 | 
						return task_get_classid(skb);
 | 
				
			||||||
| 
						 | 
					@ -5098,6 +5266,7 @@ bool bpf_helper_changes_pkt_data(void *func)
 | 
				
			||||||
	    func == bpf_xdp_adjust_meta ||
 | 
						    func == bpf_xdp_adjust_meta ||
 | 
				
			||||||
	    func == bpf_msg_pull_data ||
 | 
						    func == bpf_msg_pull_data ||
 | 
				
			||||||
	    func == bpf_msg_push_data ||
 | 
						    func == bpf_msg_push_data ||
 | 
				
			||||||
 | 
						    func == bpf_msg_pop_data ||
 | 
				
			||||||
	    func == bpf_xdp_adjust_tail ||
 | 
						    func == bpf_xdp_adjust_tail ||
 | 
				
			||||||
#if IS_ENABLED(CONFIG_IPV6_SEG6_BPF)
 | 
					#if IS_ENABLED(CONFIG_IPV6_SEG6_BPF)
 | 
				
			||||||
	    func == bpf_lwt_seg6_store_bytes ||
 | 
						    func == bpf_lwt_seg6_store_bytes ||
 | 
				
			||||||
| 
						 | 
					@ -5394,6 +5563,8 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
 | 
				
			||||||
		return &bpf_msg_pull_data_proto;
 | 
							return &bpf_msg_pull_data_proto;
 | 
				
			||||||
	case BPF_FUNC_msg_push_data:
 | 
						case BPF_FUNC_msg_push_data:
 | 
				
			||||||
		return &bpf_msg_push_data_proto;
 | 
							return &bpf_msg_push_data_proto;
 | 
				
			||||||
 | 
						case BPF_FUNC_msg_pop_data:
 | 
				
			||||||
 | 
							return &bpf_msg_pop_data_proto;
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return bpf_base_func_proto(func_id);
 | 
							return bpf_base_func_proto(func_id);
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -289,12 +289,23 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	bool cork = false, enospc = msg->sg.start == msg->sg.end;
 | 
						bool cork = false, enospc = msg->sg.start == msg->sg.end;
 | 
				
			||||||
	struct sock *sk_redir;
 | 
						struct sock *sk_redir;
 | 
				
			||||||
	u32 tosend;
 | 
						u32 tosend, delta = 0;
 | 
				
			||||||
	int ret;
 | 
						int ret;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
more_data:
 | 
					more_data:
 | 
				
			||||||
	if (psock->eval == __SK_NONE)
 | 
						if (psock->eval == __SK_NONE) {
 | 
				
			||||||
 | 
							/* Track delta in msg size to add/subtract it on SK_DROP from
 | 
				
			||||||
 | 
							 * returned to user copied size. This ensures user doesn't
 | 
				
			||||||
 | 
							 * get a positive return code with msg_cut_data and SK_DROP
 | 
				
			||||||
 | 
							 * verdict.
 | 
				
			||||||
 | 
							 */
 | 
				
			||||||
 | 
							delta = msg->sg.size;
 | 
				
			||||||
		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
 | 
							psock->eval = sk_psock_msg_verdict(sk, psock, msg);
 | 
				
			||||||
 | 
							if (msg->sg.size < delta)
 | 
				
			||||||
 | 
								delta -= msg->sg.size;
 | 
				
			||||||
 | 
							else
 | 
				
			||||||
 | 
								delta = 0;
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if (msg->cork_bytes &&
 | 
						if (msg->cork_bytes &&
 | 
				
			||||||
	    msg->cork_bytes > msg->sg.size && !enospc) {
 | 
						    msg->cork_bytes > msg->sg.size && !enospc) {
 | 
				
			||||||
| 
						 | 
					@ -350,7 +361,7 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		sk_msg_free_partial(sk, msg, tosend);
 | 
							sk_msg_free_partial(sk, msg, tosend);
 | 
				
			||||||
		sk_msg_apply_bytes(psock, tosend);
 | 
							sk_msg_apply_bytes(psock, tosend);
 | 
				
			||||||
		*copied -= tosend;
 | 
							*copied -= (tosend + delta);
 | 
				
			||||||
		return -EACCES;
 | 
							return -EACCES;
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -687,6 +687,7 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
 | 
				
			||||||
	struct sock *sk_redir;
 | 
						struct sock *sk_redir;
 | 
				
			||||||
	struct tls_rec *rec;
 | 
						struct tls_rec *rec;
 | 
				
			||||||
	int err = 0, send;
 | 
						int err = 0, send;
 | 
				
			||||||
 | 
						u32 delta = 0;
 | 
				
			||||||
	bool enospc;
 | 
						bool enospc;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	psock = sk_psock_get(sk);
 | 
						psock = sk_psock_get(sk);
 | 
				
			||||||
| 
						 | 
					@ -694,8 +695,14 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
 | 
				
			||||||
		return tls_push_record(sk, flags, record_type);
 | 
							return tls_push_record(sk, flags, record_type);
 | 
				
			||||||
more_data:
 | 
					more_data:
 | 
				
			||||||
	enospc = sk_msg_full(msg);
 | 
						enospc = sk_msg_full(msg);
 | 
				
			||||||
	if (psock->eval == __SK_NONE)
 | 
						if (psock->eval == __SK_NONE) {
 | 
				
			||||||
 | 
							delta = msg->sg.size;
 | 
				
			||||||
		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
 | 
							psock->eval = sk_psock_msg_verdict(sk, psock, msg);
 | 
				
			||||||
 | 
							if (delta < msg->sg.size)
 | 
				
			||||||
 | 
								delta -= msg->sg.size;
 | 
				
			||||||
 | 
							else
 | 
				
			||||||
 | 
								delta = 0;
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
 | 
						if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
 | 
				
			||||||
	    !enospc && !full_record) {
 | 
						    !enospc && !full_record) {
 | 
				
			||||||
		err = -ENOSPC;
 | 
							err = -ENOSPC;
 | 
				
			||||||
| 
						 | 
					@ -743,7 +750,7 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
 | 
				
			||||||
			msg->apply_bytes -= send;
 | 
								msg->apply_bytes -= send;
 | 
				
			||||||
		if (msg->sg.size == 0)
 | 
							if (msg->sg.size == 0)
 | 
				
			||||||
			tls_free_open_rec(sk);
 | 
								tls_free_open_rec(sk);
 | 
				
			||||||
		*copied -= send;
 | 
							*copied -= (send + delta);
 | 
				
			||||||
		err = -EACCES;
 | 
							err = -EACCES;
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue