forked from mirrors/linux
		
	tls: rx: wrap decrypt params in a struct
The max size of iv + aad + tail is 22B. That's smaller than a single sg entry (32B). Don't bother with the memory packing, just create a struct which holds the max size of those members. Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
		
							parent
							
								
									50a07aa531
								
							
						
					
					
						commit
						b89fec54fd
					
				
					 1 changed files with 30 additions and 30 deletions
				
			
		| 
						 | 
					@ -50,6 +50,13 @@ struct tls_decrypt_arg {
 | 
				
			||||||
	u8 tail;
 | 
						u8 tail;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct tls_decrypt_ctx {
 | 
				
			||||||
 | 
						u8 iv[MAX_IV_SIZE];
 | 
				
			||||||
 | 
						u8 aad[TLS_MAX_AAD_SIZE];
 | 
				
			||||||
 | 
						u8 tail;
 | 
				
			||||||
 | 
						struct scatterlist sg[];
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
noinline void tls_err_abort(struct sock *sk, int err)
 | 
					noinline void tls_err_abort(struct sock *sk, int err)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	WARN_ON_ONCE(err >= 0);
 | 
						WARN_ON_ONCE(err >= 0);
 | 
				
			||||||
| 
						 | 
					@ -1414,17 +1421,18 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 | 
				
			||||||
	struct tls_context *tls_ctx = tls_get_ctx(sk);
 | 
						struct tls_context *tls_ctx = tls_get_ctx(sk);
 | 
				
			||||||
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 | 
						struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 | 
				
			||||||
	struct tls_prot_info *prot = &tls_ctx->prot_info;
 | 
						struct tls_prot_info *prot = &tls_ctx->prot_info;
 | 
				
			||||||
 | 
						int n_sgin, n_sgout, aead_size, err, pages = 0;
 | 
				
			||||||
	struct strp_msg *rxm = strp_msg(skb);
 | 
						struct strp_msg *rxm = strp_msg(skb);
 | 
				
			||||||
	struct tls_msg *tlm = tls_msg(skb);
 | 
						struct tls_msg *tlm = tls_msg(skb);
 | 
				
			||||||
	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
 | 
					 | 
				
			||||||
	u8 *aad, *iv, *tail, *mem = NULL;
 | 
					 | 
				
			||||||
	struct aead_request *aead_req;
 | 
						struct aead_request *aead_req;
 | 
				
			||||||
	struct sk_buff *unused;
 | 
						struct sk_buff *unused;
 | 
				
			||||||
	struct scatterlist *sgin = NULL;
 | 
						struct scatterlist *sgin = NULL;
 | 
				
			||||||
	struct scatterlist *sgout = NULL;
 | 
						struct scatterlist *sgout = NULL;
 | 
				
			||||||
	const int data_len = rxm->full_len - prot->overhead_size;
 | 
						const int data_len = rxm->full_len - prot->overhead_size;
 | 
				
			||||||
	int tail_pages = !!prot->tail_size;
 | 
						int tail_pages = !!prot->tail_size;
 | 
				
			||||||
 | 
						struct tls_decrypt_ctx *dctx;
 | 
				
			||||||
	int iv_offset = 0;
 | 
						int iv_offset = 0;
 | 
				
			||||||
 | 
						u8 *mem;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if (darg->zc && (out_iov || out_sg)) {
 | 
						if (darg->zc && (out_iov || out_sg)) {
 | 
				
			||||||
		if (out_iov)
 | 
							if (out_iov)
 | 
				
			||||||
| 
						 | 
					@ -1446,38 +1454,30 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 | 
				
			||||||
	/* Increment to accommodate AAD */
 | 
						/* Increment to accommodate AAD */
 | 
				
			||||||
	n_sgin = n_sgin + 1;
 | 
						n_sgin = n_sgin + 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	nsg = n_sgin + n_sgout;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
 | 
					 | 
				
			||||||
	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
 | 
					 | 
				
			||||||
	mem_size = mem_size + TLS_MAX_AAD_SIZE;
 | 
					 | 
				
			||||||
	mem_size = mem_size + MAX_IV_SIZE;
 | 
					 | 
				
			||||||
	mem_size = mem_size + prot->tail_size;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	/* Allocate a single block of memory which contains
 | 
						/* Allocate a single block of memory which contains
 | 
				
			||||||
	 * aead_req || sgin[] || sgout[] || aad || iv || tail.
 | 
						 *   aead_req || tls_decrypt_ctx.
 | 
				
			||||||
	 * This order achieves correct alignment for aead_req, sgin, sgout.
 | 
						 * Both structs are variable length.
 | 
				
			||||||
	 */
 | 
						 */
 | 
				
			||||||
	mem = kmalloc(mem_size, sk->sk_allocation);
 | 
						aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
 | 
				
			||||||
 | 
						mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
 | 
				
			||||||
 | 
							      sk->sk_allocation);
 | 
				
			||||||
	if (!mem)
 | 
						if (!mem)
 | 
				
			||||||
		return -ENOMEM;
 | 
							return -ENOMEM;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	/* Segment the allocated memory */
 | 
						/* Segment the allocated memory */
 | 
				
			||||||
	aead_req = (struct aead_request *)mem;
 | 
						aead_req = (struct aead_request *)mem;
 | 
				
			||||||
	sgin = (struct scatterlist *)(mem + aead_size);
 | 
						dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
 | 
				
			||||||
	sgout = sgin + n_sgin;
 | 
						sgin = &dctx->sg[0];
 | 
				
			||||||
	aad = (u8 *)(sgout + n_sgout);
 | 
						sgout = &dctx->sg[n_sgin];
 | 
				
			||||||
	iv = aad + TLS_MAX_AAD_SIZE;
 | 
					 | 
				
			||||||
	tail = iv + MAX_IV_SIZE;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	/* For CCM based ciphers, first byte of nonce+iv is a constant */
 | 
						/* For CCM based ciphers, first byte of nonce+iv is a constant */
 | 
				
			||||||
	switch (prot->cipher_type) {
 | 
						switch (prot->cipher_type) {
 | 
				
			||||||
	case TLS_CIPHER_AES_CCM_128:
 | 
						case TLS_CIPHER_AES_CCM_128:
 | 
				
			||||||
		iv[0] = TLS_AES_CCM_IV_B0_BYTE;
 | 
							dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
 | 
				
			||||||
		iv_offset = 1;
 | 
							iv_offset = 1;
 | 
				
			||||||
		break;
 | 
							break;
 | 
				
			||||||
	case TLS_CIPHER_SM4_CCM:
 | 
						case TLS_CIPHER_SM4_CCM:
 | 
				
			||||||
		iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
 | 
							dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
 | 
				
			||||||
		iv_offset = 1;
 | 
							iv_offset = 1;
 | 
				
			||||||
		break;
 | 
							break;
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -1485,28 +1485,28 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 | 
				
			||||||
	/* Prepare IV */
 | 
						/* Prepare IV */
 | 
				
			||||||
	if (prot->version == TLS_1_3_VERSION ||
 | 
						if (prot->version == TLS_1_3_VERSION ||
 | 
				
			||||||
	    prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
 | 
						    prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
 | 
				
			||||||
		memcpy(iv + iv_offset, tls_ctx->rx.iv,
 | 
							memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
 | 
				
			||||||
		       prot->iv_size + prot->salt_size);
 | 
							       prot->iv_size + prot->salt_size);
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
 | 
							err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
 | 
				
			||||||
				    iv + iv_offset + prot->salt_size,
 | 
									    &dctx->iv[iv_offset] + prot->salt_size,
 | 
				
			||||||
				    prot->iv_size);
 | 
									    prot->iv_size);
 | 
				
			||||||
		if (err < 0) {
 | 
							if (err < 0) {
 | 
				
			||||||
			kfree(mem);
 | 
								kfree(mem);
 | 
				
			||||||
			return err;
 | 
								return err;
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
 | 
							memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
 | 
						xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	/* Prepare AAD */
 | 
						/* Prepare AAD */
 | 
				
			||||||
	tls_make_aad(aad, rxm->full_len - prot->overhead_size +
 | 
						tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
 | 
				
			||||||
		     prot->tail_size,
 | 
							     prot->tail_size,
 | 
				
			||||||
		     tls_ctx->rx.rec_seq, tlm->control, prot);
 | 
							     tls_ctx->rx.rec_seq, tlm->control, prot);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	/* Prepare sgin */
 | 
						/* Prepare sgin */
 | 
				
			||||||
	sg_init_table(sgin, n_sgin);
 | 
						sg_init_table(sgin, n_sgin);
 | 
				
			||||||
	sg_set_buf(&sgin[0], aad, prot->aad_size);
 | 
						sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
 | 
				
			||||||
	err = skb_to_sgvec(skb, &sgin[1],
 | 
						err = skb_to_sgvec(skb, &sgin[1],
 | 
				
			||||||
			   rxm->offset + prot->prepend_size,
 | 
								   rxm->offset + prot->prepend_size,
 | 
				
			||||||
			   rxm->full_len - prot->prepend_size);
 | 
								   rxm->full_len - prot->prepend_size);
 | 
				
			||||||
| 
						 | 
					@ -1518,7 +1518,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 | 
				
			||||||
	if (n_sgout) {
 | 
						if (n_sgout) {
 | 
				
			||||||
		if (out_iov) {
 | 
							if (out_iov) {
 | 
				
			||||||
			sg_init_table(sgout, n_sgout);
 | 
								sg_init_table(sgout, n_sgout);
 | 
				
			||||||
			sg_set_buf(&sgout[0], aad, prot->aad_size);
 | 
								sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			err = tls_setup_from_iter(out_iov, data_len,
 | 
								err = tls_setup_from_iter(out_iov, data_len,
 | 
				
			||||||
						  &pages, &sgout[1],
 | 
											  &pages, &sgout[1],
 | 
				
			||||||
| 
						 | 
					@ -1528,7 +1528,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if (prot->tail_size) {
 | 
								if (prot->tail_size) {
 | 
				
			||||||
				sg_unmark_end(&sgout[pages]);
 | 
									sg_unmark_end(&sgout[pages]);
 | 
				
			||||||
				sg_set_buf(&sgout[pages + 1], tail,
 | 
									sg_set_buf(&sgout[pages + 1], &dctx->tail,
 | 
				
			||||||
					   prot->tail_size);
 | 
										   prot->tail_size);
 | 
				
			||||||
				sg_mark_end(&sgout[pages + 1]);
 | 
									sg_mark_end(&sgout[pages + 1]);
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					@ -1545,13 +1545,13 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	/* Prepare and submit AEAD request */
 | 
						/* Prepare and submit AEAD request */
 | 
				
			||||||
	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
 | 
						err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
 | 
				
			||||||
				data_len + prot->tail_size, aead_req, darg);
 | 
									data_len + prot->tail_size, aead_req, darg);
 | 
				
			||||||
	if (darg->async)
 | 
						if (darg->async)
 | 
				
			||||||
		return 0;
 | 
							return 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if (prot->tail_size)
 | 
						if (prot->tail_size)
 | 
				
			||||||
		darg->tail = *tail;
 | 
							darg->tail = dctx->tail;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	/* Release the pages in case iov was mapped to pages */
 | 
						/* Release the pages in case iov was mapped to pages */
 | 
				
			||||||
	for (; pages > 0; pages--)
 | 
						for (; pages > 0; pages--)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue