3
0
Fork 0
forked from mirrors/linux

riscv: make unsafe user copy routines use existing assembly routines

The current implementation is underperforming and in addition, it
triggers misaligned access traps on platforms which do not handle
misaligned accesses in hardware.

Use the existing assembly routines to solve both problems at once.

Signed-off-by: Alexandre Ghiti <alexghiti@rivosinc.com>
Link: https://lore.kernel.org/r/20250602193918.868962-2-cleger@rivosinc.com
Signed-off-by: Palmer Dabbelt <palmer@dabbelt.com>
This commit is contained in:
Alexandre Ghiti 2025-06-02 21:39:14 +02:00 committed by Palmer Dabbelt
parent 259aaf03d7
commit a434854633
No known key found for this signature in database
GPG key ID: 2E1319F35FBB1889
5 changed files with 63 additions and 48 deletions

View file

@ -12,7 +12,7 @@ long long __ashlti3(long long a, int b);
#ifdef CONFIG_RISCV_ISA_V #ifdef CONFIG_RISCV_ISA_V
#ifdef CONFIG_MMU #ifdef CONFIG_MMU
asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n); asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n, bool enable_sum);
#endif /* CONFIG_MMU */ #endif /* CONFIG_MMU */
void xor_regs_2_(unsigned long bytes, unsigned long *__restrict p1, void xor_regs_2_(unsigned long bytes, unsigned long *__restrict p1,

View file

@ -450,35 +450,18 @@ static inline void user_access_restore(unsigned long enabled) { }
(x) = (__force __typeof__(*(ptr)))__gu_val; \ (x) = (__force __typeof__(*(ptr)))__gu_val; \
} while (0) } while (0)
#define unsafe_copy_loop(dst, src, len, type, op, label) \ unsigned long __must_check __asm_copy_to_user_sum_enabled(void __user *to,
while (len >= sizeof(type)) { \ const void *from, unsigned long n);
op(*(type *)(src), (type __user *)(dst), label); \ unsigned long __must_check __asm_copy_from_user_sum_enabled(void *to,
dst += sizeof(type); \ const void __user *from, unsigned long n);
src += sizeof(type); \
len -= sizeof(type); \
}
#define unsafe_copy_to_user(_dst, _src, _len, label) \ #define unsafe_copy_to_user(_dst, _src, _len, label) \
do { \ if (__asm_copy_to_user_sum_enabled(_dst, _src, _len)) \
char __user *__ucu_dst = (_dst); \ goto label;
const char *__ucu_src = (_src); \
size_t __ucu_len = (_len); \
unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u64, unsafe_put_user, label); \
unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u32, unsafe_put_user, label); \
unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u16, unsafe_put_user, label); \
unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u8, unsafe_put_user, label); \
} while (0)
#define unsafe_copy_from_user(_dst, _src, _len, label) \ #define unsafe_copy_from_user(_dst, _src, _len, label) \
do { \ if (__asm_copy_from_user_sum_enabled(_dst, _src, _len)) \
char *__ucu_dst = (_dst); \ goto label;
const char __user *__ucu_src = (_src); \
size_t __ucu_len = (_len); \
unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u64, unsafe_get_user, label); \
unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u32, unsafe_get_user, label); \
unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u16, unsafe_get_user, label); \
unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u8, unsafe_get_user, label); \
} while (0)
#else /* CONFIG_MMU */ #else /* CONFIG_MMU */
#include <asm-generic/uaccess.h> #include <asm-generic/uaccess.h>

View file

@ -16,8 +16,11 @@
#ifdef CONFIG_MMU #ifdef CONFIG_MMU
size_t riscv_v_usercopy_threshold = CONFIG_RISCV_ISA_V_UCOPY_THRESHOLD; size_t riscv_v_usercopy_threshold = CONFIG_RISCV_ISA_V_UCOPY_THRESHOLD;
int __asm_vector_usercopy(void *dst, void *src, size_t n); int __asm_vector_usercopy(void *dst, void *src, size_t n);
int __asm_vector_usercopy_sum_enabled(void *dst, void *src, size_t n);
int fallback_scalar_usercopy(void *dst, void *src, size_t n); int fallback_scalar_usercopy(void *dst, void *src, size_t n);
asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n) int fallback_scalar_usercopy_sum_enabled(void *dst, void *src, size_t n);
asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n,
bool enable_sum)
{ {
size_t remain, copied; size_t remain, copied;
@ -26,7 +29,8 @@ asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n)
goto fallback; goto fallback;
kernel_vector_begin(); kernel_vector_begin();
remain = __asm_vector_usercopy(dst, src, n); remain = enable_sum ? __asm_vector_usercopy(dst, src, n) :
__asm_vector_usercopy_sum_enabled(dst, src, n);
kernel_vector_end(); kernel_vector_end();
if (remain) { if (remain) {
@ -40,6 +44,7 @@ asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n)
return remain; return remain;
fallback: fallback:
return fallback_scalar_usercopy(dst, src, n); return enable_sum ? fallback_scalar_usercopy(dst, src, n) :
fallback_scalar_usercopy_sum_enabled(dst, src, n);
} }
#endif #endif

View file

