diff --git a/third_party/gemmology/gemmology.h b/third_party/gemmology/gemmology.h index 54f7c19b3ea0..d774c5338896 100644 --- a/third_party/gemmology/gemmology.h +++ b/third_party/gemmology/gemmology.h @@ -209,6 +209,25 @@ maddw(xsimd::batch x, xsimd::batch y, } #endif +#ifdef __AVX512VNNI__ + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch>) { + return _mm512_dpbusd_epi32(z, x, y); +} + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch>) { + return _mm512_dpbusd_epi32(z, x, y); +} +#endif + #endif #ifdef __SSSE3__ @@ -233,7 +252,7 @@ template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { - return {_mm_unpacklo_epi8(first, second), _mm_unpackhi_epi8(first, second)}; + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; } template @@ -241,7 +260,7 @@ std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { - return {_mm_unpacklo_epi16(first, second), _mm_unpackhi_epi16(first, second)}; + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; } template @@ -249,7 +268,7 @@ std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { - return {_mm_unpacklo_epi32(first, second), _mm_unpackhi_epi32(first, second)}; + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; } template @@ -257,7 +276,7 @@ std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { - return {_mm_unpacklo_epi64(first, second), _mm_unpackhi_epi64(first, second)}; + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; } template @@ -362,14 +381,7 @@ template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { - int8x8_t first_lo = vget_low_s8(first); - int8x8_t second_lo = vget_low_s8(second); - int8x8x2_t result_lo = vzip_s8(first_lo, second_lo); - int8x8_t first_hi = vget_high_s8(first); - int8x8_t second_hi = vget_high_s8(second); - int8x8x2_t result_hi = vzip_s8(first_hi, second_hi); - return {vcombine_s8(result_lo.val[0], result_lo.val[1]), - vcombine_s8(result_hi.val[0], result_hi.val[1])}; + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; } template @@ -377,14 +389,7 @@ std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { - int16x4_t first_lo = vget_low_s16(first); - int16x4_t second_lo = vget_low_s16(second); - int16x4x2_t result_lo = vzip_s16(first_lo, second_lo); - int16x4_t first_hi = vget_high_s16(first); - int16x4_t second_hi = vget_high_s16(second); - int16x4x2_t result_hi = vzip_s16(first_hi, second_hi); - return {vcombine_s16(result_lo.val[0], result_lo.val[1]), - vcombine_s16(result_hi.val[0], result_hi.val[1])}; + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; } template @@ -392,14 +397,7 @@ std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { - int32x2_t first_lo = vget_low_s32(first); - int32x2_t second_lo = vget_low_s32(second); - int32x2x2_t result_lo = vzip_s32(first_lo, second_lo); - int32x2_t first_hi = vget_high_s32(first); - int32x2_t second_hi = vget_high_s32(second); - int32x2x2_t result_hi = vzip_s32(first_hi, second_hi); - return {vcombine_s32(result_lo.val[0], result_lo.val[1]), - vcombine_s32(result_hi.val[0], result_hi.val[1])}; + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; } template @@ -407,11 +405,7 @@ std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { - int64x1_t first_lo = vget_low_s64(first); - int64x1_t second_lo = vget_low_s64(second); - int64x1_t first_hi = vget_high_s64(first); - int64x1_t second_hi = vget_high_s64(second); - return {vcombine_s64(first_lo, second_lo), vcombine_s64(first_hi, second_hi)}; + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; } template @@ -554,10 +548,9 @@ inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { - int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))), - vmovl_s8(vget_low_s8(y))); - int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))), - vmovl_s8(vget_high_s8(y))); + int16x8_t tl = vmull_s8(vreinterpret_s8_u8(vget_low_u8(x)), + vget_low_s8(y)); + int16x8_t th = vmull_high_s8(vreinterpretq_s8_u8(x), y); return vqaddq_s16(vuzp1q_s16(tl, th), vuzp2q_s16(tl, th)); } @@ -566,14 +559,12 @@ inline xsimd::batch maddw(xsimd::batch x, xsimd::batch y, xsimd::batch z, xsimd::kernel::requires_arch) { - int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))), vmovl_s8(vget_low_s8(y))); int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))), vmovl_s8(vget_high_s8(y))); - int32x4_t pl = vpaddlq_s16(tl); - int32x4_t ph = vpaddlq_s16(th); - return vpaddq_s32(z, vpaddq_s32(pl, ph)); + return vpadalq_s16(vpadalq_s16(z, tl), th); + //TODO: investigate using vdotq_s32 } template diff --git a/third_party/gemmology/moz.yaml b/third_party/gemmology/moz.yaml index dd4664ef6c11..d9f9472da75e 100644 --- a/third_party/gemmology/moz.yaml +++ b/third_party/gemmology/moz.yaml @@ -10,8 +10,8 @@ origin: url: https://github.com/mozilla/gemmology - release: c04bacb101e020d9e6b51f20c92d7f63af50dd01 (2023-12-18T13:47:06Z). - revision: c04bacb101e020d9e6b51f20c92d7f63af50dd01 + release: ec535e87d0ab9d1457ff6d2af247cc8113e74694 (2024-02-05T09:05:20Z). + revision: ec535e87d0ab9d1457ff6d2af247cc8113e74694 license: MIT