@ -17,14 +17,43 @@ SYM_FUNC_START(__asm_copy_to_user)
ALTERNATIVE("j fallback_scalar_usercopy", "nop", 0, RISCV_ISA_EXT_ZVE32X, CONFIG_RISCV_ISA_V) ALTERNATIVE("j fallback_scalar_usercopy", "nop", 0, RISCV_ISA_EXT_ZVE32X, CONFIG_RISCV_ISA_V)
REG_L t0, riscv_v_usercopy_threshold REG_L t0, riscv_v_usercopy_threshold
bltu a2, t0, fallback_scalar_usercopy bltu a2, t0, fallback_scalar_usercopy
tail enter_vector_usercopy li a3, 1
tail enter_vector_usercopy
#endif #endif
SYM_FUNC_END(__asm_copy_to_user)
EXPORT_SYMBOL(__asm_copy_to_user)
SYM_FUNC_ALIAS(__asm_copy_from_user, __asm_copy_to_user)
EXPORT_SYMBOL(__asm_copy_from_user)
SYM_FUNC_START(fallback_scalar_usercopy) SYM_FUNC_START(fallback_scalar_usercopy)
/* Enable access to user memory */ /* Enable access to user memory */
li t6, SR_SUM li t6, SR_SUM
csrs CSR_STATUS, t6 csrs CSR_STATUS, t6
mv t6, ra
call fallback_scalar_usercopy_sum_enabled
/* Disable access to user memory */
mv ra, t6
li t6, SR_SUM
csrc CSR_STATUS, t6
ret
SYM_FUNC_END(fallback_scalar_usercopy)
SYM_FUNC_START(__asm_copy_to_user_sum_enabled)
#ifdef CONFIG_RISCV_ISA_V
ALTERNATIVE("j fallback_scalar_usercopy_sum_enabled", "nop", 0, RISCV_ISA_EXT_ZVE32X, CONFIG_RISCV_ISA_V)
REG_L t0, riscv_v_usercopy_threshold
bltu a2, t0, fallback_scalar_usercopy_sum_enabled
li a3, 0
tail enter_vector_usercopy
#endif
SYM_FUNC_END(__asm_copy_to_user_sum_enabled)
SYM_FUNC_ALIAS(__asm_copy_from_user_sum_enabled, __asm_copy_to_user_sum_enabled)
EXPORT_SYMBOL(__asm_copy_from_user_sum_enabled)
EXPORT_SYMBOL(__asm_copy_to_user_sum_enabled)
SYM_FUNC_START(fallback_scalar_usercopy_sum_enabled)
/* /*
* Save the terminal address which will be used to compute the number * Save the terminal address which will be used to compute the number
* of bytes copied in case of a fixup exception. * of bytes copied in case of a fixup exception.
@ -178,23 +207,12 @@ SYM_FUNC_START(fallback_scalar_usercopy)
bltu a0, t0, 4b /* t0 - end of dst */ bltu a0, t0, 4b /* t0 - end of dst */
.Lout_copy_user: .Lout_copy_user:
/* Disable access to user memory */
csrc CSR_STATUS, t6
li a0, 0 li a0, 0
ret ret
/* Exception fixup code */
10: 10:
/* Disable access to user memory */
csrc CSR_STATUS, t6
sub a0, t5, a0 sub a0, t5, a0
ret ret
SYM_FUNC_END(__asm_copy_to_user) SYM_FUNC_END(fallback_scalar_usercopy_sum_enabled)
SYM_FUNC_END(fallback_scalar_usercopy)
EXPORT_SYMBOL(__asm_copy_to_user)
SYM_FUNC_ALIAS(__asm_copy_from_user, __asm_copy_to_user)
EXPORT_SYMBOL(__asm_copy_from_user)
SYM_FUNC_START(__clear_user) SYM_FUNC_START(__clear_user)

View file

@ -24,7 +24,18 @@ SYM_FUNC_START(__asm_vector_usercopy)
/* Enable access to user memory */ /* Enable access to user memory */
li t6, SR_SUM li t6, SR_SUM
csrs CSR_STATUS, t6 csrs CSR_STATUS, t6
mv t6, ra
call __asm_vector_usercopy_sum_enabled
/* Disable access to user memory */
mv ra, t6
li t6, SR_SUM
csrc CSR_STATUS, t6
ret
SYM_FUNC_END(__asm_vector_usercopy)
SYM_FUNC_START(__asm_vector_usercopy_sum_enabled)
loop: loop:
vsetvli iVL, iNum, e8, ELEM_LMUL_SETTING, ta, ma vsetvli iVL, iNum, e8, ELEM_LMUL_SETTING, ta, ma
fixup vle8.v vData, (pSrc), 10f fixup vle8.v vData, (pSrc), 10f
@ -36,8 +47,6 @@ loop:
/* Exception fixup for vector load is shared with normal exit */ /* Exception fixup for vector load is shared with normal exit */
10: 10:
/* Disable access to user memory */
csrc CSR_STATUS, t6
mv a0, iNum mv a0, iNum
ret ret
@ -49,4 +58,4 @@ loop:
csrr t2, CSR_VSTART csrr t2, CSR_VSTART
sub iNum, iNum, t2 sub iNum, iNum, t2
j 10b j 10b
SYM_FUNC_END(__asm_vector_usercopy) SYM_FUNC_END(__asm_vector_usercopy_sum_enabled)