diff --git a/media/highway/moz.build b/media/highway/moz.build index f2282bbe8baa..ea5a6bff4fbe 100644 --- a/media/highway/moz.build +++ b/media/highway/moz.build @@ -22,6 +22,7 @@ EXPORTS.hwy += [ "/third_party/highway/hwy/detect_targets.h", "/third_party/highway/hwy/foreach_target.h", "/third_party/highway/hwy/highway.h", + "/third_party/highway/hwy/highway_export.h", "/third_party/highway/hwy/targets.h", ] diff --git a/media/highway/moz.yaml b/media/highway/moz.yaml index f9be38e279ee..e26c091e5a4e 100644 --- a/media/highway/moz.yaml +++ b/media/highway/moz.yaml @@ -20,11 +20,11 @@ origin: # Human-readable identifier for this version/release # Generally "version NNN", "tag SSS", "bookmark SSS" - release: commit e69083a12a05caf037cabecdf1b248b7579705a5 (2021-11-11T08:20:00Z). + release: commit f13e3b956eb226561ac79427893ec0afd66f91a8 (2022-02-15T18:19:21Z). # Revision to pull in # Must be a long or short commit SHA (long preferred) - revision: e69083a12a05caf037cabecdf1b248b7579705a5 + revision: f13e3b956eb226561ac79427893ec0afd66f91a8 # The package's license, where possible using the mnemonic from # https://spdx.org/licenses/ diff --git a/media/libjxl/moz.yaml b/media/libjxl/moz.yaml index d3ebc20c4c3b..f553c18a740b 100644 --- a/media/libjxl/moz.yaml +++ b/media/libjxl/moz.yaml @@ -10,9 +10,9 @@ origin: url: https://github.com/libjxl/libjxl - release: commit 4322679b1c418addc2284c5ea84fc2c3935b4a75 (2022-02-07T20:56:39Z). + release: commit 89875cba4d18485ec9692c80b747b59b73ce712e (2022-02-28T16:03:42Z). - revision: 4322679b1c418addc2284c5ea84fc2c3935b4a75 + revision: 89875cba4d18485ec9692c80b747b59b73ce712e license: Apache-2.0 diff --git a/third_party/highway/.github/workflows/build_test.yml b/third_party/highway/.github/workflows/build_test.yml new file mode 100644 index 000000000000..bab1630bdadd --- /dev/null +++ b/third_party/highway/.github/workflows/build_test.yml @@ -0,0 +1,57 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Build / test +on: [push, pull_request] +jobs: + cmake: + name: Build and test ${{ matrix.name }} + runs-on: ubuntu-18.04 + strategy: + matrix: + include: + - name: Clang-5.0 + extra_deps: clang-5.0 + c_compiler: clang-5.0 + cxx_compiler: clang++-5.0 + + - name: Clang-6.0 + extra_deps: clang-6.0 + c_compiler: clang-6.0 + cxx_compiler: clang++-6.0 + + steps: + - uses: actions/checkout@v2 + + - name: Install deps + run: sudo apt-get install ${{ matrix.extra_deps }} + + - name: Build and test + run: | + export CMAKE_BUILD_PARALLEL_LEVEL=2 + export CTEST_PARALLEL_LEVEL=2 + CXXFLAGS=-Werror CC=${{ matrix.c_compiler }} CXX=${{ matrix.cxx_compiler }} cmake -B out . + cmake --build out + ctest --test-dir out + + bazel: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: bazelbuild/setup-bazelisk@v1 + - uses: actions/cache@v2 + with: + path: ~/.cache/bazel + key: bazel-${{ runner.os }} + - run: bazel build //... diff --git a/third_party/highway/BUILD b/third_party/highway/BUILD index 0bd3acd3d603..080c16bd9db8 100644 --- a/third_party/highway/BUILD +++ b/third_party/highway/BUILD @@ -1,6 +1,6 @@ load("@bazel_skylib//lib:selects.bzl", "selects") -load("@rules_cc//cc:defs.bzl", "cc_test") +load("@rules_cc//cc:defs.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -18,6 +18,11 @@ config_setting( flag_values = {"@bazel_tools//tools/cpp:compiler": "msvc"}, ) +config_setting( + name = "compiler_emscripten", + values = {"cpu": "wasm32"}, +) + # See https://github.com/bazelbuild/bazel/issues/12707 config_setting( name = "compiler_gcc_bug", @@ -41,13 +46,6 @@ selects.config_setting_group( ], ) -config_setting( - name = "emulate_sve", - values = { - "copt": "-DHWY_EMULATE_SVE", - }, -) - # Additional warnings for Clang OR GCC (skip for MSVC) CLANG_GCC_COPTS = [ "-Wunused-parameter", @@ -82,7 +80,7 @@ COPTS = select({ "//conditions:default": CLANG_GCC_COPTS + CLANG_ONLY_COPTS, }) + select({ "@platforms//cpu:riscv64": [ - "-march=rv64gcv0p10", + "-march=rv64gcv1p0", "-menable-experimental-extensions", ], "//conditions:default": [ @@ -112,28 +110,32 @@ cc_library( "hwy/base.h", "hwy/cache_control.h", "hwy/detect_compiler_arch.h", # private - "hwy/detect_targets.h", # private - "hwy/targets.h", + "hwy/highway_export.h", ], compatible_with = [], copts = COPTS, textual_hdrs = [ + # These are textual because config macros influence them: + "hwy/detect_targets.h", # private + "hwy/targets.h", + # End of list "hwy/highway.h", # public "hwy/foreach_target.h", # public "hwy/ops/arm_neon-inl.h", "hwy/ops/arm_sve-inl.h", "hwy/ops/generic_ops-inl.h", - "hwy/ops/rvv-inl.h", "hwy/ops/scalar-inl.h", "hwy/ops/set_macros-inl.h", "hwy/ops/shared-inl.h", - "hwy/ops/wasm_128-inl.h", "hwy/ops/x86_128-inl.h", "hwy/ops/x86_256-inl.h", "hwy/ops/x86_512-inl.h", - ], - deps = select({ - ":emulate_sve": ["//third_party/farm_sve"], + # Select avoids recompiling native arch if only non-native changed + ] + select({ + ":compiler_emscripten": ["hwy/ops/wasm_128-inl.h"], + "//conditions:default": [], + }) + select({ + "@platforms//cpu:riscv64": ["hwy/ops/rvv-inl.h"], "//conditions:default": [], }), ) @@ -144,7 +146,9 @@ cc_library( textual_hdrs = [ "hwy/contrib/dot/dot-inl.h", ], - deps = [":hwy"], + deps = [ + ":hwy", + ], ) cc_library( @@ -156,7 +160,9 @@ cc_library( "hwy/contrib/image/image.h", ], compatible_with = [], - deps = [":hwy"], + deps = [ + ":hwy", + ], ) cc_library( @@ -165,16 +171,9 @@ cc_library( textual_hdrs = [ "hwy/contrib/math/math-inl.h", ], - deps = [":hwy"], -) - -cc_library( - name = "sort", - compatible_with = [], - textual_hdrs = [ - "hwy/contrib/sort/sort-inl.h", + deps = [ + ":hwy", ], - deps = [":hwy"], ) # Everything required for tests that use Highway. @@ -188,7 +187,9 @@ cc_library( ], # Must not depend on a gtest variant, which can conflict with the # GUNIT_INTERNAL_BUILD_MODE defined by the test. - deps = [":hwy"], + deps = [ + ":hwy", + ], ) cc_library( @@ -212,7 +213,9 @@ cc_library( srcs = ["hwy/examples/skeleton.cc"], hdrs = ["hwy/examples/skeleton.h"], textual_hdrs = ["hwy/examples/skeleton-inl.h"], - deps = [":hwy"], + deps = [ + ":hwy", + ], ) cc_binary( @@ -226,7 +229,7 @@ HWY_TESTS = [ ("hwy/contrib/dot/", "dot_test"), ("hwy/contrib/image/", "image_test"), ("hwy/contrib/math/", "math_test"), - ("hwy/contrib/sort/", "sort_test"), + # contrib/sort has its own BUILD, we add it to GUITAR_TESTS. ("hwy/examples/", "skeleton_test"), ("hwy/", "nanobenchmark_test"), ("hwy/", "aligned_allocator_test"), @@ -239,13 +242,27 @@ HWY_TESTS = [ ("hwy/tests/", "compare_test"), ("hwy/tests/", "convert_test"), ("hwy/tests/", "crypto_test"), + ("hwy/tests/", "demote_test"), ("hwy/tests/", "logical_test"), ("hwy/tests/", "mask_test"), ("hwy/tests/", "memory_test"), + ("hwy/tests/", "shift_test"), ("hwy/tests/", "swizzle_test"), ("hwy/tests/", "test_util_test"), ] +HWY_TEST_DEPS = [ + ":dot", + ":hwy", + ":hwy_test_util", + ":image", + ":math", + ":nanobenchmark", + ":skeleton", + "//hwy/contrib/sort:vqsort", + "@com_google_googletest//:gtest_main", +] + [ [ cc_test( @@ -265,6 +282,18 @@ HWY_TESTS = [ "@platforms//cpu:riscv64": ["fully_static_link"], "//conditions:default": [], }), + linkopts = select({ + ":compiler_emscripten": [ + "-s ASSERTIONS=2", + "-s ENVIRONMENT=node,shell,web", + "-s ERROR_ON_UNDEFINED_SYMBOLS=1", + "-s DEMANGLE_SUPPORT=1", + "-s EXIT_RUNTIME=1", + "-s ALLOW_MEMORY_GROWTH=1", + "--pre-js $(location :preamble.js.lds)", + ], + "//conditions:default": [], + }), linkstatic = select({ "@platforms//cpu:riscv64": True, "//conditions:default": False, @@ -272,17 +301,10 @@ HWY_TESTS = [ local_defines = ["HWY_IS_TEST"], # for test_suite. tags = ["hwy_ops_test"], - deps = [ - ":dot", - ":hwy", - ":hwy_test_util", - ":image", - ":math", - ":nanobenchmark", - ":skeleton", - ":sort", - "@com_google_googletest//:gtest_main", - ], + deps = HWY_TEST_DEPS + select({ + ":compiler_emscripten": [":preamble.js.lds"], + "//conditions:default": [], + }), ), ] for subdir, test in HWY_TESTS @@ -293,3 +315,5 @@ test_suite( name = "hwy_ops_tests", tags = ["hwy_ops_test"], ) + +# Placeholder for integration test, do not remove diff --git a/third_party/highway/CMakeLists.txt b/third_party/highway/CMakeLists.txt index d910e3c9ea5f..604eb74320ff 100644 --- a/third_party/highway/CMakeLists.txt +++ b/third_party/highway/CMakeLists.txt @@ -19,11 +19,13 @@ if(POLICY CMP0083) cmake_policy(SET CMP0083 NEW) endif() -project(hwy VERSION 0.15.0) # Keep in sync with highway.h version +project(hwy VERSION 0.16.0) # Keep in sync with highway.h version + +# Directly define the ABI version from the cmake project() version values: +set(LIBRARY_VERSION "${hwy_VERSION}") +set(LIBRARY_SOVERSION ${hwy_VERSION_MAJOR}) -set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_EXTENSIONS OFF) -set(CMAKE_CXX_STANDARD_REQUIRED YES) # Enabled PIE binaries by default if supported. include(CheckPIESupported OPTIONAL RESULT_VARIABLE CHECK_PIE_SUPPORTED) @@ -40,13 +42,14 @@ if (NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE RelWithDebInfo) endif() -set(HWY_CMAKE_ARM7 OFF CACHE BOOL "Set copts for ARMv7 with NEON?") +set(HWY_CMAKE_ARM7 OFF CACHE BOOL "Set copts for ARMv7 with NEON (requires vfpv4)?") # Unconditionally adding -Werror risks breaking the build when new warnings # arise due to compiler/platform changes. Enable this in CI/tests. set(HWY_WARNINGS_ARE_ERRORS OFF CACHE BOOL "Add -Werror flag?") -set(HWY_EXAMPLES_TESTS_INSTALL ON CACHE BOOL "Build examples, tests, install?") +set(HWY_ENABLE_EXAMPLES ON CACHE BOOL "Build examples") +set(HWY_ENABLE_INSTALL ON CACHE BOOL "Install library") include(CheckCXXSourceCompiles) check_cxx_source_compiles( @@ -64,7 +67,32 @@ set(HWY_CONTRIB_SOURCES hwy/contrib/image/image.cc hwy/contrib/image/image.h hwy/contrib/math/math-inl.h - hwy/contrib/sort/sort-inl.h + hwy/contrib/sort/disabled_targets.h + hwy/contrib/sort/shared-inl.h + hwy/contrib/sort/sorting_networks-inl.h + hwy/contrib/sort/traits-inl.h + hwy/contrib/sort/traits128-inl.h + hwy/contrib/sort/vqsort-inl.h + hwy/contrib/sort/vqsort.cc + hwy/contrib/sort/vqsort.h + hwy/contrib/sort/vqsort_128a.cc + hwy/contrib/sort/vqsort_128d.cc + hwy/contrib/sort/vqsort_f32a.cc + hwy/contrib/sort/vqsort_f32d.cc + hwy/contrib/sort/vqsort_f64a.cc + hwy/contrib/sort/vqsort_f64d.cc + hwy/contrib/sort/vqsort_i16a.cc + hwy/contrib/sort/vqsort_i16d.cc + hwy/contrib/sort/vqsort_i32a.cc + hwy/contrib/sort/vqsort_i32d.cc + hwy/contrib/sort/vqsort_i64a.cc + hwy/contrib/sort/vqsort_i64d.cc + hwy/contrib/sort/vqsort_u16a.cc + hwy/contrib/sort/vqsort_u16d.cc + hwy/contrib/sort/vqsort_u32a.cc + hwy/contrib/sort/vqsort_u32d.cc + hwy/contrib/sort/vqsort_u64a.cc + hwy/contrib/sort/vqsort_u64d.cc ) set(HWY_SOURCES @@ -76,6 +104,7 @@ set(HWY_SOURCES hwy/detect_targets.h # private hwy/foreach_target.h hwy/highway.h + hwy/highway_export.h hwy/nanobenchmark.cc hwy/nanobenchmark.h hwy/ops/arm_neon-inl.h @@ -192,20 +221,59 @@ else() endif() # !MSVC -add_library(hwy STATIC ${HWY_SOURCES}) +# By default prefer STATIC build (legacy behavior) +option(BUILD_SHARED_LIBS "Build shared libraries" OFF) +option(HWY_FORCE_STATIC_LIBS "Ignore BUILD_SHARED_LIBS" OFF) +# only expose shared/static options to advanced users: +mark_as_advanced(BUILD_SHARED_LIBS) +mark_as_advanced(HWY_FORCE_STATIC_LIBS) +# Define visibility settings globally: +set(CMAKE_CXX_VISIBILITY_PRESET hidden) +set(CMAKE_VISIBILITY_INLINES_HIDDEN 1) + +# Copy-cat "add_library" logic + add override. +set(HWY_LIBRARY_TYPE "SHARED") +if (NOT BUILD_SHARED_LIBS OR HWY_FORCE_STATIC_LIBS) + set(HWY_LIBRARY_TYPE "STATIC") +endif() + +# This preprocessor define will drive the build, also used in the *.pc files: +if("${HWY_LIBRARY_TYPE}" STREQUAL "SHARED") + set(DLLEXPORT_TO_DEFINE "HWY_SHARED_DEFINE") +else() + set(DLLEXPORT_TO_DEFINE "HWY_STATIC_DEFINE") +endif() + +add_library(hwy ${HWY_LIBRARY_TYPE} ${HWY_SOURCES}) +target_compile_definitions(hwy PUBLIC "${DLLEXPORT_TO_DEFINE}") target_compile_options(hwy PRIVATE ${HWY_FLAGS}) set_property(TARGET hwy PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(hwy PROPERTIES VERSION ${LIBRARY_VERSION} SOVERSION ${LIBRARY_SOVERSION}) target_include_directories(hwy PUBLIC ${CMAKE_CURRENT_LIST_DIR}) +target_compile_features(hwy PUBLIC cxx_std_11) +set_target_properties(hwy PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version) +# not supported by MSVC/Clang, safe to skip (we use DLLEXPORT annotations) +if(UNIX AND NOT APPLE) + set_property(TARGET hwy APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version") +endif() -add_library(hwy_contrib STATIC ${HWY_CONTRIB_SOURCES}) +add_library(hwy_contrib ${HWY_LIBRARY_TYPE} ${HWY_CONTRIB_SOURCES}) +target_link_libraries(hwy_contrib hwy) target_compile_options(hwy_contrib PRIVATE ${HWY_FLAGS}) set_property(TARGET hwy_contrib PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(hwy_contrib PROPERTIES VERSION ${LIBRARY_VERSION} SOVERSION ${LIBRARY_SOVERSION}) target_include_directories(hwy_contrib PUBLIC ${CMAKE_CURRENT_LIST_DIR}) +target_compile_features(hwy_contrib PUBLIC cxx_std_11) -add_library(hwy_test STATIC ${HWY_TEST_SOURCES}) +add_library(hwy_test ${HWY_LIBRARY_TYPE} ${HWY_TEST_SOURCES}) +target_link_libraries(hwy_test hwy) target_compile_options(hwy_test PRIVATE ${HWY_FLAGS}) set_property(TARGET hwy_test PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(hwy_test PROPERTIES VERSION ${LIBRARY_VERSION} SOVERSION ${LIBRARY_SOVERSION}) target_include_directories(hwy_test PUBLIC ${CMAKE_CURRENT_LIST_DIR}) +target_compile_features(hwy_test PUBLIC cxx_std_11) # -------------------------------------------------------- hwy_list_targets # Generate a tool to print the compiled-in targets as defined by the current @@ -219,17 +287,22 @@ target_include_directories(hwy_list_targets PRIVATE # Naked target also not always could be run (due to the lack of '.\' prefix) # Thus effective command to run should contain the full path # and emulator prefix (if any). +if (NOT CMAKE_CROSSCOMPILING OR CMAKE_CROSSCOMPILING_EMULATOR) add_custom_command(TARGET hwy_list_targets POST_BUILD COMMAND ${CMAKE_CROSSCOMPILING_EMULATOR} $ || (exit 0)) +endif() # -------------------------------------------------------- # Allow skipping the following sections for projects that do not need them: # tests, examples, benchmarks and installation. -if (HWY_EXAMPLES_TESTS_INSTALL) # -------------------------------------------------------- install library +if (HWY_ENABLE_INSTALL) + install(TARGETS hwy - DESTINATION "${CMAKE_INSTALL_LIBDIR}") + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") # Install all the headers keeping the relative path to the current directory # when installing them. foreach (source ${HWY_SOURCES}) @@ -241,7 +314,9 @@ foreach (source ${HWY_SOURCES}) endforeach() install(TARGETS hwy_contrib - DESTINATION "${CMAKE_INSTALL_LIBDIR}") + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") # Install all the headers keeping the relative path to the current directory # when installing them. foreach (source ${HWY_CONTRIB_SOURCES}) @@ -253,7 +328,9 @@ foreach (source ${HWY_CONTRIB_SOURCES}) endforeach() install(TARGETS hwy_test - DESTINATION "${CMAKE_INSTALL_LIBDIR}") + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") # Install all the headers keeping the relative path to the current directory # when installing them. foreach (source ${HWY_TEST_SOURCES}) @@ -272,7 +349,9 @@ foreach (pc libhwy.pc libhwy-contrib.pc libhwy-test.pc) DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") endforeach() +endif() # HWY_ENABLE_INSTALL # -------------------------------------------------------- Examples +if (HWY_ENABLE_EXAMPLES) # Avoids mismatch between GTest's static CRT and our dynamic. set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) @@ -280,7 +359,6 @@ set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) # Programming exercise with integrated benchmark add_executable(hwy_benchmark hwy/examples/benchmark.cc) target_sources(hwy_benchmark PRIVATE - hwy/nanobenchmark.cc hwy/nanobenchmark.h) # Try adding either -DHWY_COMPILE_ONLY_SCALAR or -DHWY_COMPILE_ONLY_STATIC to # observe the difference in targets printed. @@ -289,6 +367,7 @@ target_link_libraries(hwy_benchmark hwy) set_target_properties(hwy_benchmark PROPERTIES RUNTIME_OUTPUT_DIRECTORY "examples/") +endif() # HWY_ENABLE_EXAMPLES # -------------------------------------------------------- Tests include(CTest) @@ -352,9 +431,11 @@ set(HWY_TEST_FILES hwy/tests/compare_test.cc hwy/tests/convert_test.cc hwy/tests/crypto_test.cc + hwy/tests/demote_test.cc hwy/tests/logical_test.cc hwy/tests/mask_test.cc hwy/tests/memory_test.cc + hwy/tests/shift_test.cc hwy/tests/swizzle_test.cc hwy/tests/test_util_test.cc ) @@ -377,7 +458,7 @@ foreach (TESTFILE IN LISTS HWY_TEST_FILES) target_link_libraries(${TESTNAME} hwy hwy_contrib hwy_test gtest gtest_main) endif() # Output test targets in the test directory. - set_target_properties(${TESTNAME} PROPERTIES PREFIX "tests/") + set_target_properties(${TESTNAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "tests") if (HWY_EMSCRIPTEN) set_target_properties(${TESTNAME} PROPERTIES LINK_FLAGS "-s SINGLE_FILE=1") @@ -394,5 +475,3 @@ endforeach () target_sources(skeleton_test PRIVATE hwy/examples/skeleton.cc) endif() # BUILD_TESTING - -endif() # HWY_EXAMPLES_TESTS_INSTALL diff --git a/third_party/highway/CMakeLists.txt.in b/third_party/highway/CMakeLists.txt.in index df401705ee99..a0260b82f7c0 100644 --- a/third_party/highway/CMakeLists.txt.in +++ b/third_party/highway/CMakeLists.txt.in @@ -5,11 +5,11 @@ project(googletest-download NONE) include(ExternalProject) ExternalProject_Add(googletest GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG master + GIT_TAG 43efa0a4efd40c78b9210d15373112081899a97c SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-build" CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND "" TEST_COMMAND "" -) \ No newline at end of file +) diff --git a/third_party/highway/README.md b/third_party/highway/README.md index 4747e85bcc27..8bdcaa45bf69 100644 --- a/third_party/highway/README.md +++ b/third_party/highway/README.md @@ -1,30 +1,95 @@ -# Efficient and performance-portable SIMD +# Efficient and performance-portable vector software -Highway is a C++ library for SIMD (Single Instruction, Multiple Data), i.e. -applying the same operation to multiple 'lanes' using a single CPU instruction. +[//]: # (placeholder, do not remove) -## Why Highway? +Highway is a C++ library that provides portable SIMD/vector intrinsics. -- more portable (same source code) than platform-specific intrinsics, -- works on a wider range of compilers than compiler-specific vector extensions, -- more dependable than autovectorization, -- easier to write/maintain than assembly language, -- supports **runtime dispatch**, -- supports **variable-length vector** architectures. +## Why + +We are passionate about high-performance software. We see major untapped +potential in CPUs (servers, mobile, desktops). Highway is for engineers who want +to reliably and economically push the boundaries of what is possible in +software. + +## How + +CPUs provide SIMD/vector instructions that apply the same operation to multiple +data items. This can reduce energy usage e.g. *fivefold* because fewer +instructions are executed. We also often see *5-10x* speedups. + +Highway makes SIMD/vector programming practical and workable according to these +guiding principles: + +**Does what you expect**: Highway is a C++ library with carefully-chosen +functions that map well to CPU instructions without extensive compiler +transformations. The resulting code is more predictable and robust to code +changes/compiler updates than autovectorization. + +**Works on widely-used platforms**: Highway supports four architectures; the +same application code can target eight instruction sets, including those with +'scalable' vectors (size unknown at compile time). Highway only requires C++11 +and supports four families of compilers. If you would like to use Highway on +other platforms, please raise an issue. + +**Flexible to deploy**: Applications using Highway can run on heterogeneous +clouds or client devices, choosing the best available instruction set at +runtime. Alternatively, developers may choose to target a single instruction set +without any runtime overhead. In both cases, the application code is the same +except for swapping `HWY_STATIC_DISPATCH` with `HWY_DYNAMIC_DISPATCH` plus one +line of code. + +**Suitable for a variety of domains**: Highway provides an extensive set of +operations, used for image processing (floating-point), compression, video +analysis, linear algebra, cryptography, sorting and random generation. We +recognise that new use-cases may require additional ops and are happy to add +them where it makes sense (e.g. no performance cliffs on some architectures). If +you would like to discuss, please file an issue. + +**Rewards data-parallel design**: Highway provides tools such as Gather, +MaskedLoad, and FixedTag to enable speedups for legacy data structures. However, +the biggest gains are unlocked by designing algorithms and data structures for +scalable vectors. Helpful techniques include batching, structure-of-array +layouts, and aligned/padded allocations. + +## Examples + +Online demos using Compiler Explorer: + +- [generating code for multiple targets](https://gcc.godbolt.org/z/n6rx6xK5h) (recommended) +- [single target using -m flags](https://gcc.godbolt.org/z/rGnjMevKG) + +Projects using Highway: (to add yours, feel free to raise an issue or contact us +via the below email) + +* [iresearch database index](https://github.com/iresearch-toolkit/iresearch/blob/e7638e7a4b99136ca41f82be6edccf01351a7223/core/utils/simd_utils.hpp) +* [JPEG XL image codec](https://github.com/libjxl/libjxl) +* [Grok JPEG 2000 image codec](https://github.com/GrokImageCompression/grok) +* [vectorized Quicksort](https://github.com/google/highway/tree/master/hwy/contrib/sort) ## Current status +### Targets + Supported targets: scalar, S-SSE3, SSE4, AVX2, AVX-512, AVX3_DL (~Icelake, -requires opt-in by defining `HWY_WANT_AVX3_DL`), NEON (ARMv7 and v8), SVE, +requires opt-in by defining `HWY_WANT_AVX3_DL`), NEON (ARMv7 and v8), SVE, SVE2, WASM SIMD. -SVE is tested using farm_sve (see acknowledgments). SVE2 is implemented but not -yet validated. A subset of RVV is implemented and tested with GCC and QEMU. -Work is underway to compile using LLVM, which has different intrinsics with AVL. +SVE was initially tested using farm_sve (see acknowledgments). A subset of RVV +is implemented and tested with LLVM and QEMU. Work is underway to add RVV ops +which were not yet supported by GCC. -Version 0.11 is considered stable enough to use in other projects, and is -expected to remain backwards compatible unless serious issues are discovered -while finishing the RVV target. After that, Highway will reach version 1.0. +### Versioning + +Highway releases aim to follow the semver.org system (MAJOR.MINOR.PATCH), +incrementing MINOR after backward-compatible additions and PATCH after +backward-compatible fixes. We recommend using releases (rather than the Git tip) +because they are tested more extensively, see below. + +Version 0.11 is considered stable enough to use in other projects. +Version 1.0 will signal an increased focus on backwards compatibility and will +be reached after the RVV target is finished (planned for 2022H1). + +### Testing Continuous integration tests build with a recent version of Clang (running on x86 and QEMU for ARM) and MSVC from VS2015 (running on x86). @@ -33,13 +98,15 @@ Before releases, we also test on x86 with Clang and GCC, and ARMv7/8 via GCC cross-compile and QEMU. See the [testing process](g3doc/release_testing_process.md) for details. +### Related modules + The `contrib` directory contains SIMD-related utilities: an image class with -aligned rows, and a math library (16 functions already implemented, mostly -trigonometry). +aligned rows, a math library (16 functions already implemented, mostly +trigonometry), and functions for computing dot products and sorting. ## Installation -This project uses cmake to generate and build. In a Debian-based system you can +This project uses CMake to generate and build. In a Debian-based system you can install it via: ```bash @@ -55,7 +122,8 @@ installing gtest separately: sudo apt install libgtest-dev ``` -To build and test the library the standard cmake workflow can be used: +To build Highway as a shared or static library (depending on BUILD_SHARED_LIBS), +the standard CMake workflow can be used: ```bash mkdir -p build && cd build @@ -76,31 +144,40 @@ and their parameters, and the [instruction_matrix](g3doc/instruction_matrix.pdf) indicates the number of instructions per operation. We recommend using full SIMD vectors whenever possible for maximum performance -portability. To obtain them, pass a `HWY_FULL(float)` tag to functions such as -`Zero/Set/Load`. There is also the option of a vector of up to `N` (a power of -two <= 16/sizeof(T)) lanes of type `T`: `HWY_CAPPED(T, N)`. If `HWY_TARGET == -HWY_SCALAR`, the vector always has one lane. For all other targets, up to -128-bit vectors are guaranteed to be available. +portability. To obtain them, pass a `ScalableTag` (or equivalently +`HWY_FULL(float)`) tag to functions such as `Zero/Set/Load`. There are two +alternatives for use-cases requiring an upper bound on the lanes: -Functions using Highway must be inside `namespace HWY_NAMESPACE {` -(possibly nested in one or more other namespaces defined by the project), and -additionally either prefixed with `HWY_ATTR`, or residing between -`HWY_BEFORE_NAMESPACE()` and `HWY_AFTER_NAMESPACE()`. +- For up to a power of two `N`, specify `CappedTag` (or + equivalently `HWY_CAPPED(T, N)`). This is useful for data structures such as + a narrow matrix. A loop is still required because vectors may actually have + fewer than `N` lanes. + +- For exactly a power of two `N` lanes, specify `FixedTag`. The largest + supported `N` depends on the target, but is guaranteed to be at least + `16/sizeof(T)`. + +Functions using Highway must either be inside `namespace HWY_NAMESPACE {` +(possibly nested in one or more other namespaces defined by the project), OR +each op must be prefixed with `hn::`, e.g. `namespace hn = hwy::HWY_NAMESPACE; +hn::LoadDup128()`. Additionally, each function using Highway must either be +prefixed with `HWY_ATTR`, OR reside between `HWY_BEFORE_NAMESPACE()` and +`HWY_AFTER_NAMESPACE()`. * For static dispatch, `HWY_TARGET` will be the best available target among `HWY_BASELINE_TARGETS`, i.e. those allowed for use by the compiler (see - [quick-reference](g3doc/quick_reference.md)). Functions inside `HWY_NAMESPACE` - can be called using `HWY_STATIC_DISPATCH(func)(args)` within the same module - they are defined in. You can call the function from other modules by - wrapping it in a regular function and declaring the regular function in a - header. + [quick-reference](g3doc/quick_reference.md)). Functions inside + `HWY_NAMESPACE` can be called using `HWY_STATIC_DISPATCH(func)(args)` within + the same module they are defined in. You can call the function from other + modules by wrapping it in a regular function and declaring the regular + function in a header. * For dynamic dispatch, a table of function pointers is generated via the `HWY_EXPORT` macro that is used by `HWY_DYNAMIC_DISPATCH(func)(args)` to call the best function pointer for the current CPU's supported targets. A module is automatically compiled for each target in `HWY_TARGETS` (see [quick-reference](g3doc/quick_reference.md)) if `HWY_TARGET_INCLUDE` is - defined and foreach_target.h is included. + defined and `foreach_target.h` is included. ## Compiler flags @@ -123,17 +200,17 @@ ensure proper VEX code generation for AVX2 targets. To vectorize a loop, "strip-mining" transforms it into an outer loop and inner loop with number of iterations matching the preferred vector width. -In this section, let `T` denote the element type, `d = HWY_FULL(T)`, `count` the -number of elements to process, and `N = Lanes(d)` the number of lanes in a full -vector. Assume the loop body is given as a function `template void LoopBody(D d, size_t max_n)`. +In this section, let `T` denote the element type, `d = ScalableTag`, `count` +the number of elements to process, and `N = Lanes(d)` the number of lanes in a +full vector. Assume the loop body is given as a function `template void LoopBody(D d, size_t index, size_t max_n)`. Highway offers several ways to express loops where `N` need not divide `count`: * Ensure all inputs/outputs are padded. Then the loop is simply ``` - for (size_t i = 0; i < count; i += N) LoopBody(d, 0); + for (size_t i = 0; i < count; i += N) LoopBody(d, i, 0); ``` Here, the template parameter and second function argument are not needed. @@ -149,8 +226,8 @@ Highway offers several ways to express loops where `N` need not divide `count`: ``` size_t i = 0; - for (; i + N <= count; i += N) LoopBody(d, 0); - for (; i < count; ++i) LoopBody(HWY_CAPPED(T, 1)(), 0); + for (; i + N <= count; i += N) LoopBody(d, i, 0); + for (; i < count; ++i) LoopBody(HWY_CAPPED(T, 1)(), i, 0); ``` The template parameter and second function arguments are again not needed. @@ -163,18 +240,20 @@ Highway offers several ways to express loops where `N` need not divide `count`: ``` size_t i = 0; for (; i + N <= count; i += N) { - LoopBody(d, 0); + LoopBody(d, i, 0); } if (i < count) { - LoopBody(d, count - i); + LoopBody(d, i, count - i); } ``` - Now the template parameter and second function argument can be used inside + Now the template parameter and third function argument can be used inside `LoopBody` to 'blend' the new partial vector with previous memory contents: `Store(IfThenElse(FirstN(d, N), partial, prev_full), d, aligned_pointer);`. This is a good default when it is infeasible to ensure vectors are padded. In contrast to the scalar loop, only a single final iteration is needed. + The increased code size from two loop bodies is expected to be worthwhile + because it avoids the cost of masking in all but the final iteration. ## Additional resources diff --git a/third_party/highway/debian/changelog b/third_party/highway/debian/changelog index 516a15e4c7aa..b956a89682f2 100644 --- a/third_party/highway/debian/changelog +++ b/third_party/highway/debian/changelog @@ -1,3 +1,15 @@ +highway (0.16.0-1) UNRELEASED; urgency=medium + + * Add contrib/sort (vectorized quicksort) + * Add IfNegativeThenElse, IfVecThenElse + * Add Reverse2,4,8, ReverseBlocks, DupEven/Odd, AESLastRound + * Add OrAnd, Min128, Max128, Lt128, SumsOf8 + * Support capped/partial vectors on RVV/SVE, int64 in WASM + * Support SVE2, shared library build + * Remove deprecated overloads without the required d arg (UpperHalf etc.) + + -- Jan Wassenberg Thu, 03 Feb 2022 11:00:00 +0100 + highway (0.15.0-1) UNRELEASED; urgency=medium * New ops: CompressBlendedStore, ConcatOdd/Even, IndicesFromVec diff --git a/third_party/highway/hwy/aligned_allocator.h b/third_party/highway/hwy/aligned_allocator.h index 1e76cefb6feb..ff8c08d39d25 100644 --- a/third_party/highway/hwy/aligned_allocator.h +++ b/third_party/highway/hwy/aligned_allocator.h @@ -18,8 +18,11 @@ // Memory allocator with support for alignment and offsets. #include + #include +#include "hwy/highway_export.h" + namespace hwy { // Minimum alignment of allocated memory for use in HWY_ASSUME_ALIGNED, which @@ -36,15 +39,15 @@ using FreePtr = void (*)(void* opaque, void* memory); // bytes of newly allocated memory, aligned to the larger of HWY_ALIGNMENT and // the vector size. Calls `alloc` with the passed `opaque` pointer to obtain // memory or malloc() if it is null. -void* AllocateAlignedBytes(size_t payload_size, AllocPtr alloc_ptr, - void* opaque_ptr); +HWY_DLLEXPORT void* AllocateAlignedBytes(size_t payload_size, + AllocPtr alloc_ptr, void* opaque_ptr); // Frees all memory. No effect if `aligned_pointer` == nullptr, otherwise it // must have been returned from a previous call to `AllocateAlignedBytes`. // Calls `free_ptr` with the passed `opaque_ptr` pointer to free the memory; if // `free_ptr` function is null, uses the default free(). -void FreeAlignedBytes(const void* aligned_pointer, FreePtr free_ptr, - void* opaque_ptr); +HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer, + FreePtr free_ptr, void* opaque_ptr); // Class that deletes the aligned pointer passed to operator() calling the // destructor before freeing the pointer. This is equivalent to the @@ -76,8 +79,10 @@ class AlignedDeleter { // array. TypeArrayDeleter would match this prototype. using ArrayDeleter = void (*)(void* t_ptr, size_t t_size); - static void DeleteAlignedArray(void* aligned_pointer, FreePtr free_ptr, - void* opaque_ptr, ArrayDeleter deleter); + HWY_DLLEXPORT static void DeleteAlignedArray(void* aligned_pointer, + FreePtr free_ptr, + void* opaque_ptr, + ArrayDeleter deleter); FreePtr free_; void* opaque_ptr_; @@ -107,8 +112,8 @@ template AlignedUniquePtr MakeUniqueAligned(Args&&... args) { T* ptr = static_cast(AllocateAlignedBytes( sizeof(T), /*alloc_ptr=*/nullptr, /*opaque_ptr=*/nullptr)); - return AlignedUniquePtr( - new (ptr) T(std::forward(args)...), AlignedDeleter()); + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter()); } // Helpers for array allocators (avoids overflow) diff --git a/third_party/highway/hwy/base.h b/third_party/highway/hwy/base.h index 009373f884ba..4c3384f4db70 100644 --- a/third_party/highway/hwy/base.h +++ b/third_party/highway/hwy/base.h @@ -24,6 +24,7 @@ #include #include "hwy/detect_compiler_arch.h" +#include "hwy/highway_export.h" //------------------------------------------------------------------------------ // Compiler-specific definitions @@ -184,10 +185,6 @@ } while (0) #endif -#if defined(HWY_EMULATE_SVE) -class FarmFloat16; -#endif - namespace hwy { //------------------------------------------------------------------------------ @@ -205,7 +202,9 @@ static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; //------------------------------------------------------------------------------ // Alignment -// For stack-allocated partial arrays or LoadDup128. +// Potentially useful for LoadDup128 and capped vectors. In other cases, arrays +// should be allocated dynamically via aligned_allocator.h because Lanes() may +// exceed the stack size. #if HWY_ARCH_X86 #define HWY_ALIGN_MAX alignas(64) #elif HWY_ARCH_RVV && defined(__riscv_vector) @@ -228,9 +227,7 @@ static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; #pragma pack(push, 1) -#if defined(HWY_EMULATE_SVE) -using float16_t = FarmFloat16; -#elif HWY_NATIVE_FLOAT16 +#if HWY_NATIVE_FLOAT16 using float16_t = __fp16; // Clang does not allow __fp16 arguments, but scalar.h requires LaneType // arguments, so use a wrapper. @@ -253,15 +250,15 @@ using float64_t = double; //------------------------------------------------------------------------------ // Controlling overload resolution (SFINAE) -template +template struct EnableIfT {}; -template -struct EnableIfT { - using type = T; +template <> +struct EnableIfT { + using type = void; }; -template -using EnableIf = typename EnableIfT::type; +template +using EnableIf = typename EnableIfT::type; template struct IsSameT { @@ -283,7 +280,7 @@ HWY_API constexpr bool IsSame() { // // Note that enabling for exactly 128 bits is unnecessary because a function can // simply be overloaded with Vec128 and/or Full128 tag. Enabling for other -// sizes (e.g. 64 bit) can be achieved via Simd. +// sizes (e.g. 64 bit) can be achieved via Simd. #define HWY_IF_LE128(T, N) hwy::EnableIf* = nullptr #define HWY_IF_LE64(T, N) hwy::EnableIf* = nullptr #define HWY_IF_LE32(T, N) hwy::EnableIf* = nullptr @@ -319,102 +316,6 @@ struct RemoveConstT { template using RemoveConst = typename RemoveConstT::type; -//------------------------------------------------------------------------------ -// Type traits - -template -HWY_API constexpr bool IsFloat() { - // Cannot use T(1.25) != T(1) for float16_t, which can only be converted to or - // from a float, not compared. - return IsSame() || IsSame(); -} - -template -HWY_API constexpr bool IsSigned() { - return T(0) > T(-1); -} -template <> -constexpr bool IsSigned() { - return true; -} -template <> -constexpr bool IsSigned() { - return true; -} - -// Largest/smallest representable integer values. -template -HWY_API constexpr T LimitsMax() { - static_assert(!IsFloat(), "Only for integer types"); - return IsSigned() ? T((1ULL << (sizeof(T) * 8 - 1)) - 1) - : static_cast(~0ull); -} -template -HWY_API constexpr T LimitsMin() { - static_assert(!IsFloat(), "Only for integer types"); - return IsSigned() ? T(-1) - LimitsMax() : T(0); -} - -// Largest/smallest representable value (integer or float). This naming avoids -// confusion with numeric_limits::min() (the smallest positive value). -template -HWY_API constexpr T LowestValue() { - return LimitsMin(); -} -template <> -constexpr float LowestValue() { - return -FLT_MAX; -} -template <> -constexpr double LowestValue() { - return -DBL_MAX; -} - -template -HWY_API constexpr T HighestValue() { - return LimitsMax(); -} -template <> -constexpr float HighestValue() { - return FLT_MAX; -} -template <> -constexpr double HighestValue() { - return DBL_MAX; -} - -// Returns bitmask of the exponent field in IEEE binary32/64. -template -constexpr T ExponentMask() { - static_assert(sizeof(T) == 0, "Only instantiate the specializations"); - return 0; -} -template <> -constexpr uint32_t ExponentMask() { - return 0x7F800000; -} -template <> -constexpr uint64_t ExponentMask() { - return 0x7FF0000000000000ULL; -} - -// Returns 1 << mantissa_bits as a floating-point number. All integers whose -// absolute value are less than this can be represented exactly. -template -constexpr T MantissaEnd() { - static_assert(sizeof(T) == 0, "Only instantiate the specializations"); - return 0; -} -template <> -constexpr float MantissaEnd() { - return 8388608.0f; // 1 << 23 -} -template <> -constexpr double MantissaEnd() { - // floating point literal with p52 requires C++17. - return 4503599627370496.0; // 1 << 52 -} - //------------------------------------------------------------------------------ // Type relations @@ -556,6 +457,118 @@ using SignedFromSize = typename detail::TypeFromSize::Signed; template using FloatFromSize = typename detail::TypeFromSize::Float; +//------------------------------------------------------------------------------ +// Type traits + +template +HWY_API constexpr bool IsFloat() { + // Cannot use T(1.25) != T(1) for float16_t, which can only be converted to or + // from a float, not compared. + return IsSame() || IsSame(); +} + +template +HWY_API constexpr bool IsSigned() { + return T(0) > T(-1); +} +template <> +constexpr bool IsSigned() { + return true; +} +template <> +constexpr bool IsSigned() { + return true; +} + +// Largest/smallest representable integer values. +template +HWY_API constexpr T LimitsMax() { + static_assert(!IsFloat(), "Only for integer types"); + using TU = MakeUnsigned; + return static_cast(IsSigned() ? (static_cast(~0ull) >> 1) + : static_cast(~0ull)); +} +template +HWY_API constexpr T LimitsMin() { + static_assert(!IsFloat(), "Only for integer types"); + return IsSigned() ? T(-1) - LimitsMax() : T(0); +} + +// Largest/smallest representable value (integer or float). This naming avoids +// confusion with numeric_limits::min() (the smallest positive value). +template +HWY_API constexpr T LowestValue() { + return LimitsMin(); +} +template <> +constexpr float LowestValue() { + return -FLT_MAX; +} +template <> +constexpr double LowestValue() { + return -DBL_MAX; +} + +template +HWY_API constexpr T HighestValue() { + return LimitsMax(); +} +template <> +constexpr float HighestValue() { + return FLT_MAX; +} +template <> +constexpr double HighestValue() { + return DBL_MAX; +} + +// Returns bitmask of the exponent field in IEEE binary32/64. +template +constexpr T ExponentMask() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr uint32_t ExponentMask() { + return 0x7F800000; +} +template <> +constexpr uint64_t ExponentMask() { + return 0x7FF0000000000000ULL; +} + +// Returns bitmask of the mantissa field in IEEE binary32/64. +template +constexpr T MantissaMask() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr uint32_t MantissaMask() { + return 0x007FFFFF; +} +template <> +constexpr uint64_t MantissaMask() { + return 0x000FFFFFFFFFFFFFULL; +} + +// Returns 1 << mantissa_bits as a floating-point number. All integers whose +// absolute value are less than this can be represented exactly. +template +constexpr T MantissaEnd() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr float MantissaEnd() { + return 8388608.0f; // 1 << 23 +} +template <> +constexpr double MantissaEnd() { + // floating point literal with p52 requires C++17. + return 4503599627370496.0; // 1 << 52 +} + //------------------------------------------------------------------------------ // Helper functions @@ -661,14 +674,21 @@ HWY_API size_t PopCount(uint64_t x) { #endif } +// Skip HWY_API due to GCC "function not considered for inlining". Previously +// such errors were caused by underlying type mismatches, but it's not clear +// what is still mismatched despite all the casts. template -HWY_API constexpr size_t FloorLog2(TI x) { - return x == 1 ? 0 : FloorLog2(x >> 1) + 1; +/*HWY_API*/ constexpr size_t FloorLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast(FloorLog2(static_cast(x >> 1)) + 1); } template -HWY_API constexpr size_t CeilLog2(TI x) { - return x == 1 ? 0 : FloorLog2(x - 1) + 1; +/*HWY_API*/ constexpr size_t CeilLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast(FloorLog2(static_cast(x - 1)) + 1); } #if HWY_COMPILER_MSVC && HWY_ARCH_X86_64 @@ -727,7 +747,7 @@ HWY_API bfloat16_t BF16FromF32(float f) { return bf; } -HWY_NORETURN void HWY_FORMAT(3, 4) +HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) Abort(const char* file, int line, const char* format, ...); } // namespace hwy diff --git a/third_party/highway/hwy/cache_control.h b/third_party/highway/hwy/cache_control.h index 65f326a5f5f1..f00eaaed08e0 100644 --- a/third_party/highway/hwy/cache_control.h +++ b/third_party/highway/hwy/cache_control.h @@ -36,9 +36,7 @@ // undefine them in this header; these functions are anyway deprecated. // TODO(janwas): remove when these functions are removed. #pragma push_macro("LoadFence") -#pragma push_macro("StoreFence") #undef LoadFence -#undef StoreFence namespace hwy { @@ -72,9 +70,6 @@ HWY_INLINE HWY_ATTR_CACHE void FlushStream() { #endif } -// DEPRECATED, replace with `FlushStream`. -HWY_INLINE HWY_ATTR_CACHE void StoreFence() { FlushStream(); } - // Optionally begins loading the cache line containing "p" to reduce latency of // subsequent actual loads. template @@ -109,7 +104,6 @@ HWY_INLINE HWY_ATTR_CACHE void Pause() { } // namespace hwy // TODO(janwas): remove when these functions are removed. (See above.) -#pragma pop_macro("StoreFence") #pragma pop_macro("LoadFence") #endif // HIGHWAY_HWY_CACHE_CONTROL_H_ diff --git a/third_party/highway/hwy/contrib/image/image.cc b/third_party/highway/hwy/contrib/image/image.cc index 4b57cd35d88b..4703cacf5104 100644 --- a/third_party/highway/hwy/contrib/image/image.cc +++ b/third_party/highway/hwy/contrib/image/image.cc @@ -14,15 +14,14 @@ #include "hwy/contrib/image/image.h" +#include // swap #include #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "hwy/contrib/image/image.cc" - -#include // swap - #include "hwy/foreach_target.h" #include "hwy/highway.h" + HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { diff --git a/third_party/highway/hwy/contrib/image/image.h b/third_party/highway/hwy/contrib/image/image.h index 69a9a80fc8e1..da91481643c6 100644 --- a/third_party/highway/hwy/contrib/image/image.h +++ b/third_party/highway/hwy/contrib/image/image.h @@ -27,12 +27,13 @@ #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/highway_export.h" namespace hwy { // Type-independent parts of Image<> - reduces code duplication and facilitates // moving member function implementations to cc file. -struct ImageBase { +struct HWY_CONTRIB_DLLEXPORT ImageBase { // Returns required alignment in bytes for externally allocated memory. static size_t VectorSize(); @@ -100,8 +101,7 @@ struct ImageBase { protected: // Returns pointer to the start of a row. HWY_INLINE void* VoidRow(const size_t y) const { -#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ - defined(THREAD_SANITIZER) +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN if (y >= ysize_) { HWY_ABORT("Row(%" PRIu64 ") >= %u\n", static_cast(y), ysize_); } @@ -291,8 +291,7 @@ class Image3 { private: // Returns pointer to the start of a row. HWY_INLINE void* VoidPlaneRow(const size_t c, const size_t y) const { -#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ - defined(THREAD_SANITIZER) +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN if (c >= kNumPlanes || y >= ysize()) { HWY_ABORT("PlaneRow(%" PRIu64 ", %" PRIu64 ") >= %" PRIu64 "\n", static_cast(c), static_cast(y), diff --git a/third_party/highway/hwy/contrib/image/image_test.cc b/third_party/highway/hwy/contrib/image/image_test.cc index d5d64a3507a0..9b39d8b41d87 100644 --- a/third_party/highway/hwy/contrib/image/image_test.cc +++ b/third_party/highway/hwy/contrib/image/image_test.cc @@ -51,7 +51,7 @@ struct TestAlignedT { for (size_t y = 0; y < ysize; ++y) { T* HWY_RESTRICT row = img.MutableRow(y); for (size_t x = 0; x < xsize; x += Lanes(d)) { - const auto values = Iota(d, dist(rng)); + const auto values = Iota(d, static_cast(dist(rng))); Store(values, d, row + x); } } diff --git a/third_party/highway/hwy/contrib/math/math-inl.h b/third_party/highway/hwy/contrib/math/math-inl.h index 77f686de3874..6e0a00da873d 100644 --- a/third_party/highway/hwy/contrib/math/math-inl.h +++ b/third_party/highway/hwy/contrib/math/math-inl.h @@ -486,7 +486,7 @@ struct AsinImpl { } }; -#if HWY_CAP_FLOAT64 && HWY_CAP_INTEGER64 +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 template <> struct AsinImpl { @@ -531,7 +531,7 @@ struct AtanImpl { } }; -#if HWY_CAP_FLOAT64 && HWY_CAP_INTEGER64 +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 template <> struct AtanImpl { @@ -635,7 +635,7 @@ struct CosSinImpl { } }; -#if HWY_CAP_FLOAT64 && HWY_CAP_INTEGER64 +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 template <> struct CosSinImpl { @@ -787,7 +787,7 @@ struct LogImpl { } }; -#if HWY_CAP_FLOAT64 && HWY_CAP_INTEGER64 +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 template <> struct ExpImpl { // Rounds double toward zero and returns as int32_t. diff --git a/third_party/highway/hwy/contrib/math/math_test.cc b/third_party/highway/hwy/contrib/math/math_test.cc index cf9f8e339bf9..09ab02fc593f 100644 --- a/third_party/highway/hwy/contrib/math/math_test.cc +++ b/third_party/highway/hwy/contrib/math/math_test.cc @@ -61,7 +61,7 @@ HWY_NOINLINE void TestMath(const std::string name, T (*fx1)(T), uint64_t max_ulp = 0; // Emulation is slower, so cannot afford as many. - constexpr UintT kSamplesPerRange = static_cast(AdjustedReps(10000)); + constexpr UintT kSamplesPerRange = static_cast(AdjustedReps(4000)); for (int range_index = 0; range_index < range_count; ++range_index) { const UintT start = ranges[range_index][0]; const UintT stop = ranges[range_index][1]; @@ -96,24 +96,11 @@ HWY_NOINLINE void TestMath(const std::string name, T (*fx1)(T), HWY_ASSERT(max_ulp <= max_error_ulp); } -// TODO(janwas): remove once RVV supports fractional LMUL -#undef DEFINE_MATH_TEST_FUNC -#if HWY_TARGET == HWY_RVV - -#define DEFINE_MATH_TEST_FUNC(NAME) \ - HWY_NOINLINE void TestAll##NAME() { \ - ForFloatTypes(ForShrinkableVectors()); \ - } - -#else - #define DEFINE_MATH_TEST_FUNC(NAME) \ HWY_NOINLINE void TestAll##NAME() { \ ForFloatTypes(ForPartialVectors()); \ } -#endif - #undef DEFINE_MATH_TEST #define DEFINE_MATH_TEST(NAME, F32x1, F32xN, F32_MIN, F32_MAX, F32_ERROR, \ F64x1, F64xN, F64_MIN, F64_MAX, F64_ERROR) \ diff --git a/third_party/highway/hwy/contrib/sort/BUILD b/third_party/highway/hwy/contrib/sort/BUILD new file mode 100644 index 000000000000..03a9d09d3554 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/BUILD @@ -0,0 +1,133 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +# Unused on Bazel builds, where this is not defined/known; Copybara replaces +# usages with an empty list. +COMPAT = [ + "//buildenv/target:non_prod", # includes mobile/vendor. +] + +cc_library( + name = "vqsort", + srcs = [ + # Split into separate files to reduce MSVC build time. + "vqsort.cc", + "vqsort_i16a.cc", + "vqsort_i16d.cc", + "vqsort_u16a.cc", + "vqsort_u16d.cc", + "vqsort_f32a.cc", + "vqsort_f32d.cc", + "vqsort_i32a.cc", + "vqsort_i32d.cc", + "vqsort_u32a.cc", + "vqsort_u32d.cc", + "vqsort_f64a.cc", + "vqsort_f64d.cc", + "vqsort_i64a.cc", + "vqsort_i64d.cc", + "vqsort_u64a.cc", + "vqsort_u64d.cc", + "vqsort_128a.cc", + "vqsort_128d.cc", + ], + hdrs = [ + "disabled_targets.h", + "vqsort.h", # public interface + ], + compatible_with = [], + textual_hdrs = [ + "shared-inl.h", + "sorting_networks-inl.h", + "traits-inl.h", + "traits128-inl.h", + "vqsort-inl.h", + ], + deps = [ + # Only if VQSORT_SECURE_RNG is set. + # "//third_party/absl/random", + "//:hwy", + ], +) + +# ----------------------------------------------------------------------------- +# Internal-only targets + +cc_library( + name = "helpers", + testonly = 1, + textual_hdrs = [ + "algo-inl.h", + "result-inl.h", + ], + deps = [ + ":vqsort", + "//:nanobenchmark", + # Required for HAVE_PDQSORT, but that is unused and this is + # unavailable to Bazel builds, hence commented out. + # "//third_party/boost/allowed", + # Avoid ips4o and thus TBB to work around hwloc build failure. + ], +) + +cc_binary( + name = "print_network", + testonly = 1, + srcs = ["print_network.cc"], + deps = [ + ":helpers", + ":vqsort", + "//:hwy", + ], +) + +cc_test( + name = "sort_test", + size = "medium", + srcs = ["sort_test.cc"], + features = ["fully_static_link"], + linkstatic = True, + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":helpers", + ":vqsort", + "@com_google_googletest//:gtest_main", + "//:hwy", + "//:hwy_test_util", + ], +) + +cc_binary( + name = "bench_sort", + testonly = 1, + srcs = ["bench_sort.cc"], + features = ["fully_static_link"], + linkstatic = True, + local_defines = ["HWY_IS_TEST"], + deps = [ + ":helpers", + ":vqsort", + "@com_google_googletest//:gtest_main", + "//:hwy", + "//:hwy_test_util", + ], +) + +cc_binary( + name = "bench_parallel", + testonly = 1, + srcs = ["bench_parallel.cc"], + features = ["fully_static_link"], + linkstatic = True, + local_defines = ["HWY_IS_TEST"], + deps = [ + ":helpers", + ":vqsort", + "@com_google_googletest//:gtest_main", + "//:hwy", + "//:hwy_test_util", + ], +) diff --git a/third_party/highway/hwy/contrib/sort/algo-inl.h b/third_party/highway/hwy/contrib/sort/algo-inl.h new file mode 100644 index 000000000000..db9e04a6037c --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/algo-inl.h @@ -0,0 +1,395 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +#include +#include // memcpy + +#include +#include // std::abs +#include + +#include "hwy/base.h" +#include "hwy/contrib/sort/vqsort.h" + +// Third-party algorithms +#define HAVE_AVX2SORT 0 +#define HAVE_IPS4O 0 +#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1) +#define HAVE_PDQSORT 0 +#define HAVE_SORT512 0 + +#if HAVE_AVX2SORT +HWY_PUSH_ATTRIBUTES("avx2,avx") +#include "avx2sort.h" +HWY_POP_ATTRIBUTES +#endif +#if HAVE_IPS4O +#include "third_party/ips4o/include/ips4o.hpp" +#include "third_party/ips4o/include/ips4o/thread_pool.hpp" +#endif +#if HAVE_PDQSORT +#include "third_party/boost/allowed/sort/sort.hpp" +#endif +#if HAVE_SORT512 +#include "sort512.h" +#endif + +namespace hwy { + +enum class Dist { kUniform8, kUniform16, kUniform32 }; + +std::vector AllDist() { + return {/*Dist::kUniform8,*/ Dist::kUniform16, Dist::kUniform32}; +} + +const char* DistName(Dist dist) { + switch (dist) { + case Dist::kUniform8: + return "uniform8"; + case Dist::kUniform16: + return "uniform16"; + case Dist::kUniform32: + return "uniform32"; + } + return "unreachable"; +} + +template +class InputStats { + public: + void Notify(T value) { + min_ = std::min(min_, value); + max_ = std::max(max_, value); + sumf_ += static_cast(value); + count_ += 1; + } + + bool operator==(const InputStats& other) const { + if (count_ != other.count_) { + HWY_ABORT("count %d vs %d\n", static_cast(count_), + static_cast(other.count_)); + } + + if (min_ != other.min_ || max_ != other.max_) { + HWY_ABORT("minmax %f/%f vs %f/%f\n", double(min_), double(max_), + double(other.min_), double(other.max_)); + } + + // Sum helps detect duplicated/lost values + if (sumf_ != other.sumf_) { + // Allow some tolerance because kUniform32 * num can exceed double + // precision. + const double mul = 1E-9; // prevent destructive cancellation + const double err = std::abs(sumf_ * mul - other.sumf_ * mul); + if (err > 1E-3) { + HWY_ABORT("Sum mismatch %.15e %.15e (%f) min %g max %g\n", sumf_, + other.sumf_, err, double(min_), double(max_)); + } + } + + return true; + } + + private: + T min_ = hwy::HighestValue(); + T max_ = hwy::LowestValue(); + double sumf_ = 0.0; + size_t count_ = 0; +}; + +enum class Algo { +#if HAVE_AVX2SORT + kSEA, +#endif +#if HAVE_IPS4O + kIPS4O, +#endif +#if HAVE_PARALLEL_IPS4O + kParallelIPS4O, +#endif +#if HAVE_PDQSORT + kPDQ, +#endif +#if HAVE_SORT512 + kSort512, +#endif + kStd, + kVQSort, + kHeap, +}; + +const char* AlgoName(Algo algo) { + switch (algo) { +#if HAVE_AVX2SORT + case Algo::kSEA: + return "sea"; +#endif +#if HAVE_IPS4O + case Algo::kIPS4O: + return "ips4o"; +#endif +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + return "par_ips4o"; +#endif +#if HAVE_PDQSORT + case Algo::kPDQ: + return "pdq"; +#endif +#if HAVE_SORT512 + case Algo::kSort512: + return "sort512"; +#endif + case Algo::kStd: + return "std"; + case Algo::kVQSort: + return "vq"; + case Algo::kHeap: + return "heap"; + } + return "unreachable"; +} + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#endif + +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" // HeapSort +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +class Xorshift128Plus { + static HWY_INLINE uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + public: + // Generates two vectors of 64-bit seeds via SplitMix64 and stores into + // `seeds`. Generating these afresh in each ChoosePivot is too expensive. + template + static void GenerateSeeds(DU64 du64, TFromD* HWY_RESTRICT seeds) { + seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull); + for (size_t i = 1; i < 2 * Lanes(du64); ++i) { + seeds[i] = SplitMix64(seeds[i - 1]); + } + } + + // Need to pass in the state because vector cannot be class members. + template + static Vec RandomBits(DU64 /* tag */, Vec& state0, + Vec& state1) { + Vec s1 = state0; + Vec s0 = state1; + const Vec bits = Add(s1, s0); + state0 = s0; + s1 = Xor(s1, ShiftLeft<23>(s1)); + state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0)))); + return bits; + } +}; + +template +Vec RandomValues(DU64 du64, Vec& s0, Vec& s1, + const Vec mask) { + const Vec bits = Xorshift128Plus::RandomBits(du64, s0, s1); + return And(bits, mask); +} + +// Important to avoid denormals, which are flushed to zero by SIMD but not +// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD. +template +Vec RandomValues(DU64 du64, Vec& s0, Vec& s1, + const Vec mask) { + const Vec bits = Xorshift128Plus::RandomBits(du64, s0, s1); + const Vec values = And(bits, mask); +#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to i32 + const RebindToSigned di; +#else + const Repartition, DU64> di; +#endif + const RebindToFloat df; + // Avoid NaN/denormal by converting from (range-limited) integer. + const Vec no_nan = + And(values, Set(du64, MantissaMask>())); + return BitCast(du64, ConvertTo(df, BitCast(di, no_nan))); +} + +template +Vec MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) { + switch (sizeof_t) { + case 2: + return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull + : 0xFFFFFFFFFFFFFFFFull); + case 4: + return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull + : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull + : 0xFFFFFFFFFFFFFFFFull); + case 8: + return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull + : (dist == Dist::kUniform16) ? 0x000000000000FFFFull + : 0x00000000FFFFFFFFull); + default: + HWY_ABORT("Logic error"); + return Zero(du64); + } +} + +template +InputStats GenerateInput(const Dist dist, T* v, size_t num) { + SortTag du64; + using VU64 = Vec; + const size_t N64 = Lanes(du64); + auto buf = hwy::AllocateAligned(2 * N64); + Xorshift128Plus::GenerateSeeds(du64, buf.get()); + auto s0 = Load(du64, buf.get()); + auto s1 = Load(du64, buf.get() + N64); + + const VU64 mask = MaskForDist(du64, dist, sizeof(T)); + + const Repartition d; + const size_t N = Lanes(d); + size_t i = 0; + for (; i + N <= num; i += N) { + const VU64 bits = RandomValues(du64, s0, s1, mask); +#if HWY_ARCH_RVV + // v may not be 64-bit aligned + StoreU(bits, du64, buf.get()); + memcpy(v + i, buf.get(), N64 * sizeof(uint64_t)); +#else + StoreU(bits, du64, reinterpret_cast(v + i)); +#endif + } + if (i < num) { + const VU64 bits = RandomValues(du64, s0, s1, mask); + StoreU(bits, du64, buf.get()); + memcpy(v + i, buf.get(), (num - i) * sizeof(T)); + } + + InputStats input_stats; + for (size_t i = 0; i < num; ++i) { + input_stats.Notify(v[i]); + } + return input_stats; +} + +struct ThreadLocal { + Sorter sorter; +}; + +struct SharedState { +#if HAVE_PARALLEL_IPS4O + ips4o::StdThreadPool pool{ + HWY_MIN(16, static_cast(std::thread::hardware_concurrency() / 2))}; +#endif + std::vector tls{1}; +}; + +template +void Run(Algo algo, T* HWY_RESTRICT inout, size_t num, SharedState& shared, + size_t thread) { + using detail::HeapSort; + using detail::LaneTraits; + using detail::SharedTraits; + + switch (algo) { +#if HAVE_AVX2SORT + case Algo::kSEA: + return avx2::quicksort(inout, static_cast(num)); +#endif + +#if HAVE_IPS4O + case Algo::kIPS4O: + if (Order().IsAscending()) { + return ips4o::sort(inout, inout + num, std::less()); + } else { + return ips4o::sort(inout, inout + num, std::greater()); + } +#endif + +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + if (Order().IsAscending()) { + return ips4o::parallel::sort(inout, inout + num, std::less()); + } else { + return ips4o::parallel::sort(inout, inout + num, std::greater()); + } +#endif + +#if HAVE_SORT512 + case Algo::kSort512: + HWY_ABORT("not supported"); + // return Sort512::Sort(inout, num); +#endif + +#if HAVE_PDQSORT + case Algo::kPDQ: + if (Order().IsAscending()) { + return boost::sort::pdqsort_branchless(inout, inout + num, + std::less()); + } else { + return boost::sort::pdqsort_branchless(inout, inout + num, + std::greater()); + } +#endif + + case Algo::kStd: + if (Order().IsAscending()) { + return std::sort(inout, inout + num, std::less()); + } else { + return std::sort(inout, inout + num, std::greater()); + } + + case Algo::kVQSort: + return shared.tls[thread].sorter(inout, num, Order()); + + case Algo::kHeap: + HWY_ASSERT(sizeof(T) < 16); + if (Order().IsAscending()) { + const SharedTraits> st; + return HeapSort(st, inout, num); + } else { + const SharedTraits> st; + return HeapSort(st, inout, num); + } + + default: + HWY_ABORT("Not implemented"); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/bench_parallel.cc b/third_party/highway/hwy/contrib/sort/bench_parallel.cc new file mode 100644 index 000000000000..9c348c154021 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/bench_parallel.cc @@ -0,0 +1,243 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Concurrent, independent sorts for generating more memory traffic and testing +// scalability. + +// clang-format off +#include "hwy/contrib/sort/vqsort.h" +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_parallel.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/aligned_allocator.h" +// Last +#include "hwy/tests/test_util-inl.h" +// clang-format on + +#include +#include + +#include //NOLINT +#include +#include +#include //NOLINT +#include //NOLINT +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +#if HWY_TARGET != HWY_SCALAR + +class ThreadPool { + public: + // Starts the given number of worker threads and blocks until they are ready. + explicit ThreadPool( + const size_t num_threads = std::thread::hardware_concurrency() / 2) + : num_threads_(num_threads) { + HWY_ASSERT(num_threads_ > 0); + threads_.reserve(num_threads_); + for (size_t i = 0; i < num_threads_; ++i) { + threads_.emplace_back(ThreadFunc, this, i); + } + + WorkersReadyBarrier(); + } + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator&(const ThreadPool&) = delete; + + // Waits for all threads to exit. + ~ThreadPool() { + StartWorkers(kWorkerExit); + + for (std::thread& thread : threads_) { + thread.join(); + } + } + + size_t NumThreads() const { return threads_.size(); } + + template + void RunOnThreads(size_t max_threads, const Func& func) { + task_ = &CallClosure; + data_ = &func; + StartWorkers(max_threads); + WorkersReadyBarrier(); + } + + private: + // After construction and between calls to Run, workers are "ready", i.e. + // waiting on worker_start_cv_. They are "started" by sending a "command" + // and notifying all worker_start_cv_ waiters. (That is why all workers + // must be ready/waiting - otherwise, the notification will not reach all of + // them and the main thread waits in vain for them to report readiness.) + using WorkerCommand = uint64_t; + + static constexpr WorkerCommand kWorkerWait = ~1ULL; + static constexpr WorkerCommand kWorkerExit = ~2ULL; + + // Calls a closure (lambda with captures). + template + static void CallClosure(const void* f, size_t thread) { + (*reinterpret_cast(f))(thread); + } + + void WorkersReadyBarrier() { + std::unique_lock lock(mutex_); + // Typically only a single iteration. + while (workers_ready_ != threads_.size()) { + workers_ready_cv_.wait(lock); + } + workers_ready_ = 0; + + // Safely handle spurious worker wakeups. + worker_start_command_ = kWorkerWait; + } + + // Precondition: all workers are ready. + void StartWorkers(const WorkerCommand worker_command) { + std::unique_lock lock(mutex_); + worker_start_command_ = worker_command; + // Workers will need this lock, so release it before they wake up. + lock.unlock(); + worker_start_cv_.notify_all(); + } + + static void ThreadFunc(ThreadPool* self, size_t thread) { + // Until kWorkerExit command received: + for (;;) { + std::unique_lock lock(self->mutex_); + // Notify main thread that this thread is ready. + if (++self->workers_ready_ == self->num_threads_) { + self->workers_ready_cv_.notify_one(); + } + RESUME_WAIT: + // Wait for a command. + self->worker_start_cv_.wait(lock); + const WorkerCommand command = self->worker_start_command_; + switch (command) { + case kWorkerWait: // spurious wakeup: + goto RESUME_WAIT; // lock still held, avoid incrementing ready. + case kWorkerExit: + return; // exits thread + default: + break; + } + + lock.unlock(); + // Command is the maximum number of threads that should run the task. + HWY_ASSERT(command < self->NumThreads()); + if (thread < command) { + self->task_(self->data_, thread); + } + } + } + + const size_t num_threads_; + + // Unmodified after ctor, but cannot be const because we call thread::join(). + std::vector threads_; + + std::mutex mutex_; // guards both cv and their variables. + std::condition_variable workers_ready_cv_; + size_t workers_ready_ = 0; + std::condition_variable worker_start_cv_; + WorkerCommand worker_start_command_; + + // Written by main thread, read by workers (after mutex lock/unlock). + std::function task_; // points to CallClosure + const void* data_; // points to caller's Func +}; + +template +void RunWithoutVerify(const Dist dist, const size_t num, const Algo algo, + SharedState& shared, size_t thread) { + auto aligned = hwy::AllocateAligned(num); + + (void)GenerateInput(dist, aligned.get(), num); + + const Timestamp t0; + Run(algo, aligned.get(), num, shared, thread); + HWY_ASSERT(aligned[0] < aligned[num - 1]); +} + +void BenchParallel() { + // Not interested in benchmark results for other targets + if (HWY_TARGET != HWY_AVX3) return; + + ThreadPool pool; + const size_t NT = pool.NumThreads(); + + using T = int64_t; + detail::SharedTraits> st; + + size_t num = 100 * 1000 * 1000; + +#if HAVE_IPS4O + const Algo algo = Algo::kIPS4O; +#else + const Algo algo = Algo::kVQSort; +#endif + const Dist dist = Dist::kUniform16; + + SharedState shared; + shared.tls.resize(NT); + + std::vector results; + for (size_t nt = 1; nt < NT; nt += HWY_MAX(1, NT / 16)) { + Timestamp t0; + // Default capture because MSVC wants algo/dist but clang does not. + pool.RunOnThreads(nt, [=, &shared](size_t thread) { + RunWithoutVerify(dist, num, algo, shared, thread); + }); + const double sec = SecondsSince(t0); + results.push_back(MakeResult(algo, dist, st, num, nt, sec)); + results.back().Print(); + } +} + +#else +void BenchParallel() {} +#endif + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +namespace { +HWY_BEFORE_TEST(BenchParallel); +HWY_EXPORT_AND_TEST_P(BenchParallel, BenchParallel); +} // namespace +} // namespace hwy + +// Ought not to be necessary, but without this, no tests run on RVV. +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/bench_sort.cc b/third_party/highway/hwy/contrib/sort/bench_sort.cc new file mode 100644 index 000000000000..095fa6ccb918 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/bench_sort.cc @@ -0,0 +1,259 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_sort.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/contrib/sort/vqsort.h" +#include "hwy/contrib/sort/sorting_networks-inl.h" // SharedTraits +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +#include +#include +#include // memcpy + +#include + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { +using detail::LaneTraits; +using detail::OrderAscending; +using detail::OrderDescending; +using detail::SharedTraits; + +#if HWY_TARGET != HWY_SCALAR +using detail::OrderAscending128; +using detail::OrderDescending128; +using detail::Traits128; + +template +HWY_NOINLINE void BenchPartition() { + const SortTag d; + detail::SharedTraits st; + const Dist dist = Dist::kUniform8; + double sum = 0.0; + + const size_t max_log2 = AdjustedLog2Reps(20); + for (size_t log2 = max_log2; log2 < max_log2 + 1; ++log2) { + const size_t num = 1ull << log2; + auto aligned = hwy::AllocateAligned(num); + auto buf = + hwy::AllocateAligned(hwy::SortConstants::PartitionBufNum(Lanes(d))); + + std::vector seconds; + const size_t num_reps = (1ull << (14 - log2 / 2)) * kReps; + for (size_t rep = 0; rep < num_reps; ++rep) { + (void)GenerateInput(dist, aligned.get(), num); + + const Timestamp t0; + + detail::Partition(d, st, aligned.get(), 0, num - 1, Set(d, T(128)), + buf.get()); + seconds.push_back(SecondsSince(t0)); + // 'Use' the result to prevent optimizing out the partition. + sum += static_cast(aligned.get()[num / 2]); + } + + MakeResult(Algo::kVQSort, dist, st, num, 1, + SummarizeMeasurements(seconds)) + .Print(); + } + HWY_ASSERT(sum != 999999); // Prevent optimizing out +} + +HWY_NOINLINE void BenchAllPartition() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 || + HWY_TARGET == HWY_AVX2) { + return; + } + + BenchPartition, float>(); + BenchPartition, int64_t>(); + BenchPartition, uint64_t>(); +} + +template +HWY_NOINLINE void BenchBase(std::vector& results) { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4) { + return; + } + + const SortTag d; + detail::SharedTraits st; + const Dist dist = Dist::kUniform32; + + const size_t N = Lanes(d); + const size_t num = SortConstants::BaseCaseNum(N); + auto keys = hwy::AllocateAligned(num); + auto buf = hwy::AllocateAligned(num + N); + + std::vector seconds; + double sum = 0; // prevents elision + constexpr size_t kMul = AdjustedReps(600); // ensures long enough to measure + + for (size_t rep = 0; rep < kReps; ++rep) { + InputStats input_stats = GenerateInput(dist, keys.get(), num); + + const Timestamp t0; + for (size_t i = 0; i < kMul; ++i) { + detail::BaseCase(d, st, keys.get(), num, buf.get()); + sum += static_cast(keys[0]); + } + seconds.push_back(SecondsSince(t0)); + // printf("%f\n", seconds.back()); + + HWY_ASSERT(VerifySort(st, input_stats, keys.get(), num, "BenchBase")); + } + HWY_ASSERT(sum < 1E99); + results.push_back(MakeResult(Algo::kVQSort, dist, st, num * kMul, 1, + SummarizeMeasurements(seconds))); +} + +HWY_NOINLINE void BenchAllBase() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3) { + return; + } + + std::vector results; + BenchBase, float>(results); + BenchBase, int64_t>(results); + BenchBase, uint64_t>(results); + for (const Result& r : results) { + r.Print(); + } +} + +std::vector AlgoForBench() { + return { +#if HAVE_AVX2SORT + Algo::kSEA, +#endif +#if HAVE_PARALLEL_IPS4O + Algo::kParallelIPS4O, +#endif +#if HAVE_IPS4O + Algo::kIPS4O, +#endif +#if HAVE_PDQSORT + Algo::kPDQ, +#endif +#if HAVE_SORT512 + Algo::kSort512, +#endif + // Algo::kStd, // too slow to always benchmark + // Algo::kHeap, // too slow to always benchmark + Algo::kVQSort, + }; +} + +template +HWY_NOINLINE void BenchSort(size_t num) { + SharedState shared; + detail::SharedTraits st; + auto aligned = hwy::AllocateAligned(num); + for (Algo algo : AlgoForBench()) { + for (Dist dist : AllDist()) { + std::vector seconds; + for (size_t rep = 0; rep < kReps; ++rep) { + InputStats input_stats = GenerateInput(dist, aligned.get(), num); + + const Timestamp t0; + Run(algo, aligned.get(), num, shared, + /*thread=*/0); + seconds.push_back(SecondsSince(t0)); + // printf("%f\n", seconds.back()); + + HWY_ASSERT( + VerifySort(st, input_stats, aligned.get(), num, "BenchSort")); + } + MakeResult(algo, dist, st, num, 1, SummarizeMeasurements(seconds)) + .Print(); + } // dist + } // algo +} + +HWY_NOINLINE void BenchAllSort() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4) { + return; + } + + constexpr size_t K = 1000; + constexpr size_t M = K * K; + (void)K; + (void)M; + for (size_t num : { +#if HAVE_PARALLEL_IPS4O + 100 * M, +#else + AdjustedReps(1 * M), +#endif + }) { + // BenchSort, float>(num); + // BenchSort, double>(num); + // BenchSort, int16_t>(num); + BenchSort, int32_t>(num); + BenchSort, int64_t>(num); + // BenchSort, uint16_t>(num); + // BenchSort, uint32_t>(num); + // BenchSort, uint64_t>(num); + + BenchSort, uint64_t>(num); + // BenchSort, uint64_t>(num); + } +} + +#else +void BenchAllPartition() {} +void BenchAllBase() {} +void BenchAllSort() {} +#endif + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +namespace { +HWY_BEFORE_TEST(BenchSort); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllPartition); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllBase); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllSort); +} // namespace +} // namespace hwy + +// Ought not to be necessary, but without this, no tests run on RVV. +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/disabled_targets.h b/third_party/highway/hwy/contrib/sort/disabled_targets.h new file mode 100644 index 000000000000..b9927ba7bab4 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/disabled_targets.h @@ -0,0 +1,30 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Speed up MSVC builds by building fewer targets. This header must be included +// from all TUs that contain a HWY_DYNAMIC_DISPATCH to vqsort, i.e. vqsort_*.cc. +// However, users of vqsort.h are unaffected. + +#ifndef HIGHWAY_HWY_CONTRIB_SORT_DISABLED_TARGETS_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_DISABLED_TARGETS_H_ + +#include "hwy/base.h" + +#if HWY_COMPILER_MSVC +#undef HWY_DISABLED_TARGETS +// HWY_SCALAR remains, so there will still be a valid target to call. +#define HWY_DISABLED_TARGETS (HWY_SSSE3 | HWY_SSE4) +#endif // HWY_COMPILER_MSVC + +#endif // HIGHWAY_HWY_CONTRIB_SORT_DISABLED_TARGETS_H_ diff --git a/third_party/highway/hwy/contrib/sort/print_network.cc b/third_party/highway/hwy/contrib/sort/print_network.cc new file mode 100644 index 000000000000..6e1e49516418 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/print_network.cc @@ -0,0 +1,190 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "hwy/base.h" + +// Based on A.7 in "Entwurf und Implementierung vektorisierter +// Sortieralgorithmen" and code by Mark Blacher. +void PrintMergeNetwork16x2() { + for (int i = 8; i < 16; ++i) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i); + } + printf("\n"); +} + +void PrintMergeNetwork16x4() { + printf("\n"); + + for (int i = 8; i < 16; ++i) { + printf("v%x = st.Reverse4(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.Reverse4(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.Reverse4(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.Reverse4(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.Reverse4(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.Reverse4(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsReverse4(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i); + } +} + +void PrintMergeNetwork16x8() { + printf("\n"); + + for (int i = 8; i < 16; ++i) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsReverse8(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance2(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i); + } +} + +void PrintMergeNetwork16x16() { + printf("\n"); + + for (int i = 8; i < 16; ++i) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsReverse16(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance4(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance2(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i); + } +} + +int main(int argc, char** argv) { + PrintMergeNetwork16x2(); + PrintMergeNetwork16x4(); + PrintMergeNetwork16x8(); + PrintMergeNetwork16x16(); + return 0; +} diff --git a/third_party/highway/hwy/contrib/sort/result-inl.h b/third_party/highway/hwy/contrib/sort/result-inl.h new file mode 100644 index 000000000000..9342911adab0 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/result-inl.h @@ -0,0 +1,149 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/algo-inl.h" + +// Normal include guard for non-SIMD parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +#include + +#include // std::sort +#include + +#include "hwy/base.h" +#include "hwy/nanobenchmark.h" + +namespace hwy { + +struct Timestamp { + Timestamp() { t = platform::Now(); } + double t; +}; + +double SecondsSince(const Timestamp& t0) { + const Timestamp t1; + return t1.t - t0.t; +} + +constexpr size_t kReps = 30; + +// Returns trimmed mean (we don't want to run an out-of-L3-cache sort often +// enough for the mode to be reliable). +double SummarizeMeasurements(std::vector& seconds) { + std::sort(seconds.begin(), seconds.end()); + double sum = 0; + int count = 0; + for (size_t i = kReps / 4; i < seconds.size() - kReps / 2; ++i) { + sum += seconds[i]; + count += 1; + } + return sum / count; +} + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct Result { + Result() {} + Result(const uint32_t target, const Algo algo, Dist dist, bool is128, + size_t num, size_t num_threads, double sec, size_t sizeof_t, + const char* type_name) + : target(target), + algo(algo), + dist(dist), + is128(is128), + num(num), + num_threads(num_threads), + sec(sec), + sizeof_t(sizeof_t), + type_name(type_name) {} + + void Print() const { + const double bytes = static_cast(num) * + static_cast(num_threads) * + static_cast(sizeof_t); + printf("%10s: %12s: %7s: %9s: %.2E %4.0f MB/s (%2zu threads)\n", + hwy::TargetName(target), AlgoName(algo), + is128 ? "u128" : type_name.c_str(), DistName(dist), + static_cast(num), bytes * 1E-6 / sec, num_threads); + } + + uint32_t target; + Algo algo; + Dist dist; + bool is128; + size_t num = 0; + size_t num_threads = 0; + double sec = 0.0; + size_t sizeof_t = 0; + std::string type_name; +}; + +template +Result MakeResult(const Algo algo, Dist dist, Traits st, size_t num, + size_t num_threads, double sec) { + char string100[100]; + hwy::detail::TypeName(hwy::detail::MakeTypeInfo(), 1, string100); + return Result(HWY_TARGET, algo, dist, st.Is128(), num, num_threads, sec, + sizeof(T), string100); +} + +template +bool VerifySort(Traits st, const InputStats& input_stats, const T* out, + size_t num, const char* caller) { + constexpr size_t N1 = st.Is128() ? 2 : 1; + HWY_ASSERT(num >= N1); + + InputStats output_stats; + // Ensure it matches the sort order + for (size_t i = 0; i < num - N1; i += N1) { + output_stats.Notify(out[i]); + if (N1 == 2) output_stats.Notify(out[i + 1]); + // Reverse order instead of checking !Compare1 so we accept equal keys. + if (st.Compare1(out + i + N1, out + i)) { + printf("%s: i=%d of %d: N1=%d %5.0f %5.0f vs. %5.0f %5.0f\n\n", caller, + static_cast(i), static_cast(num), static_cast(N1), + double(out[i + 1]), double(out[i + 0]), double(out[i + N1 + 1]), + double(out[i + N1])); + HWY_ABORT("%d-bit sort is incorrect\n", + static_cast(sizeof(T) * 8 * N1)); + } + } + output_stats.Notify(out[num - N1]); + if (N1 == 2) output_stats.Notify(out[num - N1 + 1]); + + return input_stats == output_stats; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/shared-inl.h b/third_party/highway/hwy/contrib/sort/shared-inl.h new file mode 100644 index 000000000000..8f60613e6bc0 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/shared-inl.h @@ -0,0 +1,104 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions shared between vqsort-inl and sorting_networks-inl. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +#include "hwy/base.h" + +namespace hwy { + +// Internal constants - these are to avoid magic numbers/literals and cannot be +// changed without also changing the associated code. +struct SortConstants { +// SortingNetwork reshapes its input into a matrix. This is the maximum number +// of *keys* per vector. +#if HWY_COMPILER_MSVC + static constexpr size_t kMaxCols = 8; // avoids build timeout +#else + static constexpr size_t kMaxCols = 16; // enough for u32 in 512-bit vector +#endif + + // 16 rows is a compromise between using the 32 AVX-512/SVE/RVV registers, + // fitting within 16 AVX2 registers with only a few spills, keeping BaseCase + // code size reasonable (7 KiB for AVX-512 and 16 cols), and minimizing the + // extra logN factor for larger networks (for which only loose upper bounds + // on size are known). + static constexpr size_t kMaxRowsLog2 = 4; + static constexpr size_t kMaxRows = size_t{1} << kMaxRowsLog2; + + static HWY_INLINE size_t BaseCaseNum(size_t N) { + return kMaxRows * HWY_MIN(N, kMaxCols); + } + + // Unrolling is important (pipelining and amortizing branch mispredictions); + // 2x is sufficient to reach full memory bandwidth on SKX in Partition, but + // somewhat slower for sorting than 4x. + // + // To change, must also update left + 3 * N etc. in the loop. + static constexpr size_t kPartitionUnroll = 4; + + static HWY_INLINE size_t PartitionBufNum(size_t N) { + // The main loop reads kPartitionUnroll vectors, and first loads from + // both left and right beforehand, so it requires min = 2 * + // kPartitionUnroll vectors. To handle smaller amounts (only guaranteed + // >= BaseCaseNum), we partition the right side into a buffer. We need + // another vector at the end so CompressStore does not overwrite anything. + return (2 * kPartitionUnroll + 1) * N; + } + + // Chunk := group of keys loaded for sampling a pivot. Matches the typical + // cache line size of 64 bytes to get maximum benefit per L2 miss. If vectors + // are larger, use entire vectors to ensure we do not overrun the array. + static HWY_INLINE size_t LanesPerChunk(size_t sizeof_t, size_t N) { + return HWY_MAX(64 / sizeof_t, N); + } +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#endif + +#include "hwy/highway.h" + +namespace hwy { +namespace HWY_NAMESPACE { + +// Default tag / vector width selector. +// TODO(janwas): enable once LMUL < 1 is supported. +#if HWY_TARGET == HWY_RVV && 0 +template +using SortTag = ScalableTag; +#else +template +using SortTag = ScalableTag; +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/sort_test.cc b/third_party/highway/hwy/contrib/sort/sort_test.cc index c3b421fc76f2..dcb14663ca30 100644 --- a/third_party/highway/hwy/contrib/sort/sort_test.cc +++ b/third_party/highway/hwy/contrib/sort/sort_test.cc @@ -12,160 +12,553 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include -#include - // clang-format off #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "hwy/contrib/sort/sort_test.cc" #include "hwy/foreach_target.h" -#include "hwy/contrib/sort/sort-inl.h" +#include "hwy/contrib/sort/vqsort.h" +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" // BaseCase #include "hwy/tests/test_util-inl.h" // clang-format on +#include +#include +#include // memcpy + +#include // std::max +#include + +#undef VQSORT_TEST_IMPL +#if (HWY_TARGET == HWY_SCALAR) || (defined(_MSC_VER) && !HWY_IS_DEBUG_BUILD) +// Scalar does not implement these, and MSVC non-debug builds time out. +#define VQSORT_TEST_IMPL 0 +#else +#define VQSORT_TEST_IMPL 1 +#endif + +#undef VQSORT_TEST_SORT +// MSVC non-debug builds time out. +#if defined(_MSC_VER) && !HWY_IS_DEBUG_BUILD +#define VQSORT_TEST_SORT 0 +#else +#define VQSORT_TEST_SORT 1 +#endif + HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { +namespace { -#if HWY_TARGET != HWY_SCALAR && HWY_ARCH_X86 - -template -size_t K(D d) { - return SortBatchSize(d); -} - -template -void Validate(D d, const TFromD* in, const TFromD* out) { - const size_t N = Lanes(d); - // Ensure it matches the sort order - for (size_t i = 0; i < K(d) - 1; ++i) { - if (!verify::Compare(out[i], out[i + 1], kOrder)) { - printf("range=%" PRIu64 " lane=%" PRIu64 " N=%" PRIu64 " %.0f %.0f\n\n", - static_cast(i), static_cast(i), - static_cast(N), static_cast(out[i + 0]), - static_cast(out[i + 1])); - for (size_t i = 0; i < K(d); ++i) { - printf("%.0f\n", static_cast(out[i])); - } - - printf("\n\nin was:\n"); - for (size_t i = 0; i < K(d); ++i) { - printf("%.0f\n", static_cast(in[i])); - } - fflush(stdout); - HWY_ABORT("Sort is incorrect"); - } - } - - // Also verify sums match (detects duplicated/lost values) - double expected_sum = 0.0; - double actual_sum = 0.0; - for (size_t i = 0; i < K(d); ++i) { - expected_sum += in[i]; - actual_sum += out[i]; - } - if (expected_sum != actual_sum) { - for (size_t i = 0; i < K(d); ++i) { - printf("%.0f %.0f\n", static_cast(in[i]), - static_cast(out[i])); - } - HWY_ABORT("Mismatch"); - } -} - -class TestReverse { - template - void TestOrder(D d, RandomState& /* rng */) { - using T = TFromD; - const size_t N = Lanes(d); - HWY_ASSERT((N % 4) == 0); - auto in = AllocateAligned(K(d)); - auto inout = AllocateAligned(K(d)); - - const size_t expected_size = SortBatchSize(d); - - for (size_t i = 0; i < K(d); ++i) { - in[i] = static_cast(K(d) - i); - inout[i] = in[i]; - } - - const size_t actual_size = SortBatch(d, inout.get()); - HWY_ASSERT_EQ(expected_size, actual_size); - Validate(d, in.get(), inout.get()); - } - - public: - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - RandomState rng; - TestOrder(d, rng); - TestOrder(d, rng); - } -}; - -void TestAllReverse() { - TestReverse test; - test(int32_t(), CappedTag()); - test(uint32_t(), CappedTag()); -} - -class TestRanges { - template - void TestOrder(D d, RandomState& rng) { - using T = TFromD; - const size_t N = Lanes(d); - HWY_ASSERT((N % 4) == 0); - auto in = AllocateAligned(K(d)); - auto inout = AllocateAligned(K(d)); - - const size_t expected_size = SortBatchSize(d); - - // For each range, try all 0/1 combinations and set any other lanes to - // random inputs. - constexpr size_t kRange = 8; - for (size_t range = 0; range < K(d); range += kRange) { - for (size_t bits = 0; bits < (1ull << kRange); ++bits) { - // First set all to random, will later overwrite those for `range` - for (size_t i = 0; i < K(d); ++i) { - in[i] = inout[i] = static_cast(Random32(&rng) & 0xFF); - } - // Now set the current combination of {0,1} for elements in the range. - // This is sufficient to establish correctness (arbitrary inputs could - // be mapped to 0/1 with a comparison predicate). - for (size_t i = 0; i < kRange; ++i) { - in[range + i] = inout[range + i] = (bits >> i) & 1; - } - - const size_t actual_size = SortBatch(d, inout.get()); - HWY_ASSERT_EQ(expected_size, actual_size); - Validate(d, in.get(), inout.get()); - } - } - } - - public: - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - RandomState rng; - TestOrder(d, rng); - TestOrder(d, rng); - } -}; - -void TestAllRanges() { - TestRanges test; - test(int32_t(), CappedTag()); - test(uint32_t(), CappedTag()); -} +#if VQSORT_TEST_IMPL || VQSORT_TEST_SORT +using detail::LaneTraits; +using detail::OrderAscending; +using detail::OrderAscending128; +using detail::OrderDescending; +using detail::OrderDescending128; +using detail::SharedTraits; +using detail::Traits128; +#endif +#if !VQSORT_TEST_IMPL +static void TestAllMedian() {} +static void TestAllBaseCase() {} +static void TestAllPartition() {} +static void TestAllGenerator() {} #else -void TestAllReverse() {} -void TestAllRanges() {} -#endif // HWY_TARGET != HWY_SCALAR && HWY_ARCH_X86 +template +static HWY_NOINLINE void TestMedian3() { + using T = uint64_t; + using D = CappedTag; + SharedTraits st; + const D d; + using V = Vec; + for (uint32_t bits = 0; bits < 8; ++bits) { + const V v0 = Set(d, T{(bits & (1u << 0)) ? 1u : 0u}); + const V v1 = Set(d, T{(bits & (1u << 1)) ? 1u : 0u}); + const V v2 = Set(d, T{(bits & (1u << 2)) ? 1u : 0u}); + const T m = GetLane(detail::MedianOf3(st, v0, v1, v2)); + // If at least half(rounded up) of bits are 1, so is the median. + const size_t count = PopCount(bits); + HWY_ASSERT_EQ((count >= 2) ? static_cast(1) : 0, m); + } +} + +HWY_NOINLINE void TestAllMedian() { + TestMedian3 >(); +} + +template +static HWY_NOINLINE void TestBaseCaseAscDesc() { + SharedTraits st; + const SortTag d; + const size_t N = Lanes(d); + const size_t base_case_num = SortConstants::BaseCaseNum(N); + const size_t N1 = st.LanesPerKey(); + + constexpr int kDebug = 0; + auto aligned_keys = hwy::AllocateAligned(N + base_case_num + N); + auto buf = hwy::AllocateAligned(base_case_num + 2 * N); + + std::vector lengths; + lengths.push_back(HWY_MAX(1, N1)); + lengths.push_back(3 * N1); + lengths.push_back(base_case_num / 2); + lengths.push_back(base_case_num / 2 + N1); + lengths.push_back(base_case_num - N1); + lengths.push_back(base_case_num); + + std::vector misalignments; + misalignments.push_back(0); + misalignments.push_back(1); + if (N >= 6) misalignments.push_back(N / 2 - 1); + misalignments.push_back(N / 2); + misalignments.push_back(N / 2 + 1); + misalignments.push_back(HWY_MIN(2 * N / 3 + 3, size_t{N - 1})); + + for (bool asc : {false, true}) { + for (size_t len : lengths) { + for (size_t misalign : misalignments) { + T* HWY_RESTRICT keys = aligned_keys.get() + misalign; + if (kDebug) { + printf("============%s asc %d N1 %d len %d misalign %d\n", + hwy::TypeName(T(), 1).c_str(), asc, static_cast(N1), + static_cast(len), static_cast(misalign)); + } + + for (size_t i = 0; i < misalign; ++i) { + aligned_keys[i] = hwy::LowestValue(); + } + InputStats input_stats; + for (size_t i = 0; i < len; ++i) { + keys[i] = + asc ? static_cast(T(i) + 1) : static_cast(T(len) - T(i)); + input_stats.Notify(keys[i]); + if (kDebug >= 2) printf("%3zu: %f\n", i, double(keys[i])); + } + for (size_t i = len; i < base_case_num + N; ++i) { + keys[i] = hwy::LowestValue(); + } + + detail::BaseCase(d, st, keys, len, buf.get()); + + if (kDebug >= 2) { + printf("out>>>>>>\n"); + for (size_t i = 0; i < len; ++i) { + printf("%3zu: %f\n", i, double(keys[i])); + } + } + + HWY_ASSERT(VerifySort(st, input_stats, keys, len, "BaseAscDesc")); + for (size_t i = 0; i < misalign; ++i) { + if (aligned_keys[i] != hwy::LowestValue()) + HWY_ABORT("Overrun misalign at %d\n", static_cast(i)); + } + for (size_t i = len; i < base_case_num + N; ++i) { + if (keys[i] != hwy::LowestValue()) + HWY_ABORT("Overrun right at %d\n", static_cast(i)); + } + } // misalign + } // len + } // asc +} + +template +static HWY_NOINLINE void TestBaseCase01() { + SharedTraits st; + const SortTag d; + const size_t N = Lanes(d); + const size_t base_case_num = SortConstants::BaseCaseNum(N); + const size_t N1 = st.LanesPerKey(); + + constexpr int kDebug = 0; + auto keys = hwy::AllocateAligned(base_case_num + N); + auto buf = hwy::AllocateAligned(base_case_num + 2 * N); + + std::vector lengths; + lengths.push_back(HWY_MAX(1, N1)); + lengths.push_back(3 * N1); + lengths.push_back(base_case_num / 2); + lengths.push_back(base_case_num / 2 + N1); + lengths.push_back(base_case_num - N1); + lengths.push_back(base_case_num); + + for (size_t len : lengths) { + if (kDebug) { + printf("============%s 01 N1 %d len %d\n", hwy::TypeName(T(), 1).c_str(), + static_cast(N1), static_cast(len)); + } + const uint64_t kMaxBits = AdjustedLog2Reps(HWY_MIN(len, size_t{14})); + for (uint64_t bits = 0; bits < ((1ull << kMaxBits) - 1); ++bits) { + InputStats input_stats; + for (size_t i = 0; i < len; ++i) { + keys[i] = (i < 64 && (bits & (1ull << i))) ? 1 : 0; + input_stats.Notify(keys[i]); + if (kDebug >= 2) printf("%3zu: %f\n", i, double(keys[i])); + } + for (size_t i = len; i < base_case_num + N; ++i) { + keys[i] = hwy::LowestValue(); + } + + detail::BaseCase(d, st, keys.get(), len, buf.get()); + + if (kDebug >= 2) { + printf("out>>>>>>\n"); + for (size_t i = 0; i < len; ++i) { + printf("%3zu: %f\n", i, double(keys[i])); + } + } + + HWY_ASSERT(VerifySort(st, input_stats, keys.get(), len, "Base01")); + for (size_t i = len; i < base_case_num + N; ++i) { + if (keys[i] != hwy::LowestValue()) + HWY_ABORT("Overrun right at %d\n", static_cast(i)); + } + } // bits + } // len +} + +template +static HWY_NOINLINE void TestBaseCase() { + TestBaseCaseAscDesc(); + TestBaseCase01(); +} + +HWY_NOINLINE void TestAllBaseCase() { + // Workaround for stack overflow on MSVC debug. +#if defined(_MSC_VER) && HWY_IS_DEBUG_BUILD && (HWY_TARGET == HWY_AVX3) + return; +#endif + + TestBaseCase, int32_t>(); + TestBaseCase, int64_t>(); + TestBaseCase, uint64_t>(); + TestBaseCase, uint64_t>(); +} + +template +static HWY_NOINLINE void VerifyPartition(Traits st, T* HWY_RESTRICT keys, + size_t left, size_t border, + size_t right, const size_t N1, + const T* pivot) { + /* for (size_t i = left; i < right; ++i) { + if (i == border) printf("--\n"); + printf("%4zu: %3d\n", i, keys[i]); + }*/ + + HWY_ASSERT(left % N1 == 0); + HWY_ASSERT(border % N1 == 0); + HWY_ASSERT(right % N1 == 0); + const bool asc = typename Traits::Order().IsAscending(); + for (size_t i = left; i < border; i += N1) { + if (st.Compare1(pivot, keys + i)) { + HWY_ABORT( + "%s: asc %d left[%d] piv %.0f %.0f compares before %.0f %.0f " + "border %d", + hwy::TypeName(T(), 1).c_str(), asc, static_cast(i), + double(pivot[1]), double(pivot[0]), double(keys[i + 1]), + double(keys[i + 0]), static_cast(border)); + } + } + for (size_t i = border; i < right; i += N1) { + if (!st.Compare1(pivot, keys + i)) { + HWY_ABORT( + "%s: asc %d right[%d] piv %.0f %.0f compares after %.0f %.0f " + "border %d", + hwy::TypeName(T(), 1).c_str(), asc, static_cast(i), + double(pivot[1]), double(pivot[0]), double(keys[i + 1]), + double(keys[i]), static_cast(border)); + } + } +} + +template +static HWY_NOINLINE void TestPartition() { + const SortTag d; + SharedTraits st; + const bool asc = typename Traits::Order().IsAscending(); + const size_t N = Lanes(d); + constexpr int kDebug = 0; + const size_t base_case_num = SortConstants::BaseCaseNum(N); + // left + len + align + const size_t total = 32 + (base_case_num + 4 * HWY_MAX(N, 4)) + 2 * N; + auto aligned_keys = hwy::AllocateAligned(total); + auto buf = hwy::AllocateAligned(SortConstants::PartitionBufNum(N)); + + const size_t N1 = st.LanesPerKey(); + for (bool in_asc : {false, true}) { + for (int left_i : {0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 15, 22, 28, 29, 30, 31}) { + const size_t left = static_cast(left_i) & ~(N1 - 1); + for (size_t ofs : {N, N + 1, N + 2, N + 3, 2 * N, 2 * N + 1, 2 * N + 2, + 2 * N + 3, 3 * N - 1, 4 * N - 3, 4 * N - 2}) { + const size_t len = (base_case_num + ofs) & ~(N1 - 1); + for (T pivot1 : + {T(0), T(len / 3), T(len / 2), T(2 * len / 3), T(len)}) { + const T pivot2[2] = {pivot1, 0}; + const auto pivot = st.SetKey(d, pivot2); + for (size_t misalign = 0; misalign < N; + misalign += st.LanesPerKey()) { + T* HWY_RESTRICT keys = aligned_keys.get() + misalign; + const size_t right = left + len; + if (kDebug) { + printf( + "=========%s asc %d left %d len %d right %d piv %.0f %.0f\n", + hwy::TypeName(T(), 1).c_str(), asc, static_cast(left), + static_cast(len), static_cast(right), + double(pivot2[1]), double(pivot2[0])); + } + + for (size_t i = 0; i < misalign; ++i) { + aligned_keys[i] = hwy::LowestValue(); + } + for (size_t i = 0; i < left; ++i) { + keys[i] = hwy::LowestValue(); + } + for (size_t i = left; i < right; ++i) { + keys[i] = static_cast(in_asc ? T(i + 1) - static_cast(left) + : static_cast(right) - T(i)); + if (kDebug >= 2) printf("%3zu: %f\n", i, double(keys[i])); + } + for (size_t i = right; i < total - misalign; ++i) { + keys[i] = hwy::LowestValue(); + } + + size_t border = + detail::Partition(d, st, keys, left, right, pivot, buf.get()); + + if (kDebug >= 2) { + printf("out>>>>>>\n"); + for (size_t i = left; i < right; ++i) { + printf("%3zu: %f\n", i, double(keys[i])); + } + for (size_t i = right; i < total - misalign; ++i) { + printf("%3zu: sentinel %f\n", i, double(keys[i])); + } + } + + VerifyPartition(st, keys, left, border, right, N1, pivot2); + for (size_t i = 0; i < misalign; ++i) { + if (aligned_keys[i] != hwy::LowestValue()) + HWY_ABORT("Overrun misalign at %d\n", static_cast(i)); + } + for (size_t i = 0; i < left; ++i) { + if (keys[i] != hwy::LowestValue()) + HWY_ABORT("Overrun left at %d\n", static_cast(i)); + } + for (size_t i = right; i < total - misalign; ++i) { + if (keys[i] != hwy::LowestValue()) + HWY_ABORT("Overrun right at %d\n", static_cast(i)); + } + } // misalign + } // pivot + } // len + } // left + } // asc +} + +HWY_NOINLINE void TestAllPartition() { + TestPartition, int16_t>(); + TestPartition, int32_t>(); + TestPartition, int64_t>(); + TestPartition, float>(); +#if HWY_HAVE_FLOAT64 + TestPartition, double>(); +#endif + TestPartition, uint64_t>(); + TestPartition, uint64_t>(); +} + +// (used for sample selection for choosing a pivot) +template +static HWY_NOINLINE void TestRandomGenerator() { + static_assert(!hwy::IsSigned(), ""); + SortTag du; + const size_t N = Lanes(du); + + detail::Generator rng(&N, N); + + const size_t lanes_per_block = HWY_MAX(64 / sizeof(TU), N); // power of two + + for (uint32_t num_blocks = 2; num_blocks < 100000; + num_blocks = 3 * num_blocks / 2) { + // Generate some numbers and ensure all are in range + uint64_t sum = 0; + constexpr size_t kReps = 10000; + for (size_t rep = 0; rep < kReps; ++rep) { + const uint32_t bits = rng() & 0xFFFFFFFF; + const size_t index = detail::RandomChunkIndex(num_blocks, bits); + HWY_ASSERT(((index + 1) * lanes_per_block) <= + num_blocks * lanes_per_block); + + sum += index; + } + + // Also ensure the mean is near the middle of the range + const double expected = (num_blocks - 1) / 2.0; + const double actual = double(sum) / kReps; + HWY_ASSERT(0.9 * expected <= actual && actual <= 1.1 * expected); + } +} + +HWY_NOINLINE void TestAllGenerator() { + TestRandomGenerator(); + TestRandomGenerator(); +} + +#endif // VQSORT_TEST_IMPL + +#if !VQSORT_TEST_SORT +static void TestAllSort() {} +#else + +// Remembers input, and compares results to that of a reference algorithm. +template +class CompareResults { + public: + void SetInput(const T* in, size_t num) { + copy_.resize(num); + memcpy(copy_.data(), in, num * sizeof(T)); + } + + bool Verify(const T* output) { +#if HAVE_PDQSORT + const Algo reference = Algo::kPDQ; +#else + const Algo reference = Algo::kStd; +#endif + SharedState shared; + using Order = typename Traits::Order; + Run(reference, copy_.data(), copy_.size(), shared, + /*thread=*/0); + + for (size_t i = 0; i < copy_.size(); ++i) { + if (copy_[i] != output[i]) { + fprintf(stderr, "Asc %d mismatch at %d: %A %A\n", Order().IsAscending(), + static_cast(i), double(copy_[i]), double(output[i])); + return false; + } + } + return true; + } + + private: + std::vector copy_; +}; + +std::vector AlgoForTest() { + return { +#if HAVE_AVX2SORT + Algo::kSEA, +#endif +#if HAVE_IPS4O + Algo::kIPS4O, +#endif +#if HAVE_PDQSORT + Algo::kPDQ, +#endif +#if HAVE_SORT512 + Algo::kSort512, +#endif + Algo::kHeap, Algo::kVQSort, + }; +} + +template +void TestSort(size_t num) { + // TODO(janwas): fix + if (HWY_TARGET == HWY_SSSE3) return; +// Workaround for stack overflow on clang-cl (/F 8388608 does not help). +#if defined(_MSC_VER) && HWY_IS_DEBUG_BUILD && (HWY_TARGET == HWY_AVX3) + return; +#endif + + SharedState shared; + SharedTraits st; + + constexpr size_t kMaxMisalign = 16; + auto aligned = hwy::AllocateAligned(kMaxMisalign + num + kMaxMisalign); + for (Algo algo : AlgoForTest()) { +#if HAVE_IPS4O + if (st.Is128() && (algo == Algo::kIPS4O || algo == Algo::kParallelIPS4O)) { + continue; + } +#endif + for (Dist dist : AllDist()) { + for (size_t misalign : {size_t{0}, size_t{st.LanesPerKey()}, + size_t{3 * st.LanesPerKey()}, kMaxMisalign / 2}) { + T* keys = aligned.get() + misalign; + + // Set up red zones before/after the keys to sort + for (size_t i = 0; i < misalign; ++i) { + aligned[i] = hwy::LowestValue(); + } + for (size_t i = 0; i < kMaxMisalign; ++i) { + keys[num + i] = hwy::HighestValue(); + } +#if HWY_IS_MSAN + __msan_poison(aligned.get(), misalign * sizeof(T)); + __msan_poison(keys + num, kMaxMisalign * sizeof(T)); +#endif + InputStats input_stats = GenerateInput(dist, keys, num); + + CompareResults compare; + compare.SetInput(keys, num); + + Run(algo, keys, num, shared, /*thread=*/0); + HWY_ASSERT(compare.Verify(keys)); + HWY_ASSERT(VerifySort(st, input_stats, keys, num, "TestSort")); + + // Check red zones +#if HWY_IS_MSAN + __msan_unpoison(aligned.get(), misalign * sizeof(T)); + __msan_unpoison(keys + num, kMaxMisalign * sizeof(T)); +#endif + for (size_t i = 0; i < misalign; ++i) { + if (aligned[i] != hwy::LowestValue()) + HWY_ABORT("Overrun left at %d\n", static_cast(i)); + } + for (size_t i = num; i < num + kMaxMisalign; ++i) { + if (keys[i] != hwy::HighestValue()) + HWY_ABORT("Overrun right at %d\n", static_cast(i)); + } + } // misalign + } // dist + } // algo +} + +void TestAllSort() { + const size_t num = 15 * 1000; + + TestSort, int16_t>(num); + TestSort, uint16_t>(num); + + TestSort, int32_t>(num); + TestSort, uint32_t>(num); + + TestSort, int64_t>(num); + TestSort, uint64_t>(num); + + // WARNING: for float types, SIMD comparisons will flush denormals to zero, + // causing mismatches with scalar sorts. In this test, we avoid generating + // denormal inputs. + TestSort, float>(num); +#if HWY_HAVE_FLOAT64 // protects algo-inl's GenerateRandom + if (Sorter::HaveFloat64()) { + TestSort, double>(num); + } +#endif + + TestSort, uint64_t>(num); + TestSort, uint64_t>(num); +} + +#endif // VQSORT_TEST_SORT + +} // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy @@ -174,9 +567,14 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace hwy { +namespace { HWY_BEFORE_TEST(SortTest); -HWY_EXPORT_AND_TEST_P(SortTest, TestAllReverse); -HWY_EXPORT_AND_TEST_P(SortTest, TestAllRanges); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllMedian); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllBaseCase); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllPartition); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllGenerator); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllSort); +} // namespace } // namespace hwy // Ought not to be necessary, but without this, no tests run on RVV. @@ -185,4 +583,4 @@ int main(int argc, char** argv) { return RUN_ALL_TESTS(); } -#endif +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h b/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h new file mode 100644 index 000000000000..ac50650aaec2 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h @@ -0,0 +1,686 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#endif + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/shared-inl.h" // SortConstants +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +using Constants = hwy::SortConstants; + +// ------------------------------ SharedTraits + +// Code shared between all traits. It's unclear whether these can profitably be +// specialized for Lane vs Block, or optimized like SortPairsDistance1 using +// Compare/DupOdd. +template +struct SharedTraits : public Base { + // Conditionally swaps lane 0 with 2, 1 with 3 etc. + template + HWY_INLINE Vec SortPairsDistance2(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->SwapAdjacentPairs(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template + HWY_INLINE Vec SortPairsReverse8(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys8(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template + HWY_INLINE Vec SortPairsReverse16(D d, Vec v) const { + const Base* base = static_cast(this); + static_assert(Constants::kMaxCols <= 16, "Need actual Reverse16"); + Vec swapped = base->ReverseKeys(d, v); + base->Sort2(d, v, swapped); + return ConcatUpperLower(d, swapped, v); // 8 = half of the vector + } +}; + +// ------------------------------ Sorting network + +// (Green's irregular) sorting network for independent columns in 16 vectors. +template > +HWY_INLINE void Sort16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + st.Sort2(d, v0, v2); + st.Sort2(d, v1, v3); + st.Sort2(d, v4, v6); + st.Sort2(d, v5, v7); + st.Sort2(d, v8, va); + st.Sort2(d, v9, vb); + st.Sort2(d, vc, ve); + st.Sort2(d, vd, vf); + st.Sort2(d, v0, v4); + st.Sort2(d, v1, v5); + st.Sort2(d, v2, v6); + st.Sort2(d, v3, v7); + st.Sort2(d, v8, vc); + st.Sort2(d, v9, vd); + st.Sort2(d, va, ve); + st.Sort2(d, vb, vf); + st.Sort2(d, v0, v8); + st.Sort2(d, v1, v9); + st.Sort2(d, v2, va); + st.Sort2(d, v3, vb); + st.Sort2(d, v4, vc); + st.Sort2(d, v5, vd); + st.Sort2(d, v6, ve); + st.Sort2(d, v7, vf); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v3, vc); + st.Sort2(d, v7, vb); + st.Sort2(d, vd, ve); + st.Sort2(d, v4, v8); + st.Sort2(d, v1, v2); + st.Sort2(d, v1, v4); + st.Sort2(d, v7, vd); + st.Sort2(d, v2, v8); + st.Sort2(d, vb, ve); + st.Sort2(d, v2, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vd); + st.Sort2(d, v3, v8); + st.Sort2(d, v7, vc); + st.Sort2(d, v3, v5); + st.Sort2(d, v6, v8); + st.Sort2(d, v7, v9); + st.Sort2(d, va, vc); + st.Sort2(d, v3, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v7, v8); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vc); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); +} + +// ------------------------------ Merging networks + +// Blacher's hybrid bitonic/odd-even networks, generated by print_network.cc. + +template > +HWY_INLINE void Merge2(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + v8 = st.ReverseKeys2(d, v8); + v9 = st.ReverseKeys2(d, v9); + va = st.ReverseKeys2(d, va); + vb = st.ReverseKeys2(d, vb); + vc = st.ReverseKeys2(d, vc); + vd = st.ReverseKeys2(d, vd); + ve = st.ReverseKeys2(d, ve); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys2(d, v4); + vc = st.ReverseKeys2(d, vc); + v5 = st.ReverseKeys2(d, v5); + vd = st.ReverseKeys2(d, vd); + v6 = st.ReverseKeys2(d, v6); + ve = st.ReverseKeys2(d, ve); + v7 = st.ReverseKeys2(d, v7); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys2(d, v2); + v3 = st.ReverseKeys2(d, v3); + v6 = st.ReverseKeys2(d, v6); + v7 = st.ReverseKeys2(d, v7); + va = st.ReverseKeys2(d, va); + vb = st.ReverseKeys2(d, vb); + ve = st.ReverseKeys2(d, ve); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys2(d, v1); + v3 = st.ReverseKeys2(d, v3); + v5 = st.ReverseKeys2(d, v5); + v7 = st.ReverseKeys2(d, v7); + v9 = st.ReverseKeys2(d, v9); + vb = st.ReverseKeys2(d, vb); + vd = st.ReverseKeys2(d, vd); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template > +HWY_INLINE void Merge4(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + v8 = st.ReverseKeys4(d, v8); + v9 = st.ReverseKeys4(d, v9); + va = st.ReverseKeys4(d, va); + vb = st.ReverseKeys4(d, vb); + vc = st.ReverseKeys4(d, vc); + vd = st.ReverseKeys4(d, vd); + ve = st.ReverseKeys4(d, ve); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys4(d, v4); + vc = st.ReverseKeys4(d, vc); + v5 = st.ReverseKeys4(d, v5); + vd = st.ReverseKeys4(d, vd); + v6 = st.ReverseKeys4(d, v6); + ve = st.ReverseKeys4(d, ve); + v7 = st.ReverseKeys4(d, v7); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys4(d, v2); + v3 = st.ReverseKeys4(d, v3); + v6 = st.ReverseKeys4(d, v6); + v7 = st.ReverseKeys4(d, v7); + va = st.ReverseKeys4(d, va); + vb = st.ReverseKeys4(d, vb); + ve = st.ReverseKeys4(d, ve); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys4(d, v1); + v3 = st.ReverseKeys4(d, v3); + v5 = st.ReverseKeys4(d, v5); + v7 = st.ReverseKeys4(d, v7); + v9 = st.ReverseKeys4(d, v9); + vb = st.ReverseKeys4(d, vb); + vd = st.ReverseKeys4(d, vd); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsReverse4(d, v0); + v1 = st.SortPairsReverse4(d, v1); + v2 = st.SortPairsReverse4(d, v2); + v3 = st.SortPairsReverse4(d, v3); + v4 = st.SortPairsReverse4(d, v4); + v5 = st.SortPairsReverse4(d, v5); + v6 = st.SortPairsReverse4(d, v6); + v7 = st.SortPairsReverse4(d, v7); + v8 = st.SortPairsReverse4(d, v8); + v9 = st.SortPairsReverse4(d, v9); + va = st.SortPairsReverse4(d, va); + vb = st.SortPairsReverse4(d, vb); + vc = st.SortPairsReverse4(d, vc); + vd = st.SortPairsReverse4(d, vd); + ve = st.SortPairsReverse4(d, ve); + vf = st.SortPairsReverse4(d, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template > +HWY_INLINE void Merge8(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + v8 = st.ReverseKeys8(d, v8); + v9 = st.ReverseKeys8(d, v9); + va = st.ReverseKeys8(d, va); + vb = st.ReverseKeys8(d, vb); + vc = st.ReverseKeys8(d, vc); + vd = st.ReverseKeys8(d, vd); + ve = st.ReverseKeys8(d, ve); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys8(d, v4); + vc = st.ReverseKeys8(d, vc); + v5 = st.ReverseKeys8(d, v5); + vd = st.ReverseKeys8(d, vd); + v6 = st.ReverseKeys8(d, v6); + ve = st.ReverseKeys8(d, ve); + v7 = st.ReverseKeys8(d, v7); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys8(d, v2); + v3 = st.ReverseKeys8(d, v3); + v6 = st.ReverseKeys8(d, v6); + v7 = st.ReverseKeys8(d, v7); + va = st.ReverseKeys8(d, va); + vb = st.ReverseKeys8(d, vb); + ve = st.ReverseKeys8(d, ve); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys8(d, v1); + v3 = st.ReverseKeys8(d, v3); + v5 = st.ReverseKeys8(d, v5); + v7 = st.ReverseKeys8(d, v7); + v9 = st.ReverseKeys8(d, v9); + vb = st.ReverseKeys8(d, vb); + vd = st.ReverseKeys8(d, vd); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsReverse8(d, v0); + v1 = st.SortPairsReverse8(d, v1); + v2 = st.SortPairsReverse8(d, v2); + v3 = st.SortPairsReverse8(d, v3); + v4 = st.SortPairsReverse8(d, v4); + v5 = st.SortPairsReverse8(d, v5); + v6 = st.SortPairsReverse8(d, v6); + v7 = st.SortPairsReverse8(d, v7); + v8 = st.SortPairsReverse8(d, v8); + v9 = st.SortPairsReverse8(d, v9); + va = st.SortPairsReverse8(d, va); + vb = st.SortPairsReverse8(d, vb); + vc = st.SortPairsReverse8(d, vc); + vd = st.SortPairsReverse8(d, vd); + ve = st.SortPairsReverse8(d, ve); + vf = st.SortPairsReverse8(d, vf); + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +// Unused on MSVC, see below +#if !HWY_COMPILER_MSVC + +template > +HWY_INLINE void Merge16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, + V& vd, V& ve, V& vf) { + v8 = st.ReverseKeys16(d, v8); + v9 = st.ReverseKeys16(d, v9); + va = st.ReverseKeys16(d, va); + vb = st.ReverseKeys16(d, vb); + vc = st.ReverseKeys16(d, vc); + vd = st.ReverseKeys16(d, vd); + ve = st.ReverseKeys16(d, ve); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys16(d, v4); + vc = st.ReverseKeys16(d, vc); + v5 = st.ReverseKeys16(d, v5); + vd = st.ReverseKeys16(d, vd); + v6 = st.ReverseKeys16(d, v6); + ve = st.ReverseKeys16(d, ve); + v7 = st.ReverseKeys16(d, v7); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys16(d, v2); + v3 = st.ReverseKeys16(d, v3); + v6 = st.ReverseKeys16(d, v6); + v7 = st.ReverseKeys16(d, v7); + va = st.ReverseKeys16(d, va); + vb = st.ReverseKeys16(d, vb); + ve = st.ReverseKeys16(d, ve); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys16(d, v1); + v3 = st.ReverseKeys16(d, v3); + v5 = st.ReverseKeys16(d, v5); + v7 = st.ReverseKeys16(d, v7); + v9 = st.ReverseKeys16(d, v9); + vb = st.ReverseKeys16(d, vb); + vd = st.ReverseKeys16(d, vd); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsReverse16(d, v0); + v1 = st.SortPairsReverse16(d, v1); + v2 = st.SortPairsReverse16(d, v2); + v3 = st.SortPairsReverse16(d, v3); + v4 = st.SortPairsReverse16(d, v4); + v5 = st.SortPairsReverse16(d, v5); + v6 = st.SortPairsReverse16(d, v6); + v7 = st.SortPairsReverse16(d, v7); + v8 = st.SortPairsReverse16(d, v8); + v9 = st.SortPairsReverse16(d, v9); + va = st.SortPairsReverse16(d, va); + vb = st.SortPairsReverse16(d, vb); + vc = st.SortPairsReverse16(d, vc); + vd = st.SortPairsReverse16(d, vd); + ve = st.SortPairsReverse16(d, ve); + vf = st.SortPairsReverse16(d, vf); + v0 = st.SortPairsDistance4(d, v0); + v1 = st.SortPairsDistance4(d, v1); + v2 = st.SortPairsDistance4(d, v2); + v3 = st.SortPairsDistance4(d, v3); + v4 = st.SortPairsDistance4(d, v4); + v5 = st.SortPairsDistance4(d, v5); + v6 = st.SortPairsDistance4(d, v6); + v7 = st.SortPairsDistance4(d, v7); + v8 = st.SortPairsDistance4(d, v8); + v9 = st.SortPairsDistance4(d, v9); + va = st.SortPairsDistance4(d, va); + vb = st.SortPairsDistance4(d, vb); + vc = st.SortPairsDistance4(d, vc); + vd = st.SortPairsDistance4(d, vd); + ve = st.SortPairsDistance4(d, ve); + vf = st.SortPairsDistance4(d, vf); + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +#endif // !HWY_COMPILER_MSVC + +// Reshapes `buf` into a matrix, sorts columns independently, and then merges +// into a sorted 1D array without transposing. +// +// `st` is SharedTraits>. This abstraction layer +// bridges differences in sort order and single-lane vs 128-bit keys. +// `buf` ensures full vectors are aligned, and enables loads/stores without +// bounds checks. +// +// References: +// https://drops.dagstuhl.de/opus/volltexte/2021/13775/pdf/LIPIcs-SEA-2021-3.pdf +// https://github.com/simd-sorting/fast-and-robust/blob/master/avx2_sort_demo/avx2sort.h +// "Entwurf und Implementierung vektorisierter Sortieralgorithmen" (M. Blacher) +template +HWY_INLINE void SortingNetwork(Traits st, T* HWY_RESTRICT buf, size_t cols) { + const CappedTag d; + using V = decltype(Zero(d)); + + HWY_DASSERT(cols <= Constants::kMaxCols); + + // The network width depends on the number of keys, not lanes. + constexpr size_t kLanesPerKey = st.LanesPerKey(); + const size_t keys = cols / kLanesPerKey; + constexpr size_t kMaxKeys = MaxLanes(d) / kLanesPerKey; + + // These are aligned iff cols == Lanes(d). We prefer unaligned/non-constexpr + // offsets to duplicating this code for every value of cols. + static_assert(Constants::kMaxRows == 16, "Update loads/stores/args"); + V v0 = LoadU(d, buf + 0x0 * cols); + V v1 = LoadU(d, buf + 0x1 * cols); + V v2 = LoadU(d, buf + 0x2 * cols); + V v3 = LoadU(d, buf + 0x3 * cols); + V v4 = LoadU(d, buf + 0x4 * cols); + V v5 = LoadU(d, buf + 0x5 * cols); + V v6 = LoadU(d, buf + 0x6 * cols); + V v7 = LoadU(d, buf + 0x7 * cols); + V v8 = LoadU(d, buf + 0x8 * cols); + V v9 = LoadU(d, buf + 0x9 * cols); + V va = LoadU(d, buf + 0xa * cols); + V vb = LoadU(d, buf + 0xb * cols); + V vc = LoadU(d, buf + 0xc * cols); + V vd = LoadU(d, buf + 0xd * cols); + V ve = LoadU(d, buf + 0xe * cols); + V vf = LoadU(d, buf + 0xf * cols); + + Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf); + + // Checking MaxLanes avoids generating HWY_ASSERT code for the unreachable + // code paths: if MaxLanes < 2, then keys <= cols < 2. + if (HWY_LIKELY(keys >= 2 && kMaxKeys >= 2)) { + Merge2(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, + vf); + + if (HWY_LIKELY(keys >= 4 && kMaxKeys >= 4)) { + Merge4(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, + vf); + + if (HWY_LIKELY(keys >= 8 && kMaxKeys >= 8)) { + Merge8(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, + ve, vf); + + // Avoids build timeout +#if !HWY_COMPILER_MSVC + if (HWY_LIKELY(keys >= 16 && kMaxKeys >= 16)) { + Merge16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, + ve, vf); + + static_assert(Constants::kMaxCols <= 16, "Add more branches"); + } +#endif + } + } + } + + StoreU(v0, d, buf + 0x0 * cols); + StoreU(v1, d, buf + 0x1 * cols); + StoreU(v2, d, buf + 0x2 * cols); + StoreU(v3, d, buf + 0x3 * cols); + StoreU(v4, d, buf + 0x4 * cols); + StoreU(v5, d, buf + 0x5 * cols); + StoreU(v6, d, buf + 0x6 * cols); + StoreU(v7, d, buf + 0x7 * cols); + StoreU(v8, d, buf + 0x8 * cols); + StoreU(v9, d, buf + 0x9 * cols); + StoreU(va, d, buf + 0xa * cols); + StoreU(vb, d, buf + 0xb * cols); + StoreU(vc, d, buf + 0xc * cols); + StoreU(vd, d, buf + 0xd * cols); + StoreU(ve, d, buf + 0xe * cols); + StoreU(vf, d, buf + 0xf * cols); +} + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/traits-inl.h b/third_party/highway/hwy/contrib/sort/traits-inl.h new file mode 100644 index 000000000000..dc0553c15bee --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/traits-inl.h @@ -0,0 +1,324 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#endif + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/shared-inl.h" // SortConstants +#include "hwy/contrib/sort/vqsort.h" // SortDescending +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +struct KeyLane { + constexpr size_t LanesPerKey() const { return 1; } + + // For HeapSort + template + HWY_INLINE void Swap(T* a, T* b) const { + const T temp = *a; + *a = *b; + *b = temp; + } + + // Broadcasts one key into a vector + template + HWY_INLINE Vec SetKey(D d, const TFromD* key) const { + return Set(d, *key); + } + + template + HWY_INLINE Vec ReverseKeys(D d, Vec v) const { + return Reverse(d, v); + } + + template + HWY_INLINE Vec ReverseKeys2(D d, Vec v) const { + return Reverse2(d, v); + } + + template + HWY_INLINE Vec ReverseKeys4(D d, Vec v) const { + return Reverse4(d, v); + } + + template + HWY_INLINE Vec ReverseKeys8(D d, Vec v) const { + return Reverse8(d, v); + } + + template + HWY_INLINE Vec ReverseKeys16(D d, Vec v) const { + static_assert(SortConstants::kMaxCols <= 16, "Assumes u32x16 = 512 bit"); + return ReverseKeys(d, v); + } + + template + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEven(odd, even); + } + + template + HWY_INLINE Vec SwapAdjacentPairs(D d, const Vec v) const { + const Repartition du32; + return BitCast(d, Shuffle2301(BitCast(du32, v))); + } + template + HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { + return Shuffle1032(v); + } + template + HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { + return SwapAdjacentBlocks(v); + } + + template + HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide> dw; +#endif + return BitCast(d, SwapAdjacentPairs(dw, BitCast(dw, v))); + } + template + HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { + // Assumes max vector size = 512 + return ConcatLowerUpper(d, v, v); + } + + template + HWY_INLINE Vec OddEvenPairs(D d, const Vec odd, + const Vec even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide> dw; +#endif + return BitCast(d, OddEven(BitCast(dw, odd), BitCast(dw, even))); + } + template + HWY_INLINE Vec OddEvenPairs(D /* tag */, Vec odd, Vec even) const { + return OddEvenBlocks(odd, even); + } + + template + HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide> dw; +#endif + return BitCast(d, OddEvenPairs(dw, BitCast(dw, odd), BitCast(dw, even))); + } + template + HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { + return ConcatUpperLower(d, odd, even); + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +struct OrderAscending : public KeyLane { + using Order = SortAscending; + + template + HWY_INLINE bool Compare1(const T* a, const T* b) { + return *a < *b; + } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(a, b); + } + + // Two halves of Sort2, used in ScanMinMax. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue>()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue>()); + } +}; + +struct OrderDescending : public KeyLane { + using Order = SortDescending; + + template + HWY_INLINE bool Compare1(const T* a, const T* b) { + return *b < *a; + } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(b, a); + } + + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue>()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue>()); + } +}; + +// Shared code that depends on Order. +template +struct LaneTraits : public Base { + constexpr bool Is128() const { return false; } + + // For each lane i: replaces a[i] with the first and b[i] with the second + // according to Base. + // Corresponds to a conditional swap, which is one "node" of a sorting + // network. Min/Max are cheaper than compare + blend at least for integers. + template + HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { + const Base* base = static_cast(this); + + const Vec a_copy = a; + // Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4 + // instructions. We can reduce it to a compare + 2 IfThenElse. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + if (sizeof(TFromD) == 8) { + const Mask cmp = base->Compare(d, a, b); + a = IfThenElse(cmp, a, b); + b = IfThenElse(cmp, b, a_copy); + return; + } +#endif + a = base->First(d, a, b); + b = base->Last(d, a_copy, b); + } + + // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + // Further to the above optimization, Sort2+OddEvenKeys compile to four + // instructions; we can save one by combining two blends. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + const Vec cmp = VecFromMask(d, base->Compare(d, v, swapped)); + return IfVecThenElse(DupOdd(cmp), swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); +#endif + } + + // (See above - we use Sort2 for non-64-bit types.) + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 4 keys. + template + HWY_INLINE Vec SortPairsReverse4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys4(d, v); + Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template + HWY_INLINE Vec SortPairsDistance4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->SwapAdjacentQuads(d, v); + // Only used in Merge16, so this will not be used on AVX2 (which only has 4 + // u64 lanes), so skip the above optimization for 64-bit AVX2. + Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } +}; + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/traits128-inl.h b/third_party/highway/hwy/contrib/sort/traits128-inl.h new file mode 100644 index 000000000000..08c3906d6552 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/traits128-inl.h @@ -0,0 +1,368 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#endif + +#include "hwy/contrib/sort/vqsort.h" // SortDescending +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if HWY_TARGET == HWY_SCALAR + +struct OrderAscending128 { + using Order = SortAscending; + + template + HWY_INLINE bool Compare1(const T* a, const T* b) { + return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1]; + } +}; + +struct OrderDescending128 { + using Order = SortDescending; + + template + HWY_INLINE bool Compare1(const T* a, const T* b) { + return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1]; + } +}; + +template +struct Traits128 : public Order { + constexpr bool Is128() const { return true; } + constexpr size_t LanesPerKey() const { return 2; } +}; + +#else + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +struct Key128 { + constexpr size_t LanesPerKey() const { return 2; } + + template + HWY_INLINE void Swap(T* a, T* b) const { + const FixedTag d; + const auto temp = LoadU(d, a); + StoreU(LoadU(d, b), d, a); + StoreU(temp, d, b); + } + + template + HWY_INLINE Vec SetKey(D d, const TFromD* key) const { + return LoadDup128(d, key); + } + + template + HWY_INLINE Vec ReverseKeys(D d, Vec v) const { + return ReverseBlocks(d, v); + } + + template + HWY_INLINE Vec ReverseKeys2(D /* tag */, const Vec v) const { + return SwapAdjacentBlocks(v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template + HWY_INLINE Vec ReverseKeys4(D d, const Vec v) const { + HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD)); + return ReverseKeys(d, v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template + HWY_INLINE Vec OddEvenPairs(D d, const Vec odd, + const Vec even) const { + HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD)); + return ConcatUpperLower(d, odd, even); + } + + template + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEvenBlocks(odd, even); + } + + template + HWY_INLINE Vec ReverseKeys8(D, Vec) const { + HWY_ASSERT(0); // not supported: would require 1024-bit vectors + } + + template + HWY_INLINE Vec ReverseKeys16(D, Vec) const { + HWY_ASSERT(0); // not supported: would require 2048-bit vectors + } + + // This is only called for 8/16 col networks (not supported). + template + HWY_INLINE Vec SwapAdjacentPairs(D, Vec) const { + HWY_ASSERT(0); + } + + // This is only called for 16 col networks (not supported). + template + HWY_INLINE Vec SwapAdjacentQuads(D, Vec) const { + HWY_ASSERT(0); + } + + // This is only called for 8 col networks (not supported). + template + HWY_INLINE Vec OddEvenQuads(D, Vec, Vec) const { + HWY_ASSERT(0); + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +struct OrderAscending128 : public Key128 { + using Order = SortAscending; + + template + HWY_INLINE bool Compare1(const T* a, const T* b) { + return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128(d, a, b); + } + + // Used by CompareTop + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(a, b); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Min128(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Max128(d, a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const size_t N = Lanes(d); + Store(v, d, buf); + v = SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = LanesPerKey(); i < N; i += LanesPerKey()) { + v = First(d, v, SetKey(d, buf + i)); + } + return v; + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const size_t N = Lanes(d); + Store(v, d, buf); + v = SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = LanesPerKey(); i < N; i += LanesPerKey()) { + v = Last(d, v, SetKey(d, buf + i)); + } + return v; + } + + // Same as for regular lanes because 128-bit lanes are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue >()); + } +}; + +struct OrderDescending128 : public Key128 { + using Order = SortDescending; + + template + HWY_INLINE bool Compare1(const T* a, const T* b) { + return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128(d, b, a); + } + + // Used by CompareTop + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(b, a); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Max128(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Min128(d, a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const size_t N = Lanes(d); + Store(v, d, buf); + v = SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = LanesPerKey(); i < N; i += LanesPerKey()) { + v = First(d, v, SetKey(d, buf + i)); + } + return v; + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const size_t N = Lanes(d); + Store(v, d, buf); + v = SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = LanesPerKey(); i < N; i += LanesPerKey()) { + v = Last(d, v, SetKey(d, buf + i)); + } + return v; + } + + // Same as for regular lanes because 128-bit lanes are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue >()); + } +}; + +// Shared code that depends on Order. +template +class Traits128 : public Base { +#if HWY_TARGET <= HWY_AVX2 + // Returns vector with only the top u64 lane valid. Useful when the next step + // is to replicate the mask anyway. + template + HWY_INLINE HWY_MAYBE_UNUSED Vec CompareTop(D d, Vec a, Vec b) const { + const Base* base = static_cast(this); + const Vec eqHL = VecFromMask(d, Eq(a, b)); + const Vec ltHL = VecFromMask(d, base->CompareLanes(a, b)); + const Vec ltLX = ShiftLeftLanes<1>(ltHL); + return OrAnd(ltHL, eqHL, ltLX); + } + + // We want to swap 2 u128, i.e. 4 u64 lanes, based on the 0 or FF..FF mask in + // the most-significant of those lanes (the result of CompareTop), so + // replicate it 4x. Only called for >= 256-bit vectors. + template + HWY_INLINE V ReplicateTop4x(V v) const { +#if HWY_TARGET <= HWY_AVX3 + return V{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +#else // AVX2 + return V{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +#endif + } +#endif + + public: + constexpr bool Is128() const { return true; } + + template + HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { + const Base* base = static_cast(this); + + const Vec a_copy = a; + const auto lt = base->Compare(d, a, b); + a = IfThenElse(lt, a, b); + b = IfThenElse(lt, b, a_copy); + } + + // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + +#if HWY_TARGET <= HWY_AVX2 + const Vec select = ReplicateTop4x(CompareTop(d, v, swapped)); + return IfVecThenElse(select, swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); +#endif + } + + // Swaps with the vector formed by reversing contiguous groups of 4 keys. + template + HWY_INLINE Vec SortPairsReverse4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys4(d, v); + + // Only specialize for AVX3 because this requires 512-bit vectors. +#if HWY_TARGET <= HWY_AVX3 + const Vec512 outHx = CompareTop(d, v, swapped); + // Similar to ReplicateTop4x, we want to gang together 2 comparison results + // (4 lanes). They are not contiguous, so use permute to replicate 4x. + alignas(64) uint64_t kIndices[8] = {7, 7, 5, 5, 5, 5, 7, 7}; + const Vec512 select = + TableLookupLanes(outHx, SetTableIndices(d, kIndices)); + return IfVecThenElse(select, swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); +#endif + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template + HWY_INLINE Vec SortPairsDistance4(D, Vec) const { + // Only used by Merge16, which would require 2048 bit vectors (unsupported). + HWY_ASSERT(0); + } +}; + +#endif // HWY_TARGET != HWY_SCALAR + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/vqsort-inl.h b/third_party/highway/hwy/contrib/sort/vqsort-inl.h new file mode 100644 index 000000000000..756cef383d47 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort-inl.h @@ -0,0 +1,722 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +// Makes it harder for adversaries to predict our sampling locations, at the +// cost of 1-2% increased runtime. +#ifndef VQSORT_SECURE_RNG +#define VQSORT_SECURE_RNG 0 +#endif + +#if VQSORT_SECURE_RNG +#include "third_party/absl/random/random.h" +#endif + +#include // memcpy + +#include "hwy/cache_control.h" // Prefetch +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" // Fill24Bytes + +#if HWY_IS_MSAN +#include +#endif + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#endif + +#include "hwy/contrib/sort/shared-inl.h" +#include "hwy/contrib/sort/sorting_networks-inl.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if HWY_TARGET == HWY_SCALAR + +template +void Swap(T* a, T* b) { + T t = *a; + *a = *b; + *b = t; +} + +// Scalar version of HeapSort (see below) +template +void HeapSort(Traits st, T* HWY_RESTRICT keys, const size_t num) { + if (num < 2) return; + + // Build heap. + for (size_t i = 1; i < num; i += 1) { + size_t j = i; + while (j != 0) { + const size_t idx_parent = ((j - 1) / 1 / 2); + if (!st.Compare1(keys + idx_parent, keys + j)) { + break; + } + Swap(keys + j, keys + idx_parent); + j = idx_parent; + } + } + + for (size_t i = num - 1; i != 0; i -= 1) { + // Swap root with last + Swap(keys + 0, keys + i); + + // Sift down the new root. + size_t j = 0; + while (j < i) { + const size_t left = 2 * j + 1; + const size_t right = 2 * j + 2; + if (left >= i) break; + size_t idx_larger = j; + if (st.Compare1(keys + j, keys + left)) { + idx_larger = left; + } + if (right < i && st.Compare1(keys + idx_larger, keys + right)) { + idx_larger = right; + } + if (idx_larger == j) break; + Swap(keys + j, keys + idx_larger); + j = idx_larger; + } + } +} + +#else + +using Constants = hwy::SortConstants; + +// ------------------------------ HeapSort + +// Heapsort: O(1) space, O(N*logN) worst-case comparisons. +// Based on LLVM sanitizer_common.h, licensed under Apache-2.0. +template +void HeapSort(Traits st, T* HWY_RESTRICT keys, const size_t num) { + constexpr size_t N1 = st.LanesPerKey(); + const FixedTag d; + + if (num < 2 * N1) return; + + // Build heap. + for (size_t i = N1; i < num; i += N1) { + size_t j = i; + while (j != 0) { + const size_t idx_parent = ((j - N1) / N1 / 2) * N1; + if (AllFalse(d, st.Compare(d, st.SetKey(d, keys + idx_parent), + st.SetKey(d, keys + j)))) { + break; + } + st.Swap(keys + j, keys + idx_parent); + j = idx_parent; + } + } + + for (size_t i = num - N1; i != 0; i -= N1) { + // Swap root with last + st.Swap(keys + 0, keys + i); + + // Sift down the new root. + size_t j = 0; + while (j < i) { + const size_t left = 2 * j + N1; + const size_t right = 2 * j + 2 * N1; + if (left >= i) break; + size_t idx_larger = j; + const auto key_j = st.SetKey(d, keys + j); + if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, keys + left)))) { + idx_larger = left; + } + if (right < i && AllTrue(d, st.Compare(d, st.SetKey(d, keys + idx_larger), + st.SetKey(d, keys + right)))) { + idx_larger = right; + } + if (idx_larger == j) break; + st.Swap(keys + j, keys + idx_larger); + j = idx_larger; + } + } +} + +// ------------------------------ BaseCase + +// Sorts `keys` within the range [0, num) via sorting network. +template +HWY_NOINLINE void BaseCase(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + using V = decltype(Zero(d)); + + // _Nonzero32 requires num - 1 != 0. + if (HWY_UNLIKELY(num <= 1)) return; + + // Reshape into a matrix with kMaxRows rows, and columns limited by the + // 1D `num`, which is upper-bounded by the vector width (see BaseCaseNum). + const size_t num_pow2 = size_t{1} + << (32 - Num0BitsAboveMS1Bit_Nonzero32( + static_cast(num - 1))); + HWY_DASSERT(num <= num_pow2 && num_pow2 <= Constants::BaseCaseNum(N)); + const size_t cols = + HWY_MAX(st.LanesPerKey(), num_pow2 >> Constants::kMaxRowsLog2); + HWY_DASSERT(cols <= N); + + // Copy `keys` to `buf`. + size_t i; + for (i = 0; i + N <= num; i += N) { + Store(LoadU(d, keys + i), d, buf + i); + } + for (; i < num; ++i) { + buf[i] = keys[i]; + } + + // Fill with padding - last in sort order, not copied to keys. + const V kPadding = st.LastValue(d); + // Initialize an extra vector because SortingNetwork loads full vectors, + // which may exceed cols*kMaxRows. + for (; i < (cols * Constants::kMaxRows + N); i += N) { + StoreU(kPadding, d, buf + i); + } + + SortingNetwork(st, buf, cols); + + for (i = 0; i + N <= num; i += N) { + StoreU(Load(d, buf + i), d, keys + i); + } + for (; i < num; ++i) { + keys[i] = buf[i]; + } +} + +// ------------------------------ Partition + +// Consumes from `left` until a multiple of kUnroll*N remains. +// Temporarily stores the right side into `buf`, then moves behind `right`. +template +HWY_NOINLINE void PartitionToMultipleOfUnroll(D d, Traits st, + T* HWY_RESTRICT keys, + size_t& left, size_t& right, + const Vec pivot, + T* HWY_RESTRICT buf) { + constexpr size_t kUnroll = Constants::kPartitionUnroll; + const size_t N = Lanes(d); + size_t readL = left; + size_t bufR = 0; + const size_t num = right - left; + // Partition requires both a multiple of kUnroll*N and at least + // 2*kUnroll*N for the initial loads. If less, consume all here. + const size_t num_rem = + (num < 2 * kUnroll * N) ? num : (num & (kUnroll * N - 1)); + size_t i = 0; + for (; i + N <= num_rem; i += N) { + const Vec vL = LoadU(d, keys + readL); + readL += N; + + const auto comp = st.Compare(d, pivot, vL); + left += CompressBlendedStore(vL, Not(comp), d, keys + left); + bufR += CompressStore(vL, comp, d, buf + bufR); + } + // Last iteration: only use valid lanes. + if (HWY_LIKELY(i != num_rem)) { + const auto mask = FirstN(d, num_rem - i); + const Vec vL = LoadU(d, keys + readL); + + const auto comp = st.Compare(d, pivot, vL); + left += CompressBlendedStore(vL, AndNot(comp, mask), d, keys + left); + bufR += CompressStore(vL, And(comp, mask), d, buf + bufR); + } + + // MSAN seems not to understand CompressStore. buf[0, bufR) are valid. +#if HWY_IS_MSAN + __msan_unpoison(buf, bufR * sizeof(T)); +#endif + + // Everything we loaded was put into buf, or behind the new `left`, after + // which there is space for bufR items. First move items from `right` to + // `left` to free up space, then copy `buf` into the vacated `right`. + // A loop with masked loads from `buf` is insufficient - we would also need to + // mask from `right`. Combining a loop with memcpy for the remainders is + // slower than just memcpy, so we use that for simplicity. + right -= bufR; + memcpy(keys + left, keys + right, bufR * sizeof(T)); + memcpy(keys + right, buf, bufR * sizeof(T)); +} + +template +HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec v, + const Vec pivot, T* HWY_RESTRICT keys, + size_t& writeL, size_t& writeR) { + const size_t N = Lanes(d); + + const auto comp = st.Compare(d, pivot, v); + const size_t num_left = CompressBlendedStore(v, Not(comp), d, keys + writeL); + writeL += num_left; + + writeR -= (N - num_left); + (void)CompressBlendedStore(v, comp, d, keys + writeR); +} + +template +HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec v0, + const Vec v1, const Vec v2, + const Vec v3, const Vec pivot, + T* HWY_RESTRICT keys, size_t& writeL, + size_t& writeR) { + StoreLeftRight(d, st, v0, pivot, keys, writeL, writeR); + StoreLeftRight(d, st, v1, pivot, keys, writeL, writeR); + StoreLeftRight(d, st, v2, pivot, keys, writeL, writeR); + StoreLeftRight(d, st, v3, pivot, keys, writeL, writeR); +} + +// Moves "<= pivot" keys to the front, and others to the back. pivot is +// broadcasted. Time-critical! +// +// Aligned loads do not seem to be worthwhile (not bottlenecked by load ports). +template +HWY_NOINLINE size_t Partition(D d, Traits st, T* HWY_RESTRICT keys, size_t left, + size_t right, const Vec pivot, + T* HWY_RESTRICT buf) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + // StoreLeftRight will CompressBlendedStore ending at `writeR`. Unless all + // lanes happen to be in the right-side partition, this will overrun `keys`, + // which triggers asan errors. Avoid by special-casing the last vector. + HWY_DASSERT(right - left > 2 * N); // ensured by HandleSpecialCases + right -= N; + const size_t last = right; + const V vlast = LoadU(d, keys + last); + + PartitionToMultipleOfUnroll(d, st, keys, left, right, pivot, buf); + constexpr size_t kUnroll = Constants::kPartitionUnroll; + + // Invariant: [left, writeL) and [writeR, right) are already partitioned. + size_t writeL = left; + size_t writeR = right; + + const size_t num = right - left; + // Cannot load if there were fewer than 2 * kUnroll * N. + if (HWY_LIKELY(num != 0)) { + HWY_DASSERT(num >= 2 * kUnroll * N); + HWY_DASSERT((num & (kUnroll * N - 1)) == 0); + + // Make space for writing in-place by reading from left and right. + const V vL0 = LoadU(d, keys + left + 0 * N); + const V vL1 = LoadU(d, keys + left + 1 * N); + const V vL2 = LoadU(d, keys + left + 2 * N); + const V vL3 = LoadU(d, keys + left + 3 * N); + left += kUnroll * N; + right -= kUnroll * N; + const V vR0 = LoadU(d, keys + right + 0 * N); + const V vR1 = LoadU(d, keys + right + 1 * N); + const V vR2 = LoadU(d, keys + right + 2 * N); + const V vR3 = LoadU(d, keys + right + 3 * N); + + // The left/right updates may consume all inputs, so check before the loop. + while (left != right) { + V v0, v1, v2, v3; + + // Free up capacity for writing by loading from the side that has less. + // Data-dependent but branching is faster than forcing branch-free. + const size_t capacityL = left - writeL; + const size_t capacityR = writeR - right; + HWY_DASSERT(capacityL <= num && capacityR <= num); // >= 0 + if (capacityR < capacityL) { + right -= kUnroll * N; + v0 = LoadU(d, keys + right + 0 * N); + v1 = LoadU(d, keys + right + 1 * N); + v2 = LoadU(d, keys + right + 2 * N); + v3 = LoadU(d, keys + right + 3 * N); + hwy::Prefetch(keys + right - 3 * kUnroll * N); + } else { + v0 = LoadU(d, keys + left + 0 * N); + v1 = LoadU(d, keys + left + 1 * N); + v2 = LoadU(d, keys + left + 2 * N); + v3 = LoadU(d, keys + left + 3 * N); + left += kUnroll * N; + hwy::Prefetch(keys + left + 3 * kUnroll * N); + } + + StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, writeR); + } + + // Now finish writing the initial left/right to the middle. + StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, writeR); + StoreLeftRight4(d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, writeR); + } + + // We have partitioned [left, right) such that writeL is the boundary. + HWY_DASSERT(writeL == writeR); + // Make space for inserting vlast: move up to N of the first right-side keys + // into the unused space starting at last. If we have fewer, ensure they are + // the last items in that vector by subtracting from the *load* address, + // which is safe because we have at least two vectors (checked above). + const size_t totalR = last - writeL; + const size_t startR = totalR < N ? writeL + totalR - N : writeL; + StoreU(LoadU(d, keys + startR), d, keys + last); + + // Partition vlast: write L, then R, into the single-vector gap at writeL. + const auto comp = st.Compare(d, pivot, vlast); + writeL += CompressBlendedStore(vlast, Not(comp), d, keys + writeL); + (void)CompressBlendedStore(vlast, comp, d, keys + writeL); + + return writeL; +} + +// ------------------------------ Pivot + +template +HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) { + const DFromV d; + // Slightly faster for 128-bit, apparently because not serially dependent. + if (st.Is128()) { + // Median = XOR-sum 'minus' the first and last. Calling First twice is + // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR. + const auto sum = Xor(Xor(v0, v1), v2); + const auto first = st.First(d, st.First(d, v0, v1), v2); + const auto last = st.Last(d, st.Last(d, v0, v1), v2); + return Xor(Xor(sum, first), last); + } + st.Sort2(d, v0, v2); + v1 = st.Last(d, v0, v1); + v1 = st.First(d, v1, v2); + return v1; +} + +// Replaces triplets with their median and recurses until less than 3 keys +// remain. Ignores leftover values (non-whole triplets)! +template +Vec RecursiveMedianOf3(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + constexpr size_t N1 = st.LanesPerKey(); + + if (num < 3 * N1) return st.SetKey(d, keys); + + size_t read = 0; + size_t written = 0; + + // Triplets of vectors + for (; read + 3 * N <= num; read += 3 * N) { + const auto v0 = Load(d, keys + read + 0 * N); + const auto v1 = Load(d, keys + read + 1 * N); + const auto v2 = Load(d, keys + read + 2 * N); + Store(MedianOf3(st, v0, v1, v2), d, buf + written); + written += N; + } + + // Triplets of keys + for (; read + 3 * N1 <= num; read += 3 * N1) { + const auto v0 = st.SetKey(d, keys + read + 0 * N1); + const auto v1 = st.SetKey(d, keys + read + 1 * N1); + const auto v2 = st.SetKey(d, keys + read + 2 * N1); + StoreU(MedianOf3(st, v0, v1, v2), d, buf + written); + written += N1; + } + + // Tail recursion; swap buffers + return RecursiveMedianOf3(d, st, buf, written, keys); +} + +#if VQSORT_SECURE_RNG +using Generator = absl::BitGen; +#else +// Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028 +#pragma pack(push, 1) +class Generator { + public: + Generator(const void* heap, size_t num) { + Sorter::Fill24Bytes(heap, num, &a_); + k_ = 1; // stream index: must be odd + } + + uint64_t operator()() { + const uint64_t b = b_; + w_ += k_; + const uint64_t next = a_ ^ w_; + a_ = (b + (b << 3)) ^ (b >> 11); + const uint64_t rot = (b << 24) | (b >> 40); + b_ = rot + next; + return next; + } + + private: + uint64_t a_; + uint64_t b_; + uint64_t w_; + uint64_t k_; // increment +}; +#pragma pack(pop) + +#endif // !VQSORT_SECURE_RNG + +// Returns slightly biased random index of a chunk in [0, num_chunks). +// See https://www.pcg-random.org/posts/bounded-rands.html. +HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) { + const uint64_t chunk_index = (static_cast(bits) * num_chunks) >> 32; + HWY_DASSERT(chunk_index < num_chunks); + return static_cast(chunk_index); +} + +template +HWY_NOINLINE Vec ChoosePivot(D d, Traits st, T* HWY_RESTRICT keys, + const size_t begin, const size_t end, + T* HWY_RESTRICT buf, Generator& rng) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + // Power of two + const size_t lanes_per_chunk = Constants::LanesPerChunk(sizeof(T), N); + + keys += begin; + size_t num = end - begin; + + // Align start of keys to chunks. We always have at least 2 chunks because the + // base case would have handled anything up to 16 vectors, i.e. >= 4 chunks. + HWY_DASSERT(num >= 2 * lanes_per_chunk); + const size_t misalign = + (reinterpret_cast(keys) / sizeof(T)) & (lanes_per_chunk - 1); + if (misalign != 0) { + const size_t consume = lanes_per_chunk - misalign; + keys += consume; + num -= consume; + } + + // Generate enough random bits for 9 uint32 + uint64_t* bits64 = reinterpret_cast(buf); + for (size_t i = 0; i < 5; ++i) { + bits64[i] = rng(); + } + const uint32_t* bits = reinterpret_cast(buf); + + const uint32_t lpc32 = static_cast(lanes_per_chunk); + // Avoid division + const size_t log2_lpc = Num0BitsBelowLS1Bit_Nonzero32(lpc32); + const size_t num_chunks64 = num >> log2_lpc; + // Clamp to uint32 for RandomChunkIndex + const uint32_t num_chunks = + static_cast(HWY_MIN(num_chunks64, 0xFFFFFFFFull)); + + const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) << log2_lpc; + const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) << log2_lpc; + const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) << log2_lpc; + const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) << log2_lpc; + const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) << log2_lpc; + const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) << log2_lpc; + const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) << log2_lpc; + const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) << log2_lpc; + const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) << log2_lpc; + for (size_t i = 0; i < lanes_per_chunk; i += N) { + const V v0 = Load(d, keys + offset0 + i); + const V v1 = Load(d, keys + offset1 + i); + const V v2 = Load(d, keys + offset2 + i); + const V medians0 = MedianOf3(st, v0, v1, v2); + Store(medians0, d, buf + i); + + const V v3 = Load(d, keys + offset3 + i); + const V v4 = Load(d, keys + offset4 + i); + const V v5 = Load(d, keys + offset5 + i); + const V medians1 = MedianOf3(st, v3, v4, v5); + Store(medians1, d, buf + i + lanes_per_chunk); + + const V v6 = Load(d, keys + offset6 + i); + const V v7 = Load(d, keys + offset7 + i); + const V v8 = Load(d, keys + offset8 + i); + const V medians2 = MedianOf3(st, v6, v7, v8); + Store(medians2, d, buf + i + lanes_per_chunk * 2); + } + + return RecursiveMedianOf3(d, st, buf, 3 * lanes_per_chunk, + buf + 3 * lanes_per_chunk); +} + +// Compute exact min/max to detect all-equal partitions. Only called after a +// degenerate Partition (none in the right partition). +template +HWY_NOINLINE void ScanMinMax(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT buf, Vec& first, + Vec& last) { + const size_t N = Lanes(d); + + first = st.LastValue(d); + last = st.FirstValue(d); + + size_t i = 0; + for (; i + N <= num; i += N) { + const Vec v = LoadU(d, keys + i); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + if (HWY_LIKELY(i != num)) { + HWY_DASSERT(num >= N); // See HandleSpecialCases + const Vec v = LoadU(d, keys + num - N); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + + first = st.FirstOfLanes(d, first, buf); + last = st.LastOfLanes(d, last, buf); +} + +template +void Recurse(D d, Traits st, T* HWY_RESTRICT keys, const size_t begin, + const size_t end, const Vec pivot, T* HWY_RESTRICT buf, + Generator& rng, size_t remaining_levels) { + HWY_DASSERT(begin + 1 < end); + const size_t num = end - begin; // >= 2 + + // Too many degenerate partitions. This is extremely unlikely to happen + // because we select pivots from large (though still O(1)) samples. + if (HWY_UNLIKELY(remaining_levels == 0)) { + HeapSort(st, keys + begin, num); // Slow but N*logN. + return; + } + + const ptrdiff_t base_case_num = + static_cast(Constants::BaseCaseNum(Lanes(d))); + const size_t bound = Partition(d, st, keys, begin, end, pivot, buf); + + const ptrdiff_t num_left = + static_cast(bound) - static_cast(begin); + const ptrdiff_t num_right = + static_cast(end) - static_cast(bound); + + // Check for degenerate partitions (i.e. Partition did not move any keys): + if (HWY_UNLIKELY(num_right == 0)) { + // Because the pivot is one of the keys, it must have been equal to the + // first or last key in sort order. Scan for the actual min/max: + // passing the current pivot as the new bound is insufficient because one of + // the partitions might not actually include that key. + Vec first, last; + ScanMinMax(d, st, keys + begin, num, buf, first, last); + if (AllTrue(d, Eq(first, last))) return; + + // Separate recursion to make sure that we don't pick `last` as the + // pivot - that would again lead to a degenerate partition. + Recurse(d, st, keys, begin, end, first, buf, rng, remaining_levels - 1); + return; + } + + if (HWY_UNLIKELY(num_left <= base_case_num)) { + BaseCase(d, st, keys + begin, static_cast(num_left), buf); + } else { + const Vec next_pivot = ChoosePivot(d, st, keys, begin, bound, buf, rng); + Recurse(d, st, keys, begin, bound, next_pivot, buf, rng, + remaining_levels - 1); + } + if (HWY_UNLIKELY(num_right <= base_case_num)) { + BaseCase(d, st, keys + bound, static_cast(num_right), buf); + } else { + const Vec next_pivot = ChoosePivot(d, st, keys, bound, end, buf, rng); + Recurse(d, st, keys, bound, end, next_pivot, buf, rng, + remaining_levels - 1); + } +} + +// Returns true if sorting is finished. +template +bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + const size_t base_case_num = Constants::BaseCaseNum(N); + + // 128-bit keys require vectors with at least two u64 lanes, which is always + // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the + // hardware vector width is less than 128bit / fraction. + const bool partial_128 = N < 2 && st.Is128(); + // Partition assumes its input is at least two vectors. If vectors are huge, + // base_case_num may actually be smaller. If so, which is only possible on + // RVV, pass a capped or partial d (LMUL < 1). + constexpr bool kPotentiallyHuge = + HWY_MAX_BYTES / sizeof(T) > Constants::kMaxRows * Constants::kMaxCols; + const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num); + if (partial_128 || huge_vec) { + // PERFORMANCE WARNING: falling back to HeapSort. + HeapSort(st, keys, num); + return true; + } + + // Small arrays: use sorting network, no need for other checks. + if (HWY_UNLIKELY(num <= base_case_num)) { + BaseCase(d, st, keys, num, buf); + return true; + } + + // We could also check for already sorted/reverse/equal, but that's probably + // counterproductive if vqsort is used as a base case. + + return false; // not finished sorting +} + +#endif // HWY_TARGET != HWY_SCALAR +} // namespace detail + +// Sorts `keys[0..num-1]` according to the order defined by `st.Compare`. +// In-place i.e. O(1) additional storage. Worst-case N*logN comparisons. +// Non-stable (order of equal keys may change), except for the common case where +// the upper bits of T are the key, and the lower bits are a sequential or at +// least unique ID. +// There is no upper limit on `num`, but note that pivots may be chosen by +// sampling only from the first 256 GiB. +// +// `d` is typically SortTag (chooses between full and partial vectors). +// `st` is SharedTraits<{LaneTraits|Traits128}>. This abstraction layer +// bridges differences in sort order and single-lane vs 128-bit keys. +template +void Sort(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf) { +#if HWY_TARGET == HWY_SCALAR + (void)d; + (void)buf; + // PERFORMANCE WARNING: vqsort is not enabled for the non-SIMD target + return detail::HeapSort(st, keys, num); +#else + if (detail::HandleSpecialCases(d, st, keys, num, buf)) return; + +#if HWY_MAX_BYTES > 64 + // sorting_networks-inl and traits assume no more than 512 bit vectors. + if (Lanes(d) > 64 / sizeof(T)) { + return Sort(CappedTag(), st, keys, num, buf); + } +#endif // HWY_MAX_BYTES > 64 + + // Pulled out of the recursion so we can special-case degenerate partitions. + detail::Generator rng(keys, num); + const Vec pivot = detail::ChoosePivot(d, st, keys, 0, num, buf, rng); + + // Introspection: switch to worst-case N*logN heapsort after this many. + const size_t max_levels = 2 * hwy::CeilLog2(num) + 4; + + detail::Recurse(d, st, keys, 0, num, pivot, buf, rng, max_levels); +#endif // HWY_TARGET == HWY_SCALAR +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/vqsort.cc b/third_party/highway/hwy/contrib/sort/vqsort.cc new file mode 100644 index 000000000000..951a0bd51d5b --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort.cc @@ -0,0 +1,148 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#include // memset + +#include "hwy/aligned_allocator.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/shared-inl.h" + +// Seed source for SFC generator: 1=getrandom, 2=CryptGenRandom +// (not all Android support the getrandom wrapper) +#ifndef VQSORT_SECURE_SEED + +#if (defined(linux) || defined(__linux__)) && \ + !(defined(ANDROID) || defined(__ANDROID__) || HWY_ARCH_RVV) +#define VQSORT_SECURE_SEED 1 +#elif defined(_WIN32) || defined(_WIN64) +#define VQSORT_SECURE_SEED 2 +#else +#define VQSORT_SECURE_SEED 0 +#endif + +#endif // VQSORT_SECURE_SEED + +#if !VQSORT_SECURE_RNG + +#include +#if VQSORT_SECURE_SEED == 1 +#include +#elif VQSORT_SECURE_SEED == 2 +#include +#pragma comment(lib, "Advapi32.lib") +// Must come after windows.h. +#include +#endif // VQSORT_SECURE_SEED + +#endif // !VQSORT_SECURE_RNG + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +size_t VectorSize() { return Lanes(ScalableTag()); } +bool HaveFloat64() { return HWY_HAVE_FLOAT64; } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(VectorSize); +HWY_EXPORT(HaveFloat64); + +HWY_INLINE size_t PivotBufNum(size_t sizeof_t, size_t N) { + // 3 chunks of medians, 1 chunk of median medians plus two padding vectors. + const size_t lpc = SortConstants::LanesPerChunk(sizeof_t, N); + return (3 + 1) * lpc + 2 * N; +} + +} // namespace + +Sorter::Sorter() { + // Determine the largest buffer size required for any type by trying them all. + // (The capping of N in BaseCaseNum means that smaller N but larger sizeof_t + // may require a larger buffer.) + const size_t vector_size = HWY_DYNAMIC_DISPATCH(VectorSize)(); + size_t max_bytes = 0; + for (size_t sizeof_t : + {sizeof(uint16_t), sizeof(uint32_t), sizeof(uint64_t)}) { + const size_t N = vector_size / sizeof_t; + // One extra for padding plus another for full-vector loads. + const size_t base_case = SortConstants::BaseCaseNum(N) + 2 * N; + const size_t partition_num = SortConstants::PartitionBufNum(N); + const size_t buf_lanes = + HWY_MAX(base_case, HWY_MAX(partition_num, PivotBufNum(sizeof_t, N))); + max_bytes = HWY_MAX(max_bytes, buf_lanes * sizeof_t); + } + + ptr_ = hwy::AllocateAlignedBytes(max_bytes, nullptr, nullptr); + + // Prevent msan errors by initializing. + memset(ptr_, 0, max_bytes); +} + +void Sorter::Delete() { + FreeAlignedBytes(ptr_, nullptr, nullptr); + ptr_ = nullptr; +} + +#if !VQSORT_SECURE_RNG + +void Sorter::Fill24Bytes(const void* seed_heap, size_t seed_num, void* bytes) { +#if VQSORT_SECURE_SEED == 1 + // May block if urandom is not yet initialized. + const ssize_t ret = getrandom(bytes, 24, /*flags=*/0); + if (ret == 24) return; +#elif VQSORT_SECURE_SEED == 2 + HCRYPTPROV hProvider{}; + if (CryptAcquireContextA(&hProvider, nullptr, nullptr, PROV_RSA_FULL, + CRYPT_VERIFYCONTEXT)) { + const BOOL ok = + CryptGenRandom(hProvider, 24, reinterpret_cast(bytes)); + CryptReleaseContext(hProvider, 0); + if (ok) return; + } +#endif + + // VQSORT_SECURE_SEED == 0, or one of the above failed. Get some entropy from + // stack/heap/code addresses and the clock() timer. + uint64_t* words = reinterpret_cast(bytes); + uint64_t** seed_stack = &words; + void (*seed_code)(const void*, size_t, void*) = &Fill24Bytes; + const uintptr_t bits_stack = reinterpret_cast(seed_stack); + const uintptr_t bits_heap = reinterpret_cast(seed_heap); + const uintptr_t bits_code = reinterpret_cast(seed_code); + const uint64_t bits_time = static_cast(clock()); + words[0] = bits_stack ^ bits_time ^ seed_num; + words[1] = bits_heap ^ bits_time ^ seed_num; + words[2] = bits_code ^ bits_time ^ seed_num; +} + +#endif // !VQSORT_SECURE_RNG + +bool Sorter::HaveFloat64() { return HWY_DYNAMIC_DISPATCH(HaveFloat64)(); } + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort.h b/third_party/highway/hwy/contrib/sort/vqsort.h new file mode 100644 index 000000000000..6be9fcdafabd --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort.h @@ -0,0 +1,104 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Interface to vectorized quicksort with dynamic dispatch. + +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ + +#include "hwy/base.h" + +namespace hwy { + +// Aligned 128-bit type. Cannot use __int128 because clang doesn't yet align it: +// https://reviews.llvm.org/D86310 +#pragma pack(push, 1) +struct alignas(16) uint128_t { + uint64_t lo; // little-endian layout + uint64_t hi; +}; +#pragma pack(pop) + +// Tag arguments that determine the sort order. +struct SortAscending { + constexpr bool IsAscending() const { return true; } +}; +struct SortDescending { + constexpr bool IsAscending() const { return false; } +}; + +// Allocates O(1) space. Type-erased RAII wrapper over hwy/aligned_allocator.h. +// This allows amortizing the allocation over multiple sorts. +class HWY_CONTRIB_DLLEXPORT Sorter { + public: + Sorter(); + ~Sorter() { Delete(); } + + // Move-only + Sorter(const Sorter&) = delete; + Sorter& operator=(const Sorter&) = delete; + Sorter(Sorter&& other) { + Delete(); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + } + Sorter& operator=(Sorter&& other) { + Delete(); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + return *this; + } + + // Sorts keys[0, n). Dispatches to the best available instruction set, + // and does not allocate memory. + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(float* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(float* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(double* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(double* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + // For internal use only + static void Fill24Bytes(const void* seed_heap, size_t seed_num, void* bytes); + static bool HaveFloat64(); + + private: + void Delete(); + + template + T* Get() const { + return static_cast(ptr_); + } + + void* ptr_ = nullptr; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ diff --git a/third_party/highway/hwy/contrib/sort/vqsort_128a.cc b/third_party/highway/hwy/contrib/sort/vqsort_128a.cc new file mode 100644 index 000000000000..54431e98f966 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_128a.cc @@ -0,0 +1,55 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128a.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void Sort128Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(Sort128Asc); +} // namespace + +void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(Sort128Asc) + (reinterpret_cast(keys), n * 2, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_128d.cc b/third_party/highway/hwy/contrib/sort/vqsort_128d.cc new file mode 100644 index 000000000000..3c505464321f --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_128d.cc @@ -0,0 +1,55 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128d.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void Sort128Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(Sort128Desc); +} // namespace + +void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(Sort128Desc) + (reinterpret_cast(keys), n * 2, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f32a.cc b/third_party/highway/hwy/contrib/sort/vqsort_f32a.cc new file mode 100644 index 000000000000..878c146f73d6 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_f32a.cc @@ -0,0 +1,53 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32a.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF32Asc(float* HWY_RESTRICT keys, size_t num, float* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF32Asc); +} // namespace + +void Sorter::operator()(float* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortF32Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f32d.cc b/third_party/highway/hwy/contrib/sort/vqsort_f32d.cc new file mode 100644 index 000000000000..0ab7d94635e1 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_f32d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32d.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF32Desc(float* HWY_RESTRICT keys, size_t num, + float* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF32Desc); +} // namespace + +void Sorter::operator()(float* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortF32Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_f64a.cc new file mode 100644 index 000000000000..349d0d22d94a --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_f64a.cc @@ -0,0 +1,61 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64a.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF64Asc(double* HWY_RESTRICT keys, size_t num, + double* HWY_RESTRICT buf) { +#if HWY_HAVE_FLOAT64 + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +#else + (void)keys; + (void)num; + (void)buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF64Asc); +} // namespace + +void Sorter::operator()(double* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortF64Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_f64d.cc new file mode 100644 index 000000000000..9fe50919b8af --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_f64d.cc @@ -0,0 +1,61 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64d.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF64Desc(double* HWY_RESTRICT keys, size_t num, + double* HWY_RESTRICT buf) { +#if HWY_HAVE_FLOAT64 + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +#else + (void)keys; + (void)num; + (void)buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF64Desc); +} // namespace + +void Sorter::operator()(double* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortF64Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i16a.cc b/third_party/highway/hwy/contrib/sort/vqsort_i16a.cc new file mode 100644 index 000000000000..2e065acf0032 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i16a.cc @@ -0,0 +1,59 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16a.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +// Workaround for build timeout +#if !HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI16Asc(int16_t* HWY_RESTRICT keys, size_t num, + int16_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI16Asc); +} // namespace + +void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortI16Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE + +#endif // !HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i16d.cc b/third_party/highway/hwy/contrib/sort/vqsort_i16d.cc new file mode 100644 index 000000000000..139bc18afc2a --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i16d.cc @@ -0,0 +1,59 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16d.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +// Workaround for build timeout +#if !HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI16Desc(int16_t* HWY_RESTRICT keys, size_t num, + int16_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI16Desc); +} // namespace + +void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortI16Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE + +#endif // !HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i32a.cc b/third_party/highway/hwy/contrib/sort/vqsort_i32a.cc new file mode 100644 index 000000000000..2a549ae779cb --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i32a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32a.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI32Asc(int32_t* HWY_RESTRICT keys, size_t num, + int32_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI32Asc); +} // namespace + +void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortI32Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i32d.cc b/third_party/highway/hwy/contrib/sort/vqsort_i32d.cc new file mode 100644 index 000000000000..b89837f4fe35 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i32d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32d.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI32Desc(int32_t* HWY_RESTRICT keys, size_t num, + int32_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI32Desc); +} // namespace + +void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortI32Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_i64a.cc new file mode 100644 index 000000000000..417d23af1f45 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i64a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64a.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI64Asc(int64_t* HWY_RESTRICT keys, size_t num, + int64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI64Asc); +} // namespace + +void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortI64Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_i64d.cc new file mode 100644 index 000000000000..bc17f40da391 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i64d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64d.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI64Desc(int64_t* HWY_RESTRICT keys, size_t num, + int64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI64Desc); +} // namespace + +void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortI64Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u16a.cc b/third_party/highway/hwy/contrib/sort/vqsort_u16a.cc new file mode 100644 index 000000000000..a2382137bd97 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u16a.cc @@ -0,0 +1,59 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16a.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +// Workaround for build timeout +#if !HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU16Asc(uint16_t* HWY_RESTRICT keys, size_t num, + uint16_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU16Asc); +} // namespace + +void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortU16Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE + +#endif // !HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u16d.cc b/third_party/highway/hwy/contrib/sort/vqsort_u16d.cc new file mode 100644 index 000000000000..fb688f2b8867 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u16d.cc @@ -0,0 +1,59 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16d.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +// Workaround for build timeout +#if !HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU16Desc(uint16_t* HWY_RESTRICT keys, size_t num, + uint16_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU16Desc); +} // namespace + +void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortU16Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE + +#endif // !HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u32a.cc b/third_party/highway/hwy/contrib/sort/vqsort_u32a.cc new file mode 100644 index 000000000000..eee8b5105057 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u32a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32a.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU32Asc(uint32_t* HWY_RESTRICT keys, size_t num, + uint32_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU32Asc); +} // namespace + +void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortU32Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u32d.cc b/third_party/highway/hwy/contrib/sort/vqsort_u32d.cc new file mode 100644 index 000000000000..898c3c88345f --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u32d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32d.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU32Desc(uint32_t* HWY_RESTRICT keys, size_t num, + uint32_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU32Desc); +} // namespace + +void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortU32Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_u64a.cc new file mode 100644 index 000000000000..fa342a2b45a1 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u64a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64a.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU64Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU64Asc); +} // namespace + +void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortU64Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_u64d.cc new file mode 100644 index 000000000000..617f4913373b --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u64d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/disabled_targets.h" +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64d.cc" +#include "hwy/foreach_target.h" + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU64Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU64Desc); +} // namespace + +void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortU64Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/detect_compiler_arch.h b/third_party/highway/hwy/detect_compiler_arch.h index 27f573c565ad..163dd4e02335 100644 --- a/third_party/highway/hwy/detect_compiler_arch.h +++ b/third_party/highway/hwy/detect_compiler_arch.h @@ -106,20 +106,6 @@ //------------------------------------------------------------------------------ // Architecture -#if defined(HWY_EMULATE_SVE) - -#define HWY_ARCH_X86_32 0 -#define HWY_ARCH_X86_64 0 -#define HWY_ARCH_X86 0 -#define HWY_ARCH_PPC 0 -#define HWY_ARCH_ARM_A64 1 -#define HWY_ARCH_ARM_V7 0 -#define HWY_ARCH_ARM 1 -#define HWY_ARCH_WASM 0 -#define HWY_ARCH_RVV 0 - -#else - #if defined(__i386__) || defined(_M_IX86) #define HWY_ARCH_X86_32 1 #else @@ -182,8 +168,6 @@ #define HWY_ARCH_RVV 0 #endif -#endif // defined(HWY_EMULATE_SVE) - // It is an error to detect multiple architectures at the same time, but OK to // detect none of the above. #if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_WASM + \ diff --git a/third_party/highway/hwy/detect_targets.h b/third_party/highway/hwy/detect_targets.h index e1e46b2e33ff..2208f400137b 100644 --- a/third_party/highway/hwy/detect_targets.h +++ b/third_party/highway/hwy/detect_targets.h @@ -161,11 +161,6 @@ // user to override this without any guarantee of success. #ifndef HWY_BASELINE_TARGETS -#if defined(HWY_EMULATE_SVE) -#define HWY_BASELINE_TARGETS HWY_SVE // does not support SVE2 -#define HWY_BASELINE_AVX3_DL 0 -#else - // Also check HWY_ARCH to ensure that simulating unknown platforms ends up with // HWY_TARGET == HWY_SCALAR. @@ -186,7 +181,7 @@ #define HWY_BASELINE_PPC8 0 #endif -// SVE compiles, but is not yet tested. +// SVE2 compiles, but is not yet tested. #if HWY_ARCH_ARM && defined(__ARM_FEATURE_SVE2) #define HWY_BASELINE_SVE2 HWY_SVE2 #else @@ -307,8 +302,6 @@ HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | HWY_BASELINE_AVX3 | \ HWY_BASELINE_AVX3_DL | HWY_BASELINE_RVV) -#endif // HWY_EMULATE_SVE - #else // User already defined HWY_BASELINE_TARGETS, but we still need to define // HWY_BASELINE_AVX3 (matching user's definition) for HWY_CHECK_AVX3_DL. diff --git a/third_party/highway/hwy/examples/benchmark.cc b/third_party/highway/hwy/examples/benchmark.cc index 159e4c780c4a..63725c598ca5 100644 --- a/third_party/highway/hwy/examples/benchmark.cc +++ b/third_party/highway/hwy/examples/benchmark.cc @@ -25,15 +25,17 @@ #include // iota #include "hwy/aligned_allocator.h" +// Must come after foreach_target.h to avoid redefinition errors. #include "hwy/highway.h" #include "hwy/nanobenchmark.h" + HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { // These templates are not found via ADL. #if HWY_TARGET != HWY_SCALAR -using hwy::HWY_NAMESPACE::CombineShiftRightBytes; +using hwy::HWY_NAMESPACE::CombineShiftRightLanes; #endif class TwoArray { @@ -87,14 +89,14 @@ void RunBenchmark(const char* caption) { } void Intro() { - HWY_ALIGN const float in[16] = {1, 2, 3, 4, 5, 6}; - HWY_ALIGN float out[16]; + const float in[16] = {1, 2, 3, 4, 5, 6}; + float out[16]; const ScalableTag d; // largest possible vector for (size_t i = 0; i < 16; i += Lanes(d)) { - const auto vec = Load(d, in + i); // aligned! - auto result = vec * vec; - result += result; // can update if not const - Store(result, d, out + i); + const auto vec = LoadU(d, in + i); // no alignment requirement + auto result = Mul(vec, vec); + result = Add(result, result); // can update if not const + StoreU(result, d, out + i); } printf("\nF(x)->2*x^2, F(%.0f) = %.1f\n", in[2], out[2]); } @@ -109,32 +111,34 @@ class BenchmarkDot : public TwoArray { const ScalableTag d; const size_t N = Lanes(d); using V = decltype(Zero(d)); - constexpr size_t unroll = 8; // Compiler doesn't make independent sum* accumulators, so unroll manually. - // Some older compilers might not be able to fit the 8 arrays in registers, - // so manual unrolling can be helpfull if you run into this issue. - // 2 FMA ports * 4 cycle latency = 8x unrolled. - V sum[unroll]; - for (size_t i = 0; i < unroll; ++i) { - sum[i] = Zero(d); - } + // We cannot use an array because V might be a sizeless type. For reasonable + // code, we unroll 4x, but 8x might help (2 FMA ports * 4 cycle latency). + V sum0 = Zero(d); + V sum1 = Zero(d); + V sum2 = Zero(d); + V sum3 = Zero(d); const float* const HWY_RESTRICT pa = &a_[0]; const float* const HWY_RESTRICT pb = b_; - for (size_t i = 0; i < num_items; i += unroll * N) { - for (size_t j = 0; j < unroll; ++j) { - const auto a = Load(d, pa + i + j * N); - const auto b = Load(d, pb + i + j * N); - sum[j] = MulAdd(a, b, sum[j]); - } + for (size_t i = 0; i < num_items; i += 4 * N) { + const auto a0 = Load(d, pa + i + 0 * N); + const auto b0 = Load(d, pb + i + 0 * N); + sum0 = MulAdd(a0, b0, sum0); + const auto a1 = Load(d, pa + i + 1 * N); + const auto b1 = Load(d, pb + i + 1 * N); + sum1 = MulAdd(a1, b1, sum1); + const auto a2 = Load(d, pa + i + 2 * N); + const auto b2 = Load(d, pb + i + 2 * N); + sum2 = MulAdd(a2, b2, sum2); + const auto a3 = Load(d, pa + i + 3 * N); + const auto b3 = Load(d, pb + i + 3 * N); + sum3 = MulAdd(a3, b3, sum3); } - // Reduction tree: sum of all accumulators by pairs into sum[0], then the - // lanes. - for (size_t power = 1; power < unroll; power *= 2) { - for (size_t i = 0; i < unroll; i += 2 * power) { - sum[i] += sum[i + power]; - } - } - dot_ = GetLane(SumOfLanes(d, sum[0])); + // Reduction tree: sum of all accumulators by pairs into sum0. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + dot_ = GetLane(SumOfLanes(d, sum0)); return static_cast(dot_); } void Verify(size_t num_items) { @@ -193,9 +197,9 @@ struct BenchmarkDelta : public TwoArray { auto prev = Load(df, &a_[0]); for (; i < num_items; i += Lanes(df)) { const auto a = Load(df, &a_[i]); - const auto shifted = CombineShiftRightLanes<3>(a, prev); + const auto shifted = CombineShiftRightLanes<3>(df, a, prev); prev = a; - Store(a - shifted, df, &b_[i]); + Store(Sub(a, shifted), df, &b_[i]); } #endif return static_cast(b_[num_items - 1]); diff --git a/third_party/highway/hwy/examples/skeleton.cc b/third_party/highway/hwy/examples/skeleton.cc index 590d8be7fdc3..02599a25bdc2 100644 --- a/third_party/highway/hwy/examples/skeleton.cc +++ b/third_party/highway/hwy/examples/skeleton.cc @@ -24,11 +24,14 @@ // Generates code for each enabled target by re-including this source file. #include "hwy/foreach_target.h" +// Must come after foreach_target.h to avoid redefinition errors. #include "hwy/highway.h" // Optional, can instead add HWY_ATTR to all functions. HWY_BEFORE_NAMESPACE(); namespace skeleton { +// This namespace name is unique per target, which allows code for multiple +// targets to co-exist in the same translation unit. namespace HWY_NAMESPACE { // Highway ops reside here; ADL does not find templates nor builtins. @@ -47,7 +50,7 @@ template ATTR_MSAN void OneFloorLog2(const DF df, const uint8_t* HWY_RESTRICT values, uint8_t* HWY_RESTRICT log2) { // Type tags for converting to other element types (Rebind = same count). - const Rebind d32; + const RebindToSigned d32; const Rebind d8; const auto u8 = Load(d8, values); @@ -59,7 +62,7 @@ ATTR_MSAN void OneFloorLog2(const DF df, const uint8_t* HWY_RESTRICT values, void CodepathDemo() { // Highway defaults to portability, but per-target codepaths may be selected // via #if HWY_TARGET == HWY_SSE4 or by testing capability macros: -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 const char* gather = "Has int64"; #else const char* gather = "No int64"; @@ -71,20 +74,16 @@ void FloorLog2(const uint8_t* HWY_RESTRICT values, size_t count, uint8_t* HWY_RESTRICT log2) { CodepathDemo(); - // Second argument is necessary on RVV until it supports fractional lengths. - const ScalableTag df; - + const ScalableTag df; const size_t N = Lanes(df); size_t i = 0; for (; i + N <= count; i += N) { OneFloorLog2(df, values + i, log2 + i); } - // TODO(janwas): implement -#if HWY_TARGET != HWY_RVV for (; i < count; ++i) { - OneFloorLog2(HWY_CAPPED(float, 1)(), values + i, log2 + i); + CappedTag d1; + OneFloorLog2(d1, values + i, log2 + i); } -#endif } // NOLINTNEXTLINE(google-readability-namespace-comments) @@ -92,6 +91,9 @@ void FloorLog2(const uint8_t* HWY_RESTRICT values, size_t count, } // namespace skeleton HWY_AFTER_NAMESPACE(); +// The table of pointers to the various implementations in HWY_NAMESPACE must +// be compiled only once (foreach_target #includes this file multiple times). +// HWY_ONCE is true for only one of these 'compilation passes'. #if HWY_ONCE namespace skeleton { @@ -105,6 +107,8 @@ HWY_EXPORT(FloorLog2); // is equivalent to inlining this function. void CallFloorLog2(const uint8_t* HWY_RESTRICT in, const size_t count, uint8_t* HWY_RESTRICT out) { + // This must reside outside of HWY_NAMESPACE because it references (calls the + // appropriate one from) the per-target implementations there. return HWY_DYNAMIC_DISPATCH(FloorLog2)(in, count, out); } diff --git a/third_party/highway/hwy/examples/skeleton_test.cc b/third_party/highway/hwy/examples/skeleton_test.cc index 7f79b189f6d2..8058f84f5d4b 100644 --- a/third_party/highway/hwy/examples/skeleton_test.cc +++ b/third_party/highway/hwy/examples/skeleton_test.cc @@ -21,10 +21,13 @@ #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "examples/skeleton_test.cc" #include "hwy/foreach_target.h" + +// Must come after foreach_target.h to avoid redefinition errors. #include "hwy/highway.h" #include "hwy/tests/test_util-inl.h" // Optional: factor out parts of the implementation into *-inl.h +// (must also come after foreach_target.h to avoid redefinition errors) #include "hwy/examples/skeleton-inl.h" HWY_BEFORE_NAMESPACE(); @@ -50,10 +53,7 @@ struct TestFloorLog2 { CallFloorLog2(in.get(), count, out.get()); int sum = 0; for (size_t i = 0; i < count; ++i) { - // TODO(janwas): implement -#if HWY_TARGET != HWY_RVV HWY_ASSERT_EQ(expected[i], out[i]); -#endif sum += out[i]; } hwy::PreventElision(sum); diff --git a/third_party/highway/hwy/foreach_target.h b/third_party/highway/hwy/foreach_target.h index 8ce0560cbabd..d31cdab9b764 100644 --- a/third_party/highway/hwy/foreach_target.h +++ b/third_party/highway/hwy/foreach_target.h @@ -74,6 +74,28 @@ #endif #endif +#if (HWY_TARGETS & HWY_SVE) && (HWY_STATIC_TARGET != HWY_SVE) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE2) && (HWY_STATIC_TARGET != HWY_SVE2) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + #if (HWY_TARGETS & HWY_SSSE3) && (HWY_STATIC_TARGET != HWY_SSSE3) #undef HWY_TARGET #define HWY_TARGET HWY_SSSE3 diff --git a/third_party/highway/hwy/highway.h b/third_party/highway/hwy/highway.h index 174e171f309a..48a56dc47c69 100644 --- a/third_party/highway/hwy/highway.h +++ b/third_party/highway/hwy/highway.h @@ -27,7 +27,7 @@ namespace hwy { // API version (https://semver.org/); keep in sync with CMakeLists.txt. #define HWY_MAJOR 0 -#define HWY_MINOR 15 +#define HWY_MINOR 16 #define HWY_PATCH 0 //------------------------------------------------------------------------------ @@ -37,7 +37,9 @@ namespace hwy { // HWY_FULL(T[,LMUL=1]) is a native vector/group. LMUL is the number of // registers in the group, and is ignored on targets that do not support groups. -#define HWY_FULL1(T) hwy::HWY_NAMESPACE::Simd +#define HWY_FULL1(T) hwy::HWY_NAMESPACE::ScalableTag +#define HWY_FULL2(T, LMUL) \ + hwy::HWY_NAMESPACE::ScalableTag #define HWY_3TH_ARG(arg1, arg2, arg3, ...) arg3 // Workaround for MSVC grouping __VA_ARGS__ into a single argument #define HWY_FULL_RECOMPOSER(args_with_paren) HWY_3TH_ARG args_with_paren @@ -46,9 +48,9 @@ namespace hwy { HWY_FULL_RECOMPOSER((__VA_ARGS__, HWY_FULL2, HWY_FULL1, )) #define HWY_FULL(...) HWY_CHOOSE_FULL(__VA_ARGS__())(__VA_ARGS__) -// Vector of up to MAX_N lanes. Discouraged, when possible, use Half<> instead. +// Vector of up to MAX_N lanes. It's better to use full vectors where possible. #define HWY_CAPPED(T, MAX_N) \ - hwy::HWY_NAMESPACE::Simd + hwy::HWY_NAMESPACE::CappedTag //------------------------------------------------------------------------------ // Export user functions for static/dynamic dispatch @@ -109,6 +111,7 @@ struct FunctionCache { template static RetType ChooseAndCall(Args... args) { // If we are running here it means we need to update the chosen target. + ChosenTarget& chosen_target = GetChosenTarget(); chosen_target.Update(); return (table[chosen_target.GetIndex()])(args...); } @@ -263,10 +266,15 @@ FunctionCache FunctionCacheFactory(RetType (*)(Args...)) { HWY_CHOOSE_SCALAR(FUNC_NAME), \ } #define HWY_DYNAMIC_DISPATCH(FUNC_NAME) \ - (*(HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::chosen_target.GetIndex()])) + (*(HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::GetChosenTarget().GetIndex()])) #endif // HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) +// DEPRECATED names; please use HWY_HAVE_* instead. +#define HWY_CAP_INTEGER64 HWY_HAVE_INTEGER64 +#define HWY_CAP_FLOAT16 HWY_HAVE_FLOAT16 +#define HWY_CAP_FLOAT64 HWY_HAVE_FLOAT64 + } // namespace hwy #endif // HWY_HIGHWAY_INCLUDED @@ -283,13 +291,6 @@ FunctionCache FunctionCacheFactory(RetType (*)(Args...)) { #define HWY_HIGHWAY_PER_TARGET #endif -#undef HWY_FULL2 -#if HWY_TARGET == HWY_RVV -#define HWY_FULL2(T, LMUL) hwy::HWY_NAMESPACE::Simd -#else -#define HWY_FULL2(T, LMUL) hwy::HWY_NAMESPACE::Simd -#endif - // These define ops inside namespace hwy::HWY_NAMESPACE. #if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 #include "hwy/ops/x86_128-inl.h" diff --git a/third_party/highway/hwy/highway_export.h b/third_party/highway/hwy/highway_export.h new file mode 100644 index 000000000000..e3d2546d5f2c --- /dev/null +++ b/third_party/highway/hwy/highway_export.h @@ -0,0 +1,106 @@ +// Pseudo-generated file to handle both cmake & bazel build system. + +// Initial generation done using cmake code: +// include(GenerateExportHeader) +// generate_export_header(hwy EXPORT_MACRO_NAME HWY_DLLEXPORT EXPORT_FILE_NAME +// hwy/highway_export.h) +// code reformatted using clang-format --style=Google + +#ifndef HWY_DLLEXPORT_H +#define HWY_DLLEXPORT_H + +// Bazel build are always static: +#if !defined(HWY_SHARED_DEFINE) && !defined(HWY_STATIC_DEFINE) +#define HWY_STATIC_DEFINE +#endif + +#ifdef HWY_STATIC_DEFINE +#define HWY_DLLEXPORT +#define HWY_NO_EXPORT +#define HWY_CONTRIB_DLLEXPORT +#define HWY_CONTRIB_NO_EXPORT +#define HWY_TEST_DLLEXPORT +#define HWY_TEST_NO_EXPORT +#else + +#ifndef HWY_DLLEXPORT +#if defined(hwy_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllexport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else +/* We are using this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllimport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif +#endif + +#ifndef HWY_NO_EXPORT +#ifdef _WIN32 +#define HWY_NO_EXPORT +#else +#define HWY_NO_EXPORT __attribute__((visibility("hidden"))) +#endif +#endif + +#ifndef HWY_CONTRIB_DLLEXPORT +#if defined(hwy_contrib_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllexport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else +/* We are using this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllimport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif +#endif + +#ifndef HWY_CONTRIB_NO_EXPORT +#ifdef _WIN32 +#define HWY_CONTRIB_NO_EXPORT +#else +#define HWY_CONTRIB_NO_EXPORT __attribute__((visibility("hidden"))) +#endif +#endif + +#ifndef HWY_TEST_DLLEXPORT +#if defined(hwy_test_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllexport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else +/* We are using this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllimport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif +#endif + +#ifndef HWY_TEST_NO_EXPORT +#ifdef _WIN32 +#define HWY_TEST_NO_EXPORT +#else +#define HWY_TEST_NO_EXPORT __attribute__((visibility("hidden"))) +#endif +#endif + +#endif + +#endif /* HWY_DLLEXPORT_H */ diff --git a/third_party/highway/hwy/highway_test.cc b/third_party/highway/hwy/highway_test.cc index d71f419c3611..cef616bc9592 100644 --- a/third_party/highway/hwy/highway_test.cc +++ b/third_party/highway/hwy/highway_test.cc @@ -15,6 +15,8 @@ #include #include +#include + #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "highway_test.cc" #include "hwy/foreach_target.h" @@ -26,6 +28,53 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { +// For testing that ForPartialVectors reaches every possible size: +using NumLanesSet = std::bitset; + +// Monostate pattern because ForPartialVectors takes a template argument, not a +// functor by reference. +static NumLanesSet* NumLanesForSize(size_t sizeof_t) { + HWY_ASSERT(sizeof_t <= sizeof(uint64_t)); + static NumLanesSet num_lanes[sizeof(uint64_t) + 1]; + return num_lanes + sizeof_t; +} +static size_t* MaxLanesForSize(size_t sizeof_t) { + HWY_ASSERT(sizeof_t <= sizeof(uint64_t)); + static size_t num_lanes[sizeof(uint64_t) + 1] = {0}; + return num_lanes + sizeof_t; +} + +struct TestMaxLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const size_t kMax = MaxLanes(d); + HWY_ASSERT(N <= kMax); + HWY_ASSERT(kMax <= (HWY_MAX_BYTES / sizeof(T))); + + NumLanesForSize(sizeof(T))->set(N); + *MaxLanesForSize(sizeof(T)) = HWY_MAX(*MaxLanesForSize(sizeof(T)), N); + } +}; + +HWY_NOINLINE void TestAllMaxLanes() { + ForAllTypes(ForPartialVectors()); + + // Ensure ForPartialVectors visited all powers of two [1, N]. + for (size_t sizeof_t : {sizeof(uint8_t), sizeof(uint16_t), sizeof(uint32_t), + sizeof(uint64_t)}) { + const size_t N = *MaxLanesForSize(sizeof_t); + for (size_t i = 1; i <= N; i += i) { + if (!NumLanesForSize(sizeof_t)->test(i)) { + fprintf(stderr, "T=%d: did not visit for N=%d, max=%d\n", + static_cast(sizeof_t), static_cast(i), + static_cast(N)); + HWY_ASSERT(false); + } + } + } +} + struct TestSet { template HWY_NOINLINE void operator()(T /*unused*/, D d) { @@ -322,6 +371,7 @@ HWY_AFTER_NAMESPACE(); namespace hwy { HWY_BEFORE_TEST(HighwayTest); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllMaxLanes); HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllSet); HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllOverflow); HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllClamp); diff --git a/third_party/highway/hwy/hwy.version b/third_party/highway/hwy/hwy.version new file mode 100644 index 000000000000..9ff6be6a2d72 --- /dev/null +++ b/third_party/highway/hwy/hwy.version @@ -0,0 +1,19 @@ +HWY_0 { + global: + extern "C++" { + *hwy::*; + }; + + local: + # Hide all the std namespace symbols. std namespace is explicitly marked + # as visibility(default) and header-only functions or methods (such as those + # from templates) should be exposed in shared libraries as weak symbols but + # this is only needed when we expose those types in the shared library API + # in any way. We don't use C++ std types in the API and we also don't + # support exceptions in the library. + # See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=36022 for a discussion + # about this. + extern "C++" { + *std::*; + }; +}; diff --git a/third_party/highway/hwy/nanobenchmark.cc b/third_party/highway/hwy/nanobenchmark.cc index 9998c7ed3e0b..c26d539d8447 100644 --- a/third_party/highway/hwy/nanobenchmark.cc +++ b/third_party/highway/hwy/nanobenchmark.cc @@ -37,7 +37,7 @@ #include #endif -#if defined(__MACH__) +#if defined(__APPLE__) #include #include #endif @@ -148,7 +148,7 @@ inline Ticks Start() { LARGE_INTEGER counter; (void)QueryPerformanceCounter(&counter); t = counter.QuadPart; -#elif defined(__MACH__) +#elif defined(__APPLE__) t = mach_absolute_time(); #elif defined(__HAIKU__) t = system_time_nsecs(); // since boot @@ -405,7 +405,7 @@ double NominalClockRate() { } // namespace -double InvariantTicksPerSecond() { +HWY_DLLEXPORT double InvariantTicksPerSecond() { #if HWY_ARCH_PPC && defined(__GLIBC__) return double(__ppc_get_timebase_freq()); #elif HWY_ARCH_X86 @@ -415,7 +415,7 @@ double InvariantTicksPerSecond() { LARGE_INTEGER freq; (void)QueryPerformanceFrequency(&freq); return double(freq.QuadPart); -#elif defined(__MACH__) +#elif defined(__APPLE__) // https://developer.apple.com/library/mac/qa/qa1398/_index.html mach_timebase_info_data_t timebase; (void)mach_timebase_info(&timebase); @@ -426,12 +426,12 @@ double InvariantTicksPerSecond() { #endif } -double Now() { +HWY_DLLEXPORT double Now() { static const double mul = 1.0 / InvariantTicksPerSecond(); return static_cast(timer::Start()) * mul; } -uint64_t TimerResolution() { +HWY_DLLEXPORT uint64_t TimerResolution() { // Nested loop avoids exceeding stack/L1 capacity. timer::Ticks repetitions[Params::kTimerSamples]; for (size_t rep = 0; rep < Params::kTimerSamples; ++rep) { @@ -656,10 +656,11 @@ timer::Ticks Overhead(const uint8_t* arg, const InputVec* inputs, } // namespace -int Unpredictable1() { return timer::Start() != ~0ULL; } +HWY_DLLEXPORT int Unpredictable1() { return timer::Start() != ~0ULL; } -size_t Measure(const Func func, const uint8_t* arg, const FuncInput* inputs, - const size_t num_inputs, Result* results, const Params& p) { +HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, + const FuncInput* inputs, const size_t num_inputs, + Result* results, const Params& p) { NANOBENCHMARK_CHECK(num_inputs != 0); #if HWY_ARCH_X86 diff --git a/third_party/highway/hwy/nanobenchmark.h b/third_party/highway/hwy/nanobenchmark.h index 18065f8f97aa..c5c726176760 100644 --- a/third_party/highway/hwy/nanobenchmark.h +++ b/third_party/highway/hwy/nanobenchmark.h @@ -47,6 +47,8 @@ #include #include +#include "hwy/highway_export.h" + // Enables sanity checks that verify correct operation at the cost of // longer benchmark runs. #ifndef NANOBENCHMARK_ENABLE_CHECKS @@ -72,23 +74,23 @@ namespace platform { // Returns tick rate, useful for converting measurements to seconds. Invariant // means the tick counter frequency is independent of CPU throttling or sleep. // This call may be expensive, callers should cache the result. -double InvariantTicksPerSecond(); +HWY_DLLEXPORT double InvariantTicksPerSecond(); // Returns current timestamp [in seconds] relative to an unspecified origin. // Features: monotonic (no negative elapsed time), steady (unaffected by system // time changes), high-resolution (on the order of microseconds). -double Now(); +HWY_DLLEXPORT double Now(); // Returns ticks elapsed in back to back timer calls, i.e. a function of the // timer resolution (minimum measurable difference) and overhead. // This call is expensive, callers should cache the result. -uint64_t TimerResolution(); +HWY_DLLEXPORT uint64_t TimerResolution(); } // namespace platform // Returns 1, but without the compiler knowing what the value is. This prevents // optimizing out code. -int Unpredictable1(); +HWY_DLLEXPORT int Unpredictable1(); // Input influencing the function being measured (e.g. number of bytes to copy). using FuncInput = size_t; @@ -164,9 +166,9 @@ struct Result { // uniform distribution over [0, 4) could be represented as {3,0,2,1}. // Returns how many Result were written to "results": one per unique input, or // zero if the measurement failed (an error message goes to stderr). -size_t Measure(const Func func, const uint8_t* arg, const FuncInput* inputs, - const size_t num_inputs, Result* results, - const Params& p = Params()); +HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, + const FuncInput* inputs, const size_t num_inputs, + Result* results, const Params& p = Params()); // Calls operator() of the given closure (lambda function). template diff --git a/third_party/highway/hwy/ops/arm_neon-inl.h b/third_party/highway/hwy/ops/arm_neon-inl.h index 774ca5db1868..4392a314e3ca 100644 --- a/third_party/highway/hwy/ops/arm_neon-inl.h +++ b/third_party/highway/hwy/ops/arm_neon-inl.h @@ -15,6 +15,9 @@ // 128-bit ARM64 NEON vectors and operations. // External include guard in highway.h - see comment there. +// ARM NEON intrinsics are documented at: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon] + #include #include #include @@ -27,7 +30,13 @@ namespace hwy { namespace HWY_NAMESPACE { template -using Full128 = Simd; +using Full128 = Simd; + +template +using Full64 = Simd; + +template +using Full32 = Simd; namespace detail { // for code folding and Raw128 @@ -39,18 +48,19 @@ namespace detail { // for code folding and Raw128 #define HWY_NEON_BUILD_TPL_2 #define HWY_NEON_BUILD_TPL_3 -// HWY_NEON_BUILD_RET_* is return type. -#define HWY_NEON_BUILD_RET_1(type, size) Vec128 -#define HWY_NEON_BUILD_RET_2(type, size) Vec128 -#define HWY_NEON_BUILD_RET_3(type, size) Vec128 +// HWY_NEON_BUILD_RET_* is return type; type arg is without _t suffix so we can +// extend it to int32x4x2_t packs. +#define HWY_NEON_BUILD_RET_1(type, size) Vec128 +#define HWY_NEON_BUILD_RET_2(type, size) Vec128 +#define HWY_NEON_BUILD_RET_3(type, size) Vec128 // HWY_NEON_BUILD_PARAM_* is the list of parameters the function receives. -#define HWY_NEON_BUILD_PARAM_1(type, size) const Vec128 a +#define HWY_NEON_BUILD_PARAM_1(type, size) const Vec128 a #define HWY_NEON_BUILD_PARAM_2(type, size) \ - const Vec128 a, const Vec128 b -#define HWY_NEON_BUILD_PARAM_3(type, size) \ - const Vec128 a, const Vec128 b, \ - const Vec128 c + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_PARAM_3(type, size) \ + const Vec128 a, const Vec128 b, \ + const Vec128 c // HWY_NEON_BUILD_ARG_* is the list of arguments passed to the underlying // function. @@ -86,70 +96,76 @@ namespace detail { // for code folding and Raw128 // using args=2. // uint8_t -#define HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(uint8_t, 16, name, prefix##q, infix, u8, args) \ - HWY_NEON_DEF_FUNCTION(uint8_t, 8, name, prefix, infix, u8, args) \ - HWY_NEON_DEF_FUNCTION(uint8_t, 4, name, prefix, infix, u8, args) \ - HWY_NEON_DEF_FUNCTION(uint8_t, 2, name, prefix, infix, u8, args) \ - HWY_NEON_DEF_FUNCTION(uint8_t, 1, name, prefix, infix, u8, args) +#define HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 8, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 4, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 2, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 1, name, prefix, infix, u8, args) // int8_t -#define HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(int8_t, 16, name, prefix##q, infix, s8, args) \ - HWY_NEON_DEF_FUNCTION(int8_t, 8, name, prefix, infix, s8, args) \ - HWY_NEON_DEF_FUNCTION(int8_t, 4, name, prefix, infix, s8, args) \ - HWY_NEON_DEF_FUNCTION(int8_t, 2, name, prefix, infix, s8, args) \ - HWY_NEON_DEF_FUNCTION(int8_t, 1, name, prefix, infix, s8, args) +#define HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 8, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 4, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 2, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 1, name, prefix, infix, s8, args) // uint16_t -#define HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(uint16_t, 8, name, prefix##q, infix, u16, args) \ - HWY_NEON_DEF_FUNCTION(uint16_t, 4, name, prefix, infix, u16, args) \ - HWY_NEON_DEF_FUNCTION(uint16_t, 2, name, prefix, infix, u16, args) \ - HWY_NEON_DEF_FUNCTION(uint16_t, 1, name, prefix, infix, u16, args) +#define HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 4, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 2, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 1, name, prefix, infix, u16, args) // int16_t -#define HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(int16_t, 8, name, prefix##q, infix, s16, args) \ - HWY_NEON_DEF_FUNCTION(int16_t, 4, name, prefix, infix, s16, args) \ - HWY_NEON_DEF_FUNCTION(int16_t, 2, name, prefix, infix, s16, args) \ - HWY_NEON_DEF_FUNCTION(int16_t, 1, name, prefix, infix, s16, args) +#define HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 4, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 2, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 1, name, prefix, infix, s16, args) // uint32_t -#define HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(uint32_t, 4, name, prefix##q, infix, u32, args) \ - HWY_NEON_DEF_FUNCTION(uint32_t, 2, name, prefix, infix, u32, args) \ - HWY_NEON_DEF_FUNCTION(uint32_t, 1, name, prefix, infix, u32, args) +#define HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 2, name, prefix, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 1, name, prefix, infix, u32, args) // int32_t -#define HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(int32_t, 4, name, prefix##q, infix, s32, args) \ - HWY_NEON_DEF_FUNCTION(int32_t, 2, name, prefix, infix, s32, args) \ - HWY_NEON_DEF_FUNCTION(int32_t, 1, name, prefix, infix, s32, args) +#define HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 2, name, prefix, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 1, name, prefix, infix, s32, args) // uint64_t -#define HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(uint64_t, 2, name, prefix##q, infix, u64, args) \ - HWY_NEON_DEF_FUNCTION(uint64_t, 1, name, prefix, infix, u64, args) +#define HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) // int64_t -#define HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(int64_t, 2, name, prefix##q, infix, s64, args) \ - HWY_NEON_DEF_FUNCTION(int64_t, 1, name, prefix, infix, s64, args) +#define HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) + +// float +#define HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float32, 4, name, prefix##q, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 2, name, prefix, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 1, name, prefix, infix, f32, args) + +// double +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float64, 2, name, prefix##q, infix, f64, args) \ + HWY_NEON_DEF_FUNCTION(float64, 1, name, prefix, infix, f64, args) // float and double #if HWY_ARCH_ARM_A64 -#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(float, 4, name, prefix##q, infix, f32, args) \ - HWY_NEON_DEF_FUNCTION(float, 2, name, prefix, infix, f32, args) \ - HWY_NEON_DEF_FUNCTION(float, 1, name, prefix, infix, f32, args) \ - HWY_NEON_DEF_FUNCTION(double, 2, name, prefix##q, infix, f64, args) \ - HWY_NEON_DEF_FUNCTION(double, 1, name, prefix, infix, f64, args) +#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) #else -#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ - HWY_NEON_DEF_FUNCTION(float, 4, name, prefix##q, infix, f32, args) \ - HWY_NEON_DEF_FUNCTION(float, 2, name, prefix, infix, f32, args) \ - HWY_NEON_DEF_FUNCTION(float, 1, name, prefix, infix, f32, args) +#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) #endif // Helper macros to define for more than one type. @@ -501,6 +517,12 @@ class Vec128 { Raw raw; }; +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + // FF..FF or 0. template class Mask128 { @@ -518,11 +540,11 @@ class Mask128 { namespace detail { -// Deduce Simd from Vec128 +// Deduce Simd from Vec128 struct DeduceD { template - Simd operator()(Vec128) const { - return Simd(); + Simd operator()(Vec128) const { + return Simd(); } }; @@ -542,8 +564,8 @@ namespace detail { // vreinterpret*_u8_*() set of functions. #define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 #define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \ - Vec128 -#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128 v + Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128 v #define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw // Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined. @@ -575,7 +597,7 @@ HWY_INLINE Vec128 BitCastToByte(Vec128 v) { #undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return v; } @@ -583,47 +605,47 @@ HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, // 64-bit or less: template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_s8_u8(v.raw)); } template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_u16_u8(v.raw)); } template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_s16_u8(v.raw)); } template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_u32_u8(v.raw)); } template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_s32_u8(v.raw)); } template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_f32_u8(v.raw)); } -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, - Vec128 v) { - return Vec128(vreinterpret_u64_u8(v.raw)); +HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, + Vec128 v) { + return Vec64(vreinterpret_u64_u8(v.raw)); } -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, - Vec128 v) { - return Vec128(vreinterpret_s64_u8(v.raw)); +HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, + Vec128 v) { + return Vec64(vreinterpret_s64_u8(v.raw)); } #if HWY_ARCH_ARM_A64 -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, - Vec128 v) { - return Vec128(vreinterpret_f64_u8(v.raw)); +HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, + Vec128 v) { + return Vec64(vreinterpret_f64_u8(v.raw)); } #endif @@ -671,20 +693,20 @@ HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, // Special cases for [b]float16_t, which have the same Raw as uint16_t. template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { - return Vec128(BitCastFromByte(Simd(), v).raw); + return Vec128(BitCastFromByte(Simd(), v).raw); } template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, - Vec128 v) { - return Vec128(BitCastFromByte(Simd(), v).raw); +HWY_INLINE Vec128 BitCastFromByte( + Simd /* tag */, Vec128 v) { + return Vec128(BitCastFromByte(Simd(), v).raw); } } // namespace detail template -HWY_API Vec128 BitCast(Simd d, +HWY_API Vec128 BitCast(Simd d, Vec128 v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } @@ -693,9 +715,9 @@ HWY_API Vec128 BitCast(Simd d, // Returns a vector with all lanes set to "t". #define HWY_NEON_BUILD_TPL_HWY_SET1 -#define HWY_NEON_BUILD_RET_HWY_SET1(type, size) Vec128 +#define HWY_NEON_BUILD_RET_HWY_SET1(type, size) Vec128 #define HWY_NEON_BUILD_PARAM_HWY_SET1(type, size) \ - Simd /* tag */, const type t + Simd /* tag */, const type##_t t #define HWY_NEON_BUILD_ARG_HWY_SET1 t HWY_NEON_DEF_FUNCTION_ALL_TYPES(Set, vdup, _n_, HWY_SET1) @@ -707,13 +729,13 @@ HWY_NEON_DEF_FUNCTION_ALL_TYPES(Set, vdup, _n_, HWY_SET1) // Returns an all-zero vector. template -HWY_API Vec128 Zero(Simd d) { +HWY_API Vec128 Zero(Simd d) { return Set(d, 0); } template -HWY_API Vec128 Zero(Simd /* tag */) { - return Vec128(Zero(Simd()).raw); +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128(Zero(Simd()).raw); } template @@ -721,7 +743,7 @@ using VFromD = decltype(Zero(D())); // Returns a vector with uninitialized elements. template -HWY_API Vec128 Undefined(Simd /*d*/) { +HWY_API Vec128 Undefined(Simd /*d*/) { HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") typename detail::Raw128::type a; @@ -731,7 +753,7 @@ HWY_API Vec128 Undefined(Simd /*d*/) { // Returns a vector with lane i=[0, N) set to "first" + i. template -Vec128 Iota(const Simd d, const T2 first) { +Vec128 Iota(const Simd d, const T2 first) { HWY_ALIGN T lanes[16 / sizeof(T)]; for (size_t i = 0; i < 16 / sizeof(T); ++i) { lanes[i] = static_cast(first + static_cast(i)); @@ -792,30 +814,26 @@ HWY_API int32_t GetLane(const Vec128 v) { HWY_API uint64_t GetLane(const Vec128 v) { return vgetq_lane_u64(v.raw, 0); } -HWY_API uint64_t GetLane(const Vec128 v) { +HWY_API uint64_t GetLane(const Vec64 v) { return vget_lane_u64(v.raw, 0); } HWY_API int64_t GetLane(const Vec128 v) { return vgetq_lane_s64(v.raw, 0); } -HWY_API int64_t GetLane(const Vec128 v) { +HWY_API int64_t GetLane(const Vec64 v) { return vget_lane_s64(v.raw, 0); } HWY_API float GetLane(const Vec128 v) { return vgetq_lane_f32(v.raw, 0); } -HWY_API float GetLane(const Vec128 v) { - return vget_lane_f32(v.raw, 0); -} -HWY_API float GetLane(const Vec128 v) { - return vget_lane_f32(v.raw, 0); -} +HWY_API float GetLane(const Vec64 v) { return vget_lane_f32(v.raw, 0); } +HWY_API float GetLane(const Vec32 v) { return vget_lane_f32(v.raw, 0); } #if HWY_ARCH_ARM_A64 HWY_API double GetLane(const Vec128 v) { return vgetq_lane_f64(v.raw, 0); } -HWY_API double GetLane(const Vec128 v) { +HWY_API double GetLane(const Vec64 v) { return vget_lane_f64(v.raw, 0); } #endif @@ -828,7 +846,16 @@ HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator+, vadd, _, 2) // ------------------------------ Subtraction HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator-, vsub, _, 2) -// ------------------------------ Saturating addition and subtraction +// ------------------------------ SumsOf8 + +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v.raw)))); +} +HWY_API Vec64 SumsOf8(const Vec64 v) { + return Vec64(vpaddl_u32(vpaddl_u16(vpaddl_u8(v.raw)))); +} + +// ------------------------------ SaturatedAdd // Only defined for uint8_t, uint16_t and their signed versions, as in other // architectures. @@ -838,6 +865,8 @@ HWY_NEON_DEF_FUNCTION_INT_16(SaturatedAdd, vqadd, _, 2) HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedAdd, vqadd, _, 2) HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedAdd, vqadd, _, 2) +// ------------------------------ SaturatedSub + // Returns a - b clamped to the destination range. HWY_NEON_DEF_FUNCTION_INT_8(SaturatedSub, vqsub, _, 2) HWY_NEON_DEF_FUNCTION_INT_16(SaturatedSub, vqsub, _, 2) @@ -863,11 +892,11 @@ HWY_NEON_DEF_FUNCTION_UINT_16(AverageRound, vrhadd, _, 2) HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Neg, vneg, _, 1) HWY_NEON_DEF_FUNCTION_INT_8_16_32(Neg, vneg, _, 1) // i64 implemented below -HWY_API Vec128 Neg(const Vec128 v) { +HWY_API Vec64 Neg(const Vec64 v) { #if HWY_ARCH_ARM_A64 - return Vec128(vneg_s64(v.raw)); + return Vec64(vneg_s64(v.raw)); #else - return Zero(Simd()) - v; + return Zero(Full64()) - v; #endif } @@ -886,9 +915,9 @@ HWY_API Vec128 Neg(const Vec128 v) { #undef HWY_NEON_DEF_FUNCTION #define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ template \ - HWY_API Vec128 name(const Vec128 v) { \ + HWY_API Vec128 name(const Vec128 v) { \ return kBits == 0 ? v \ - : Vec128(HWY_NEON_EVAL( \ + : Vec128(HWY_NEON_EVAL( \ prefix##infix##suffix, v.raw, HWY_MAX(1, kBits))); \ } @@ -954,9 +983,9 @@ HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_u64(v.raw, vreinterpretq_s64_u64(bits.raw))); } -HWY_API Vec128 operator<<(const Vec128 v, - const Vec128 bits) { - return Vec128(vshl_u64(v.raw, vreinterpret_s64_u64(bits.raw))); +HWY_API Vec64 operator<<(const Vec64 v, + const Vec64 bits) { + return Vec64(vshl_u64(v.raw, vreinterpret_s64_u64(bits.raw))); } HWY_API Vec128 operator<<(const Vec128 v, @@ -993,9 +1022,9 @@ HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s64(v.raw, bits.raw)); } -HWY_API Vec128 operator<<(const Vec128 v, - const Vec128 bits) { - return Vec128(vshl_s64(v.raw, bits.raw)); +HWY_API Vec64 operator<<(const Vec64 v, + const Vec64 bits) { + return Vec64(vshl_s64(v.raw, bits.raw)); } // ------------------------------ Shr (Neg) @@ -1008,7 +1037,7 @@ HWY_API Vec128 operator>>(const Vec128 v, template HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { - const int8x8_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + const int8x8_t neg_bits = Neg(BitCast(Simd(), bits)).raw; return Vec128(vshl_u8(v.raw, neg_bits)); } @@ -1020,7 +1049,7 @@ HWY_API Vec128 operator>>(const Vec128 v, template HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { - const int16x4_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + const int16x4_t neg_bits = Neg(BitCast(Simd(), bits)).raw; return Vec128(vshl_u16(v.raw, neg_bits)); } @@ -1032,7 +1061,7 @@ HWY_API Vec128 operator>>(const Vec128 v, template HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { - const int32x2_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + const int32x2_t neg_bits = Neg(BitCast(Simd(), bits)).raw; return Vec128(vshl_u32(v.raw, neg_bits)); } @@ -1041,10 +1070,10 @@ HWY_API Vec128 operator>>(const Vec128 v, const int64x2_t neg_bits = Neg(BitCast(Full128(), bits)).raw; return Vec128(vshlq_u64(v.raw, neg_bits)); } -HWY_API Vec128 operator>>(const Vec128 v, - const Vec128 bits) { - const int64x1_t neg_bits = Neg(BitCast(Simd(), bits)).raw; - return Vec128(vshl_u64(v.raw, neg_bits)); +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + const int64x1_t neg_bits = Neg(BitCast(Full64(), bits)).raw; + return Vec64(vshl_u64(v.raw, neg_bits)); } HWY_API Vec128 operator>>(const Vec128 v, @@ -1081,20 +1110,20 @@ HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s64(v.raw, Neg(bits).raw)); } -HWY_API Vec128 operator>>(const Vec128 v, - const Vec128 bits) { - return Vec128(vshl_s64(v.raw, Neg(bits).raw)); +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + return Vec64(vshl_s64(v.raw, Neg(bits).raw)); } // ------------------------------ ShiftLeftSame (Shl) template HWY_API Vec128 ShiftLeftSame(const Vec128 v, int bits) { - return v << Set(Simd(), static_cast(bits)); + return v << Set(Simd(), static_cast(bits)); } template HWY_API Vec128 ShiftRightSame(const Vec128 v, int bits) { - return v >> Set(Simd(), static_cast(bits)); + return v >> Set(Simd(), static_cast(bits)); } // ------------------------------ Integer multiplication @@ -1256,10 +1285,9 @@ HWY_API Vec128 MulAdd(const Vec128 mul, #endif #if HWY_ARCH_ARM_A64 -HWY_API Vec128 MulAdd(const Vec128 mul, - const Vec128 x, - const Vec128 add) { - return Vec128(vfma_f64(add.raw, mul.raw, x.raw)); +HWY_API Vec64 MulAdd(const Vec64 mul, const Vec64 x, + const Vec64 add) { + return Vec64(vfma_f64(add.raw, mul.raw, x.raw)); } HWY_API Vec128 MulAdd(const Vec128 mul, const Vec128 x, const Vec128 add) { @@ -1290,10 +1318,9 @@ HWY_API Vec128 NegMulAdd(const Vec128 mul, #endif #if HWY_ARCH_ARM_A64 -HWY_API Vec128 NegMulAdd(const Vec128 mul, - const Vec128 x, - const Vec128 add) { - return Vec128(vfms_f64(add.raw, mul.raw, x.raw)); +HWY_API Vec64 NegMulAdd(const Vec64 mul, const Vec64 x, + const Vec64 add) { + return Vec64(vfms_f64(add.raw, mul.raw, x.raw)); } HWY_API Vec128 NegMulAdd(const Vec128 mul, const Vec128 x, @@ -1372,7 +1399,7 @@ HWY_API Vec128 Sqrt(const Vec128 v) { recip *= detail::ReciprocalSqrtStep(v * recip, recip); const auto root = v * recip; - return IfThenZeroElse(v == Zero(Simd()), root); + return IfThenZeroElse(v == Zero(Simd()), root); } #endif @@ -1389,7 +1416,7 @@ HWY_API Vec128 Not(const Vec128 v) { } template HWY_API Vec128 Not(const Vec128 v) { - const Simd d; + const Simd d; const Repartition d8; using V8 = decltype(Zero(d8)); return BitCast(d, V8(vmvn_u8(BitCast(d8, v).raw))); @@ -1401,32 +1428,34 @@ HWY_NEON_DEF_FUNCTION_INTS_UINTS(And, vand, _, 2) // Uses the u32/64 defined above. template HWY_API Vec128 And(const Vec128 a, const Vec128 b) { - const Simd, N> d; - return BitCast(Simd(), BitCast(d, a) & BitCast(d, b)); + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) & BitCast(du, b)); } // ------------------------------ AndNot -namespace internal { +namespace detail { // reversed_andnot returns a & ~b. HWY_NEON_DEF_FUNCTION_INTS_UINTS(reversed_andnot, vbic, _, 2) -} // namespace internal +} // namespace detail // Returns ~not_mask & mask. template HWY_API Vec128 AndNot(const Vec128 not_mask, const Vec128 mask) { - return internal::reversed_andnot(mask, not_mask); + return detail::reversed_andnot(mask, not_mask); } // Uses the u32/64 defined above. template HWY_API Vec128 AndNot(const Vec128 not_mask, const Vec128 mask) { - const Simd, N> du; - Vec128, N> ret = - internal::reversed_andnot(BitCast(du, mask), BitCast(du, not_mask)); - return BitCast(Simd(), ret); + const DFromV d; + const RebindToUnsigned du; + VFromD ret = + detail::reversed_andnot(BitCast(du, mask), BitCast(du, not_mask)); + return BitCast(d, ret); } // ------------------------------ Or @@ -1436,8 +1465,9 @@ HWY_NEON_DEF_FUNCTION_INTS_UINTS(Or, vorr, _, 2) // Uses the u32/64 defined above. template HWY_API Vec128 Or(const Vec128 a, const Vec128 b) { - const Simd, N> d; - return BitCast(Simd(), BitCast(d, a) | BitCast(d, b)); + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) | BitCast(du, b)); } // ------------------------------ Xor @@ -1447,8 +1477,24 @@ HWY_NEON_DEF_FUNCTION_INTS_UINTS(Xor, veor, _, 2) // Uses the u32/64 defined above. template HWY_API Vec128 Xor(const Vec128 a, const Vec128 b) { - const Simd, N> d; - return BitCast(Simd(), BitCast(d, a) ^ BitCast(d, b)); + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) ^ BitCast(du, b)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); } // ------------------------------ Operator overloads (internal-only if float) @@ -1486,7 +1532,7 @@ HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, Vec128 v) { template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, Vec128 v) { - const Simd d8; + const Simd d8; return Vec128(vcnt_u8(BitCast(d8, v).raw)); } @@ -1500,7 +1546,7 @@ HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, Vec128 v) { template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, Vec128 v) { - const Repartition> d8; + const Repartition> d8; const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); return Vec128(vpaddl_u8(bytes)); } @@ -1514,7 +1560,7 @@ HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, Vec128 v) { template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, Vec128 v) { - const Repartition> d8; + const Repartition> d8; const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); return Vec128(vpaddl_u16(vpaddl_u8(bytes))); } @@ -1528,7 +1574,7 @@ HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, Vec128 v) { template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, Vec128 v) { - const Repartition> d8; + const Repartition> d8; const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); return Vec128(vpaddl_u32(vpaddl_u16(vpaddl_u8(bytes)))); } @@ -1581,8 +1627,8 @@ HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabsq_f64(v.raw)); } -HWY_API Vec128 Abs(const Vec128 v) { - return Vec128(vabs_f64(v.raw)); +HWY_API Vec64 Abs(const Vec64 v) { + return Vec64(vabs_f64(v.raw)); } #endif @@ -1592,7 +1638,7 @@ template HWY_API Vec128 CopySign(const Vec128 magn, const Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); - const auto msb = SignBit(Simd()); + const auto msb = SignBit(Simd()); return Or(AndNot(msb, magn), And(msb, sign)); } @@ -1600,7 +1646,7 @@ template HWY_API Vec128 CopySignToAbs(const Vec128 abs, const Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); - return Or(abs, And(SignBit(Simd()), sign)); + return Or(abs, And(SignBit(Simd()), sign)); } // ------------------------------ BroadcastSignBit @@ -1617,36 +1663,30 @@ HWY_API Vec128 BroadcastSignBit(const Vec128 v) { // Mask and Vec have the same representation (true = FF..FF). template HWY_API Mask128 MaskFromVec(const Vec128 v) { - const Simd, N> du; + const Simd, N, 0> du; return Mask128(BitCast(du, v).raw); } -// DEPRECATED template -HWY_API Vec128 VecFromMask(const Mask128 v) { - return BitCast(Simd(), Vec128, N>(v.raw)); -} - -template -HWY_API Vec128 VecFromMask(Simd d, const Mask128 v) { +HWY_API Vec128 VecFromMask(Simd d, const Mask128 v) { return BitCast(d, Vec128, N>(v.raw)); } // ------------------------------ RebindMask template -HWY_API Mask128 RebindMask(Simd dto, Mask128 m) { +HWY_API Mask128 RebindMask(Simd dto, Mask128 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); - return MaskFromVec(BitCast(dto, VecFromMask(Simd(), m))); + return MaskFromVec(BitCast(dto, VecFromMask(Simd(), m))); } // ------------------------------ IfThenElse(mask, yes, no) = mask ? b : a. #define HWY_NEON_BUILD_TPL_HWY_IF -#define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128 -#define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \ - const Mask128 mask, const Vec128 yes, \ - const Vec128 no +#define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \ + const Mask128 mask, const Vec128 yes, \ + const Vec128 no #define HWY_NEON_BUILD_ARG_HWY_IF mask.raw, yes.raw, no.raw HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) @@ -1660,19 +1700,30 @@ HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) template HWY_API Vec128 IfThenElseZero(const Mask128 mask, const Vec128 yes) { - return yes & VecFromMask(Simd(), mask); + return yes & VecFromMask(Simd(), mask); } // mask ? 0 : no template HWY_API Vec128 IfThenZeroElse(const Mask128 mask, const Vec128 no) { - return AndNot(VecFromMask(Simd(), mask), no); + return AndNot(VecFromMask(Simd(), mask), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const Simd d; + const RebindToSigned di; + + Mask128 m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); } template HWY_API Vec128 ZeroIfNegative(Vec128 v) { - const auto zero = Zero(Simd()); + const auto zero = Zero(Simd()); return Max(zero, v); } @@ -1680,30 +1731,30 @@ HWY_API Vec128 ZeroIfNegative(Vec128 v) { template HWY_API Mask128 Not(const Mask128 m) { - return MaskFromVec(Not(VecFromMask(Simd(), m))); + return MaskFromVec(Not(VecFromMask(Simd(), m))); } template HWY_API Mask128 And(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Or(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } @@ -1714,14 +1765,14 @@ HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { // ------------------------------ Shuffle2301 (for i64 compares) // Swap 32-bit halves in 64-bits -HWY_API Vec128 Shuffle2301(const Vec128 v) { - return Vec128(vrev64_u32(v.raw)); +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_u32(v.raw)); } -HWY_API Vec128 Shuffle2301(const Vec128 v) { - return Vec128(vrev64_s32(v.raw)); +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_s32(v.raw)); } -HWY_API Vec128 Shuffle2301(const Vec128 v) { - return Vec128(vrev64_f32(v.raw)); +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_f32(v.raw)); } HWY_API Vec128 Shuffle2301(const Vec128 v) { return Vec128(vrev64q_u32(v.raw)); @@ -1734,9 +1785,9 @@ HWY_API Vec128 Shuffle2301(const Vec128 v) { } #define HWY_NEON_BUILD_TPL_HWY_COMPARE -#define HWY_NEON_BUILD_RET_HWY_COMPARE(type, size) Mask128 +#define HWY_NEON_BUILD_RET_HWY_COMPARE(type, size) Mask128 #define HWY_NEON_BUILD_PARAM_HWY_COMPARE(type, size) \ - const Vec128 a, const Vec128 b + const Vec128 a, const Vec128 b #define HWY_NEON_BUILD_ARG_HWY_COMPARE a.raw, b.raw // ------------------------------ Equality @@ -1779,8 +1830,8 @@ HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<=, vcle, _, HWY_COMPARE) template HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { - const Simd d32; - const Simd d64; + const Simd d32; + const Simd d64; const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); const auto cmp64 = cmp32 & Shuffle2301(cmp32); return MaskFromVec(BitCast(d64, cmp64)); @@ -1789,8 +1840,8 @@ HWY_API Mask128 operator==(const Vec128 a, template HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { - const Simd d32; - const Simd d64; + const Simd d32; + const Simd d64; const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); const auto cmp64 = cmp32 & Shuffle2301(cmp32); return MaskFromVec(BitCast(d64, cmp64)); @@ -1801,17 +1852,17 @@ HWY_API Mask128 operator<(const Vec128 a, const int64x2_t sub = vqsubq_s64(a.raw, b.raw); return MaskFromVec(BroadcastSignBit(Vec128(sub))); } -HWY_API Mask128 operator<(const Vec128 a, - const Vec128 b) { +HWY_API Mask128 operator<(const Vec64 a, + const Vec64 b) { const int64x1_t sub = vqsub_s64(a.raw, b.raw); - return MaskFromVec(BroadcastSignBit(Vec128(sub))); + return MaskFromVec(BroadcastSignBit(Vec64(sub))); } template HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { - const Simd di; - const Simd du; + const DFromV du; + const RebindToSigned di; const Vec128 msb = AndNot(a, b) | AndNot(a ^ b, a - b); return MaskFromVec(BitCast(du, BroadcastSignBit(BitCast(di, msb)))); } @@ -1832,7 +1883,7 @@ HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { // ------------------------------ FirstN (Iota, Lt) template -HWY_API Mask128 FirstN(const Simd d, size_t num) { +HWY_API Mask128 FirstN(const Simd d, size_t num) { const RebindToSigned di; // Signed comparisons are cheaper. return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); } @@ -1840,9 +1891,9 @@ HWY_API Mask128 FirstN(const Simd d, size_t num) { // ------------------------------ TestBit (Eq) #define HWY_NEON_BUILD_TPL_HWY_TESTBIT -#define HWY_NEON_BUILD_RET_HWY_TESTBIT(type, size) Mask128 +#define HWY_NEON_BUILD_RET_HWY_TESTBIT(type, size) Mask128 #define HWY_NEON_BUILD_PARAM_HWY_TESTBIT(type, size) \ - Vec128 v, Vec128 bit + Vec128 v, Vec128 bit #define HWY_NEON_BUILD_ARG_HWY_TESTBIT v.raw, bit.raw #if HWY_ARCH_ARM_A64 @@ -1878,11 +1929,11 @@ HWY_API Vec128 Abs(const Vec128 v) { return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); #endif } -HWY_API Vec128 Abs(const Vec128 v) { +HWY_API Vec64 Abs(const Vec64 v) { #if HWY_ARCH_ARM_A64 - return Vec128(vabs_s64(v.raw)); + return Vec64(vabs_s64(v.raw)); #else - const auto zero = Zero(Simd()); + const auto zero = Zero(Full64()); return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); #endif } @@ -1898,8 +1949,8 @@ HWY_API Vec128 Min(const Vec128 a, #if HWY_ARCH_ARM_A64 return IfThenElse(b < a, b, a); #else - const Simd du; - const Simd di; + const DFromV du; + const RebindToSigned di; return BitCast(du, BitCast(di, a) - BitCast(di, detail::SaturatedSub(a, b))); #endif } @@ -1936,8 +1987,8 @@ HWY_API Vec128 Max(const Vec128 a, #if HWY_ARCH_ARM_A64 return IfThenElse(b < a, a, b); #else - const Simd du; - const Simd di; + const DFromV du; + const RebindToSigned di; return BitCast(du, BitCast(di, b) + BitCast(di, detail::SaturatedSub(a, b))); #endif } @@ -2012,46 +2063,46 @@ HWY_API Vec128 LoadU(Full128 /* tag */, // ------------------------------ Load 64 -HWY_API Vec128 LoadU(Simd /* tag */, - const uint8_t* HWY_RESTRICT p) { - return Vec128(vld1_u8(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint8_t* HWY_RESTRICT p) { + return Vec64(vld1_u8(p)); } -HWY_API Vec128 LoadU(Simd /* tag */, - const uint16_t* HWY_RESTRICT p) { - return Vec128(vld1_u16(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint16_t* HWY_RESTRICT p) { + return Vec64(vld1_u16(p)); } -HWY_API Vec128 LoadU(Simd /* tag */, - const uint32_t* HWY_RESTRICT p) { - return Vec128(vld1_u32(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint32_t* HWY_RESTRICT p) { + return Vec64(vld1_u32(p)); } -HWY_API Vec128 LoadU(Simd /* tag */, - const uint64_t* HWY_RESTRICT p) { - return Vec128(vld1_u64(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint64_t* HWY_RESTRICT p) { + return Vec64(vld1_u64(p)); } -HWY_API Vec128 LoadU(Simd /* tag */, - const int8_t* HWY_RESTRICT p) { - return Vec128(vld1_s8(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const int8_t* HWY_RESTRICT p) { + return Vec64(vld1_s8(p)); } -HWY_API Vec128 LoadU(Simd /* tag */, - const int16_t* HWY_RESTRICT p) { - return Vec128(vld1_s16(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const int16_t* HWY_RESTRICT p) { + return Vec64(vld1_s16(p)); } -HWY_API Vec128 LoadU(Simd /* tag */, - const int32_t* HWY_RESTRICT p) { - return Vec128(vld1_s32(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const int32_t* HWY_RESTRICT p) { + return Vec64(vld1_s32(p)); } -HWY_API Vec128 LoadU(Simd /* tag */, - const int64_t* HWY_RESTRICT p) { - return Vec128(vld1_s64(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const int64_t* HWY_RESTRICT p) { + return Vec64(vld1_s64(p)); } -HWY_API Vec128 LoadU(Simd /* tag */, - const float* HWY_RESTRICT p) { - return Vec128(vld1_f32(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const float* HWY_RESTRICT p) { + return Vec64(vld1_f32(p)); } #if HWY_ARCH_ARM_A64 -HWY_API Vec128 LoadU(Simd /* tag */, - const double* HWY_RESTRICT p) { - return Vec128(vld1_f64(p)); +HWY_API Vec64 LoadU(Full64 /* tag */, + const double* HWY_RESTRICT p) { + return Vec64(vld1_f64(p)); } #endif @@ -2062,86 +2113,85 @@ HWY_API Vec128 LoadU(Simd /* tag */, // we don't actually care what is in it, and we don't want // to introduce extra overhead by initializing it to something. -HWY_API Vec128 LoadU(Simd /*tag*/, - const uint8_t* HWY_RESTRICT p) { - uint32x2_t a = Undefined(Simd()).raw; +HWY_API Vec32 LoadU(Full32 /*tag*/, + const uint8_t* HWY_RESTRICT p) { + uint32x2_t a = Undefined(Full64()).raw; uint32x2_t b = vld1_lane_u32(reinterpret_cast(p), a, 0); - return Vec128(vreinterpret_u8_u32(b)); + return Vec32(vreinterpret_u8_u32(b)); } -HWY_API Vec128 LoadU(Simd /*tag*/, - const uint16_t* HWY_RESTRICT p) { - uint32x2_t a = Undefined(Simd()).raw; +HWY_API Vec32 LoadU(Full32 /*tag*/, + const uint16_t* HWY_RESTRICT p) { + uint32x2_t a = Undefined(Full64()).raw; uint32x2_t b = vld1_lane_u32(reinterpret_cast(p), a, 0); - return Vec128(vreinterpret_u16_u32(b)); + return Vec32(vreinterpret_u16_u32(b)); } -HWY_API Vec128 LoadU(Simd /*tag*/, - const uint32_t* HWY_RESTRICT p) { - uint32x2_t a = Undefined(Simd()).raw; +HWY_API Vec32 LoadU(Full32 /*tag*/, + const uint32_t* HWY_RESTRICT p) { + uint32x2_t a = Undefined(Full64()).raw; uint32x2_t b = vld1_lane_u32(p, a, 0); - return Vec128(b); + return Vec32(b); } -HWY_API Vec128 LoadU(Simd /*tag*/, - const int8_t* HWY_RESTRICT p) { - int32x2_t a = Undefined(Simd()).raw; +HWY_API Vec32 LoadU(Full32 /*tag*/, + const int8_t* HWY_RESTRICT p) { + int32x2_t a = Undefined(Full64()).raw; int32x2_t b = vld1_lane_s32(reinterpret_cast(p), a, 0); - return Vec128(vreinterpret_s8_s32(b)); + return Vec32(vreinterpret_s8_s32(b)); } -HWY_API Vec128 LoadU(Simd /*tag*/, - const int16_t* HWY_RESTRICT p) { - int32x2_t a = Undefined(Simd()).raw; +HWY_API Vec32 LoadU(Full32 /*tag*/, + const int16_t* HWY_RESTRICT p) { + int32x2_t a = Undefined(Full64()).raw; int32x2_t b = vld1_lane_s32(reinterpret_cast(p), a, 0); - return Vec128(vreinterpret_s16_s32(b)); + return Vec32(vreinterpret_s16_s32(b)); } -HWY_API Vec128 LoadU(Simd /*tag*/, - const int32_t* HWY_RESTRICT p) { - int32x2_t a = Undefined(Simd()).raw; +HWY_API Vec32 LoadU(Full32 /*tag*/, + const int32_t* HWY_RESTRICT p) { + int32x2_t a = Undefined(Full64()).raw; int32x2_t b = vld1_lane_s32(p, a, 0); - return Vec128(b); + return Vec32(b); } -HWY_API Vec128 LoadU(Simd /*tag*/, - const float* HWY_RESTRICT p) { - float32x2_t a = Undefined(Simd()).raw; +HWY_API Vec32 LoadU(Full32 /*tag*/, const float* HWY_RESTRICT p) { + float32x2_t a = Undefined(Full64()).raw; float32x2_t b = vld1_lane_f32(p, a, 0); - return Vec128(b); + return Vec32(b); } // ------------------------------ Load 16 -HWY_API Vec128 LoadU(Simd /*tag*/, +HWY_API Vec128 LoadU(Simd /*tag*/, const uint8_t* HWY_RESTRICT p) { - uint16x4_t a = Undefined(Simd()).raw; + uint16x4_t a = Undefined(Full64()).raw; uint16x4_t b = vld1_lane_u16(reinterpret_cast(p), a, 0); return Vec128(vreinterpret_u8_u16(b)); } -HWY_API Vec128 LoadU(Simd /*tag*/, +HWY_API Vec128 LoadU(Simd /*tag*/, const uint16_t* HWY_RESTRICT p) { - uint16x4_t a = Undefined(Simd()).raw; + uint16x4_t a = Undefined(Full64()).raw; uint16x4_t b = vld1_lane_u16(p, a, 0); return Vec128(b); } -HWY_API Vec128 LoadU(Simd /*tag*/, +HWY_API Vec128 LoadU(Simd /*tag*/, const int8_t* HWY_RESTRICT p) { - int16x4_t a = Undefined(Simd()).raw; + int16x4_t a = Undefined(Full64()).raw; int16x4_t b = vld1_lane_s16(reinterpret_cast(p), a, 0); return Vec128(vreinterpret_s8_s16(b)); } -HWY_API Vec128 LoadU(Simd /*tag*/, +HWY_API Vec128 LoadU(Simd /*tag*/, const int16_t* HWY_RESTRICT p) { - int16x4_t a = Undefined(Simd()).raw; + int16x4_t a = Undefined(Full64()).raw; int16x4_t b = vld1_lane_s16(p, a, 0); return Vec128(b); } // ------------------------------ Load 8 -HWY_API Vec128 LoadU(Simd d, +HWY_API Vec128 LoadU(Simd d, const uint8_t* HWY_RESTRICT p) { uint8x8_t a = Undefined(d).raw; uint8x8_t b = vld1_lane_u8(p, a, 0); return Vec128(b); } -HWY_API Vec128 LoadU(Simd d, +HWY_API Vec128 LoadU(Simd d, const int8_t* HWY_RESTRICT p) { int8x8_t a = Undefined(d).raw; int8x8_t b = vld1_lane_s8(p, a, 0); @@ -2150,35 +2200,36 @@ HWY_API Vec128 LoadU(Simd d, // [b]float16_t use the same Raw as uint16_t, so forward to that. template -HWY_API Vec128 LoadU(Simd /*d*/, +HWY_API Vec128 LoadU(Simd d, const float16_t* HWY_RESTRICT p) { - const Simd du16; + const RebindToUnsigned du16; const auto pu16 = reinterpret_cast(p); return Vec128(LoadU(du16, pu16).raw); } template -HWY_API Vec128 LoadU(Simd /*d*/, +HWY_API Vec128 LoadU(Simd d, const bfloat16_t* HWY_RESTRICT p) { - const Simd du16; + const RebindToUnsigned du16; const auto pu16 = reinterpret_cast(p); return Vec128(LoadU(du16, pu16).raw); } // On ARM, Load is the same as LoadU. template -HWY_API Vec128 Load(Simd d, const T* HWY_RESTRICT p) { +HWY_API Vec128 Load(Simd d, const T* HWY_RESTRICT p) { return LoadU(d, p); } template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } // 128-bit SIMD => nothing to duplicate, same as an unaligned load. template -HWY_API Vec128 LoadDup128(Simd d, const T* const HWY_RESTRICT p) { +HWY_API Vec128 LoadDup128(Simd d, + const T* const HWY_RESTRICT p) { return LoadU(d, p); } @@ -2229,44 +2280,44 @@ HWY_API void StoreU(const Vec128 v, Full128 /* tag */, // ------------------------------ Store 64 -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, uint8_t* HWY_RESTRICT p) { vst1_u8(p, v.raw); } -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, uint16_t* HWY_RESTRICT p) { vst1_u16(p, v.raw); } -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, uint32_t* HWY_RESTRICT p) { vst1_u32(p, v.raw); } -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, uint64_t* HWY_RESTRICT p) { vst1_u64(p, v.raw); } -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, int8_t* HWY_RESTRICT p) { vst1_s8(p, v.raw); } -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, int16_t* HWY_RESTRICT p) { vst1_s16(p, v.raw); } -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, int32_t* HWY_RESTRICT p) { vst1_s32(p, v.raw); } -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, int64_t* HWY_RESTRICT p) { vst1_s64(p, v.raw); } -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, float* HWY_RESTRICT p) { vst1_f32(p, v.raw); } #if HWY_ARCH_ARM_A64 -HWY_API void StoreU(const Vec128 v, Simd /* tag */, +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, double* HWY_RESTRICT p) { vst1_f64(p, v.raw); } @@ -2274,90 +2325,90 @@ HWY_API void StoreU(const Vec128 v, Simd /* tag */, // ------------------------------ Store 32 -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec32 v, Full32, uint8_t* HWY_RESTRICT p) { uint32x2_t a = vreinterpret_u32_u8(v.raw); vst1_lane_u32(reinterpret_cast(p), a, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec32 v, Full32, uint16_t* HWY_RESTRICT p) { uint32x2_t a = vreinterpret_u32_u16(v.raw); vst1_lane_u32(reinterpret_cast(p), a, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec32 v, Full32, uint32_t* HWY_RESTRICT p) { vst1_lane_u32(p, v.raw, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec32 v, Full32, int8_t* HWY_RESTRICT p) { int32x2_t a = vreinterpret_s32_s8(v.raw); vst1_lane_s32(reinterpret_cast(p), a, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec32 v, Full32, int16_t* HWY_RESTRICT p) { int32x2_t a = vreinterpret_s32_s16(v.raw); vst1_lane_s32(reinterpret_cast(p), a, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec32 v, Full32, int32_t* HWY_RESTRICT p) { vst1_lane_s32(p, v.raw, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec32 v, Full32, float* HWY_RESTRICT p) { vst1_lane_f32(p, v.raw, 0); } // ------------------------------ Store 16 -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec128 v, Simd, uint8_t* HWY_RESTRICT p) { uint16x4_t a = vreinterpret_u16_u8(v.raw); vst1_lane_u16(reinterpret_cast(p), a, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec128 v, Simd, uint16_t* HWY_RESTRICT p) { vst1_lane_u16(p, v.raw, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec128 v, Simd, int8_t* HWY_RESTRICT p) { int16x4_t a = vreinterpret_s16_s8(v.raw); vst1_lane_s16(reinterpret_cast(p), a, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec128 v, Simd, int16_t* HWY_RESTRICT p) { vst1_lane_s16(p, v.raw, 0); } // ------------------------------ Store 8 -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec128 v, Simd, uint8_t* HWY_RESTRICT p) { vst1_lane_u8(p, v.raw, 0); } -HWY_API void StoreU(const Vec128 v, Simd, +HWY_API void StoreU(const Vec128 v, Simd, int8_t* HWY_RESTRICT p) { vst1_lane_s8(p, v.raw, 0); } // [b]float16_t use the same Raw as uint16_t, so forward to that. template -HWY_API void StoreU(Vec128 v, Simd /* tag */, +HWY_API void StoreU(Vec128 v, Simd d, float16_t* HWY_RESTRICT p) { - const Simd du16; + const RebindToUnsigned du16; const auto pu16 = reinterpret_cast(p); return StoreU(Vec128(v.raw), du16, pu16); } template -HWY_API void StoreU(Vec128 v, Simd /* tag */, +HWY_API void StoreU(Vec128 v, Simd d, bfloat16_t* HWY_RESTRICT p) { - const Simd du16; + const RebindToUnsigned du16; const auto pu16 = reinterpret_cast(p); return StoreU(Vec128(v.raw), du16, pu16); } // On ARM, Store is the same as StoreU. template -HWY_API void Store(Vec128 v, Simd d, T* HWY_RESTRICT aligned) { +HWY_API void Store(Vec128 v, Simd d, T* HWY_RESTRICT aligned) { StoreU(v, d, aligned); } @@ -2366,7 +2417,7 @@ HWY_API void Store(Vec128 v, Simd d, T* HWY_RESTRICT aligned) { // Same as aligned stores on non-x86. template -HWY_API void Stream(const Vec128 v, Simd d, +HWY_API void Stream(const Vec128 v, Simd d, T* HWY_RESTRICT aligned) { Store(v, d, aligned); } @@ -2377,72 +2428,69 @@ HWY_API void Stream(const Vec128 v, Simd d, // Unsigned: zero-extend to full vector. HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec64 v) { return Vec128(vmovl_u8(v.raw)); } HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec32 v) { uint16x8_t a = vmovl_u8(v.raw); return Vec128(vmovl_u16(vget_low_u16(a))); } HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec64 v) { return Vec128(vmovl_u16(v.raw)); } HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec64 v) { return Vec128(vmovl_u32(v.raw)); } -HWY_API Vec128 PromoteTo(Full128 d, - const Vec128 v) { +HWY_API Vec128 PromoteTo(Full128 d, const Vec64 v) { return BitCast(d, Vec128(vmovl_u8(v.raw))); } -HWY_API Vec128 PromoteTo(Full128 d, - const Vec128 v) { +HWY_API Vec128 PromoteTo(Full128 d, const Vec32 v) { uint16x8_t a = vmovl_u8(v.raw); return BitCast(d, Vec128(vmovl_u16(vget_low_u16(a)))); } -HWY_API Vec128 PromoteTo(Full128 d, - const Vec128 v) { +HWY_API Vec128 PromoteTo(Full128 d, const Vec64 v) { return BitCast(d, Vec128(vmovl_u16(v.raw))); } // Unsigned: zero-extend to half vector. template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_u16(vmovl_u8(v.raw))); } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { uint16x8_t a = vmovl_u8(v.raw); return Vec128(vget_low_u32(vmovl_u16(vget_low_u16(a)))); } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_u32(vmovl_u16(v.raw))); } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_u64(vmovl_u32(v.raw))); } template -HWY_API Vec128 PromoteTo(Simd d, +HWY_API Vec128 PromoteTo(Simd d, const Vec128 v) { return BitCast(d, Vec128(vget_low_u16(vmovl_u8(v.raw)))); } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { uint16x8_t a = vmovl_u8(v.raw); uint32x4_t b = vmovl_u16(vget_low_u16(a)); return Vec128(vget_low_s32(vreinterpretq_s32_u32(b))); } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { uint32x4_t a = vmovl_u16(v.raw); return Vec128(vget_low_s32(vreinterpretq_s32_u32(a))); @@ -2450,43 +2498,43 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, // Signed: replicate sign bit to full vector. HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec64 v) { return Vec128(vmovl_s8(v.raw)); } HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec32 v) { int16x8_t a = vmovl_s8(v.raw); return Vec128(vmovl_s16(vget_low_s16(a))); } HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec64 v) { return Vec128(vmovl_s16(v.raw)); } HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec64 v) { return Vec128(vmovl_s32(v.raw)); } // Signed: replicate sign bit to half vector. template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_s16(vmovl_s8(v.raw))); } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { int16x8_t a = vmovl_s8(v.raw); int32x4_t b = vmovl_s16(vget_low_s16(a)); return Vec128(vget_low_s32(b)); } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_s32(vmovl_s16(v.raw))); } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_s64(vmovl_s32(v.raw))); } @@ -2499,7 +2547,7 @@ HWY_API Vec128 PromoteTo(Full128 /* tag */, return Vec128(f32); } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { const float32x4_t f32 = vcvt_f32_f16(vreinterpret_f16_u16(v.raw)); return Vec128(vget_low_f32(f32)); @@ -2508,11 +2556,10 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, #else template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd df32, const Vec128 v) { - const Simd di32; - const Simd du32; - const Simd df32; + const RebindToSigned di32; + const RebindToUnsigned du32; // Expand to u32 so we can shift. const auto bits16 = PromoteTo(du32, Vec128{v.raw}); const auto sign = ShiftRight<15>(bits16); @@ -2534,25 +2581,25 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, #if HWY_ARCH_ARM_A64 HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec64 v) { return Vec128(vcvt_f64_f32(v.raw)); } -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_low_f64(vcvt_f64_f32(v.raw))); +HWY_API Vec64 PromoteTo(Full64 /* tag */, + const Vec32 v) { + return Vec64(vget_low_f64(vcvt_f64_f32(v.raw))); } HWY_API Vec128 PromoteTo(Full128 /* tag */, - const Vec128 v) { + const Vec64 v) { const int64x2_t i64 = vmovl_s32(v.raw); return Vec128(vcvtq_f64_s64(i64)); } -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { +HWY_API Vec64 PromoteTo(Full64 /* tag */, + const Vec32 v) { const int64x1_t i64 = vget_low_s64(vmovl_s32(v.raw)); - return Vec128(vcvt_f64_s64(i64)); + return Vec64(vcvt_f64_s64(i64)); } #endif @@ -2560,75 +2607,75 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, // ------------------------------ Demotions (full -> part w/ narrow lanes) // From full vector to half or quarter -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128(vqmovun_s32(v.raw)); +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovun_s32(v.raw)); } -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128(vqmovn_s32(v.raw)); +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovn_s32(v.raw)); } -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { +HWY_API Vec32 DemoteTo(Full32 /* tag */, + const Vec128 v) { const uint16x4_t a = vqmovun_s32(v.raw); - return Vec128(vqmovn_u16(vcombine_u16(a, a))); + return Vec32(vqmovn_u16(vcombine_u16(a, a))); } -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128(vqmovun_s16(v.raw)); +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovun_s16(v.raw)); } -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { +HWY_API Vec32 DemoteTo(Full32 /* tag */, + const Vec128 v) { const int16x4_t a = vqmovn_s32(v.raw); - return Vec128(vqmovn_s16(vcombine_s16(a, a))); + return Vec32(vqmovn_s16(vcombine_s16(a, a))); } -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128(vqmovn_s16(v.raw)); +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovn_s16(v.raw)); } // From half vector to partial half template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vqmovun_s32(vcombine_s32(v.raw, v.raw))); } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vqmovn_s32(vcombine_s32(v.raw, v.raw))); } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const uint16x4_t a = vqmovun_s32(vcombine_s32(v.raw, v.raw)); return Vec128(vqmovn_u16(vcombine_u16(a, a))); } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vqmovun_s16(vcombine_s16(v.raw, v.raw))); } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const int16x4_t a = vqmovn_s32(vcombine_s32(v.raw, v.raw)); return Vec128(vqmovn_s16(vcombine_s16(a, a))); } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vqmovn_s16(vcombine_s16(v.raw, v.raw))); } #if __ARM_FP & 2 -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Full64 /* tag */, const Vec128 v) { return Vec128{vreinterpret_u16_f16(vcvt_f16_f32(v.raw))}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const float16x4_t f16 = vcvt_f16_f32(vcombine_f32(v.raw, v.raw)); return Vec128(vreinterpret_u16_f16(f16)); @@ -2637,11 +2684,11 @@ HWY_API Vec128 DemoteTo(Simd /* tag */, #else template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd df16, const Vec128 v) { - const Simd di; - const Simd du; - const Simd du16; + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; const auto bits32 = BitCast(du, v); const auto sign = ShiftRight<31>(bits32); const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); @@ -2669,7 +2716,7 @@ HWY_API Vec128 DemoteTo(Simd /* tag */, #endif template -HWY_API Vec128 DemoteTo(Simd dbf16, +HWY_API Vec128 DemoteTo(Simd dbf16, const Vec128 v) { const Rebind di32; const Rebind du32; // for logical shift right @@ -2680,34 +2727,32 @@ HWY_API Vec128 DemoteTo(Simd dbf16, #if HWY_ARCH_ARM_A64 -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128(vcvt_f32_f64(v.raw)); +HWY_API Vec64 DemoteTo(Full64 /* tag */, const Vec128 v) { + return Vec64(vcvt_f32_f64(v.raw)); } -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); +HWY_API Vec32 DemoteTo(Full32 /* tag */, const Vec64 v) { + return Vec32(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); } -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { const int64x2_t i64 = vcvtq_s64_f64(v.raw); - return Vec128(vqmovn_s64(i64)); + return Vec64(vqmovn_s64(i64)); } -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { +HWY_API Vec32 DemoteTo(Full32 /* tag */, + const Vec64 v) { const int64x1_t i64 = vcvt_s64_f64(v.raw); // There is no i64x1 -> i32x1 narrow, so expand to int64x2_t first. const int64x2_t i64x2 = vcombine_s64(i64, i64); - return Vec128(vqmovn_s64(i64x2)); + return Vec32(vqmovn_s64(i64x2)); } #endif -HWY_API Vec128 U8FromU32(const Vec128 v) { +HWY_API Vec32 U8FromU32(const Vec128 v) { const uint8x16_t org_v = detail::BitCastToByte(v).raw; const uint8x16_t w = vuzp1q_u8(org_v, org_v); - return Vec128(vget_low_u8(vuzp1q_u8(w, w))); + return Vec32(vget_low_u8(vuzp1q_u8(w, w))); } template HWY_API Vec128 U8FromU32(const Vec128 v) { @@ -2723,18 +2768,18 @@ HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { - Vec128 a = DemoteTo(Simd(), v); + Vec128 a = DemoteTo(Simd(), v); Vec128 b; uint16x8_t c = vcombine_u16(a.raw, b.raw); return Vec128(vqmovn_u16(c)); } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { - Vec128 a = DemoteTo(Simd(), v); + Vec128 a = DemoteTo(Simd(), v); Vec128 b; int16x8_t c = vcombine_s16(a.raw, b.raw); return Vec128(vqmovn_s16(c)); @@ -2749,7 +2794,7 @@ HWY_API Vec128 ConvertTo(Full128 /* tag */, return Vec128(vcvtq_f32_s32(v.raw)); } template -HWY_API Vec128 ConvertTo(Simd /* tag */, +HWY_API Vec128 ConvertTo(Simd /* tag */, const Vec128 v) { return Vec128(vcvt_f32_s32(v.raw)); } @@ -2760,7 +2805,7 @@ HWY_API Vec128 ConvertTo(Full128 /* tag */, return Vec128(vcvtq_s32_f32(v.raw)); } template -HWY_API Vec128 ConvertTo(Simd /* tag */, +HWY_API Vec128 ConvertTo(Simd /* tag */, const Vec128 v) { return Vec128(vcvt_s32_f32(v.raw)); } @@ -2771,9 +2816,9 @@ HWY_API Vec128 ConvertTo(Full128 /* tag */, const Vec128 v) { return Vec128(vcvtq_f64_s64(v.raw)); } -HWY_API Vec128 ConvertTo(Simd /* tag */, - const Vec128 v) { - return Vec128(vcvt_f64_s64(v.raw)); +HWY_API Vec64 ConvertTo(Full64 /* tag */, + const Vec64 v) { + return Vec64(vcvt_f64_s64(v.raw)); } // Truncates (rounds toward zero). @@ -2781,9 +2826,9 @@ HWY_API Vec128 ConvertTo(Full128 /* tag */, const Vec128 v) { return Vec128(vcvtq_s64_f64(v.raw)); } -HWY_API Vec128 ConvertTo(Simd /* tag */, - const Vec128 v) { - return Vec128(vcvt_s64_f64(v.raw)); +HWY_API Vec64 ConvertTo(Full64 /* tag */, + const Vec64 v) { + return Vec64(vcvt_s64_f64(v.raw)); } #endif @@ -2817,14 +2862,14 @@ namespace detail { // large (i.e. the value is already an integer). template HWY_INLINE Mask128 UseInt(const Vec128 v) { - return Abs(v) < Set(Simd(), MantissaEnd()); + return Abs(v) < Set(Simd(), MantissaEnd()); } } // namespace detail template HWY_API Vec128 Trunc(const Vec128 v) { - const Simd df; + const DFromV df; const RebindToSigned di; const auto integer = ConvertTo(di, v); // round toward 0 @@ -2835,7 +2880,7 @@ HWY_API Vec128 Trunc(const Vec128 v) { template HWY_API Vec128 Round(const Vec128 v) { - const Simd df; + const DFromV df; // ARMv7 also lacks a native NearestInt, but we can instead rely on rounding // (we assume the current mode is nearest-even) after addition with a large @@ -2852,7 +2897,7 @@ HWY_API Vec128 Round(const Vec128 v) { template HWY_API Vec128 Ceil(const Vec128 v) { - const Simd df; + const DFromV df; const RebindToSigned di; const auto integer = ConvertTo(di, v); // round toward 0 @@ -2866,8 +2911,8 @@ HWY_API Vec128 Ceil(const Vec128 v) { template HWY_API Vec128 Floor(const Vec128 v) { - const Simd df; - const Simd di; + const DFromV df; + const RebindToSigned di; const auto integer = ConvertTo(di, v); // round toward 0 const auto int_f = ConvertTo(df, integer); @@ -2896,7 +2941,7 @@ HWY_API Vec128 NearestInt(const Vec128 v) { template HWY_API Vec128 NearestInt(const Vec128 v) { - const Simd di; + const RebindToSigned> di; return ConvertTo(di, Round(v)); } @@ -2912,41 +2957,42 @@ HWY_API Vec128 LowerHalf(const Vec128 v) { return Vec128(v.raw); } -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_u8(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u8(v.raw)); } -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_u16(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u16(v.raw)); } -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_u32(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u32(v.raw)); } -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_u64(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u64(v.raw)); } -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_s8(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s8(v.raw)); } -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_s16(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s16(v.raw)); } -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_s32(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s32(v.raw)); } -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_s64(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s64(v.raw)); } -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_f32(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_f32(v.raw)); } #if HWY_ARCH_ARM_A64 -HWY_API Vec128 LowerHalf(const Vec128 v) { - return Vec128(vget_low_f64(v.raw)); +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_f64(v.raw)); } #endif template -HWY_API Vec128 LowerHalf(Simd /* tag */, Vec128 v) { +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { return LowerHalf(v); } @@ -2962,8 +3008,8 @@ HWY_API V128 CombineShiftRightBytes(Full128 d, V128 hi, V128 lo) { } // 64-bit -template > -HWY_API V64 CombineShiftRightBytes(Simd d, V64 hi, V64 lo) { +template +HWY_API Vec64 CombineShiftRightBytes(Full64 d, Vec64 hi, Vec64 lo) { static_assert(0 < kBytes && kBytes < 8, "kBytes must be in [1, 7]"); const Repartition d8; uint8x8_t v8 = vext_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); @@ -2991,7 +3037,7 @@ struct ShiftLeftBytesT { template HWY_INLINE Vec128 operator()(const Vec128 v) { // Expand to 64-bit so we only use the native EXT instruction. - const Simd d64; + const Full64 d64; const auto zero64 = Zero(d64); const decltype(zero64) v64(v.raw); return Vec128( @@ -3009,7 +3055,7 @@ template <> struct ShiftLeftBytesT<0xFF> { template HWY_INLINE Vec128 operator()(const Vec128 /* v */) { - return Zero(Simd()); + return Zero(Simd()); } }; @@ -3017,11 +3063,11 @@ template struct ShiftRightBytesT { template HWY_INLINE Vec128 operator()(Vec128 v) { - const Simd d; + const Simd d; // For < 64-bit vectors, zero undefined lanes so we shift in zeros. if (N * sizeof(T) < 8) { constexpr size_t kReg = N * sizeof(T) == 16 ? 16 : 8; - const Simd dreg; + const Simd dreg; v = Vec128( IfThenElseZero(FirstN(dreg, N), VFromD(v.raw)).raw); } @@ -3039,55 +3085,55 @@ template <> struct ShiftRightBytesT<0xFF> { template HWY_INLINE Vec128 operator()(const Vec128 /* v */) { - return Zero(Simd()); + return Zero(Simd()); } }; } // namespace detail template -HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { return detail::ShiftLeftBytesT < kBytes >= N * sizeof(T) ? 0xFF : kBytes > ()(v); } template HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { - return ShiftLeftBytes(Simd(), v); + return ShiftLeftBytes(Simd(), v); } template -HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { - return ShiftLeftLanes(Simd(), v); + return ShiftLeftLanes(Simd(), v); } // 0x01..0F, kBytes = 1 => 0x0001..0E template -HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { return detail::ShiftRightBytesT < kBytes >= N * sizeof(T) ? 0xFF : kBytes > ()(v); } template -HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { const Repartition d8; - return BitCast(d, ShiftRightBytes(BitCast(d8, v))); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // Calls ShiftLeftBytes template -HWY_API Vec128 CombineShiftRightBytes(Simd d, Vec128 hi, +HWY_API Vec128 CombineShiftRightBytes(Simd d, Vec128 hi, Vec128 lo) { constexpr size_t kSize = N * sizeof(T); static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); const Repartition d8; - const Simd d_full8; + const Full64 d_full8; const Repartition d_full; using V64 = VFromD; const V64 hi64(BitCast(d8, hi).raw); @@ -3101,56 +3147,56 @@ HWY_API Vec128 CombineShiftRightBytes(Simd d, Vec128 hi, // ------------------------------ UpperHalf (ShiftRightBytes) // Full input -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_u8(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u8(v.raw)); } -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_u16(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u16(v.raw)); } -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_u32(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u32(v.raw)); } -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_u64(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u64(v.raw)); } -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_s8(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s8(v.raw)); } -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_s16(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s16(v.raw)); } -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_s32(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s32(v.raw)); } -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_s64(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s64(v.raw)); } -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_f32(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { + return Vec64(vget_high_f32(v.raw)); } #if HWY_ARCH_ARM_A64 -HWY_API Vec128 UpperHalf(Simd /* tag */, - const Vec128 v) { - return Vec128(vget_high_f64(v.raw)); +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_f64(v.raw)); } #endif // Partial template -HWY_API Vec128 UpperHalf(Half> /* tag */, +HWY_API Vec128 UpperHalf(Half> /* tag */, Vec128 v) { - const Simd d; - const auto vu = BitCast(RebindToUnsigned(), v); - const auto upper = BitCast(d, ShiftRightBytes(vu)); + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); return Vec128(upper.raw); } @@ -3183,7 +3229,7 @@ HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec128(vdupq_laneq_u64(v.raw, kLane)); } -// Vec128 is defined below. +// Vec64 is defined below. // Signed template @@ -3211,7 +3257,7 @@ HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec128(vdupq_laneq_s64(v.raw, kLane)); } -// Vec128 is defined below. +// Vec64 is defined below. // Float template @@ -3230,7 +3276,7 @@ HWY_API Vec128 Broadcast(const Vec128 v) { return Vec128(vdupq_laneq_f64(v.raw, kLane)); } template -HWY_API Vec128 Broadcast(const Vec128 v) { +HWY_API Vec64 Broadcast(const Vec64 v) { static_assert(0 <= kLane && kLane < 1, "Invalid lane"); return v; } @@ -3264,7 +3310,7 @@ HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec128(vdupq_n_u64(vgetq_lane_u64(v.raw, kLane))); } -// Vec128 is defined below. +// Vec64 is defined below. // Signed template @@ -3292,7 +3338,7 @@ HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec128(vdupq_n_s64(vgetq_lane_s64(v.raw, kLane))); } -// Vec128 is defined below. +// Vec64 is defined below. // Float template @@ -3309,12 +3355,12 @@ HWY_API Vec128 Broadcast(const Vec128 v) { #endif template -HWY_API Vec128 Broadcast(const Vec128 v) { +HWY_API Vec64 Broadcast(const Vec64 v) { static_assert(0 <= kLane && kLane < 1, "Invalid lane"); return v; } template -HWY_API Vec128 Broadcast(const Vec128 v) { +HWY_API Vec64 Broadcast(const Vec64 v) { static_assert(0 <= kLane && kLane < 1, "Invalid lane"); return v; } @@ -3328,10 +3374,10 @@ struct Indices128 { }; template -HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); #if HWY_IS_DEBUG_BUILD - const Simd di; + const Rebind di; HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && AllTrue(di, Lt(vec, Set(di, static_cast(N))))); #endif @@ -3368,14 +3414,14 @@ HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { } template -HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { - const Simd d; + const DFromV d; const RebindToSigned di; return BitCast( d, TableLookupBytes(BitCast(di, v), BitCast(di, Vec128{idx.raw}))); @@ -3385,13 +3431,13 @@ HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { // Single lane: no change template -HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { return v; } // Two lanes: shuffle template -HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { return Vec128(Shuffle2301(v)); } @@ -3408,11 +3454,75 @@ HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { // 16-bit template -HWY_API Vec128 Reverse(Simd d, const Vec128 v) { +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { const RepartitionToWide> du32; return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); } +// ------------------------------ Reverse2 + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32_u16(BitCast(du, v).raw))); +} +template +HWY_API Vec128 Reverse2(Full128 d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32q_u16(BitCast(du, v).raw))); +} + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64_u32(BitCast(du, v).raw))); +} +template +HWY_API Vec128 Reverse2(Full128 d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u32(BitCast(du, v).raw))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64_u16(BitCast(du, v).raw))); +} +template +HWY_API Vec128 Reverse4(Full128 d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u16(BitCast(du, v).raw))); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { + return Reverse(d, v); +} + +template +HWY_API Vec128 Reverse8(Simd, const Vec128) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + // ------------------------------ Other shuffles (TableLookupBytes) // Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). @@ -3496,13 +3606,12 @@ HWY_API Vec128 InterleaveLower(const Vec128 a, // < 64 bit parts template HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { - using V64 = Vec128; - return Vec128(InterleaveLower(V64(a.raw), V64(b.raw)).raw); + return Vec128(InterleaveLower(Vec64(a.raw), Vec64(b.raw)).raw); } // Additional overload for the optional Simd<> tag. template > -HWY_API V InterleaveLower(Simd /* tag */, V a, V b) { +HWY_API V InterleaveLower(Simd /* tag */, V a, V b) { return InterleaveLower(a, b); } @@ -3539,22 +3648,22 @@ HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { return Vec128(vzip2q_f32(a.raw, b.raw)); } -HWY_API Vec128 InterleaveUpper(const Vec128 a, - const Vec128 b) { - return Vec128(vzip2_f32(a.raw, b.raw)); +HWY_API Vec64 InterleaveUpper(const Vec64 a, + const Vec64 b) { + return Vec64(vzip2_f32(a.raw, b.raw)); } } // namespace detail // Full register template > -HWY_API V InterleaveUpper(Simd /* tag */, V a, V b) { +HWY_API V InterleaveUpper(Simd /* tag */, V a, V b) { return detail::InterleaveUpper(a, b); } // Partial template > -HWY_API V InterleaveUpper(Simd d, V a, V b) { +HWY_API V InterleaveUpper(Simd d, V a, V b) { const Half d2; return InterleaveLower(d, V(UpperHalf(d2, a).raw), V(UpperHalf(d2, b).raw)); } @@ -3563,26 +3672,24 @@ HWY_API V InterleaveUpper(Simd d, V a, V b) { // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. -template >> -HWY_API VFromD ZipLower(Vec128 a, Vec128 b) { +template >> +HWY_API VFromD ZipLower(V a, V b) { return BitCast(DW(), InterleaveLower(a, b)); } -template , - class DW = RepartitionToWide> -HWY_API VFromD ZipLower(DW dw, Vec128 a, Vec128 b) { +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { return BitCast(dw, InterleaveLower(D(), a, b)); } -template , - class DW = RepartitionToWide> -HWY_API VFromD ZipUpper(DW dw, Vec128 a, Vec128 b) { +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { return BitCast(dw, InterleaveUpper(D(), a, b)); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) template -HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, Vec128 a, Vec128 b, const Vec128 sum0, @@ -3603,70 +3710,67 @@ HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, // ------------------------------ Combine (InterleaveLower) // Full result -HWY_API Vec128 Combine(Full128 /* tag */, - Vec128 hi, Vec128 lo) { +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { return Vec128(vcombine_u8(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, - Vec128 hi, - Vec128 lo) { + Vec64 hi, Vec64 lo) { return Vec128(vcombine_u16(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, - Vec128 hi, - Vec128 lo) { + Vec64 hi, Vec64 lo) { return Vec128(vcombine_u32(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, - Vec128 hi, - Vec128 lo) { + Vec64 hi, Vec64 lo) { return Vec128(vcombine_u64(lo.raw, hi.raw)); } -HWY_API Vec128 Combine(Full128 /* tag */, Vec128 hi, - Vec128 lo) { +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { return Vec128(vcombine_s8(lo.raw, hi.raw)); } -HWY_API Vec128 Combine(Full128 /* tag */, - Vec128 hi, Vec128 lo) { +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { return Vec128(vcombine_s16(lo.raw, hi.raw)); } -HWY_API Vec128 Combine(Full128 /* tag */, - Vec128 hi, Vec128 lo) { +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { return Vec128(vcombine_s32(lo.raw, hi.raw)); } -HWY_API Vec128 Combine(Full128 /* tag */, - Vec128 hi, Vec128 lo) { +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { return Vec128(vcombine_s64(lo.raw, hi.raw)); } -HWY_API Vec128 Combine(Full128 /* tag */, Vec128 hi, - Vec128 lo) { +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { return Vec128(vcombine_f32(lo.raw, hi.raw)); } #if HWY_ARCH_ARM_A64 -HWY_API Vec128 Combine(Full128 /* tag */, Vec128 hi, - Vec128 lo) { +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { return Vec128(vcombine_f64(lo.raw, hi.raw)); } #endif // < 64bit input, <= 64 bit result template -HWY_API Vec128 Combine(Simd d, Vec128 hi, +HWY_API Vec128 Combine(Simd d, Vec128 hi, Vec128 lo) { // First double N (only lower halves will be used). const Vec128 hi2(hi.raw); const Vec128 lo2(lo.raw); // Repartition to two unsigned lanes (each the size of the valid input). - const Simd, 2> du; + const Simd, 2, 0> du; return BitCast(d, InterleaveLower(BitCast(du, lo2), BitCast(du, hi2))); } // ------------------------------ ZeroExtendVector (Combine) template -HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { return Combine(d, Zero(Half()), lo); } @@ -3674,108 +3778,111 @@ HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { // 64 or 128-bit input: just interleave template -HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, +HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, Vec128 lo) { // Treat half-width input as a single lane and interleave them. const Repartition, decltype(d)> du; return BitCast(d, InterleaveLower(BitCast(du, lo), BitCast(du, hi))); } -#if HWY_ARCH_ARM_A64 namespace detail { +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveEven, vtrn1, _, 2) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveEven, vtrn1, _, 2) +HWY_NEON_DEF_FUNCTION_FLOAT_32(InterleaveEven, vtrn1, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveOdd, vtrn2, _, 2) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveOdd, vtrn2, _, 2) +HWY_NEON_DEF_FUNCTION_FLOAT_32(InterleaveOdd, vtrn2, _, 2) +#else -HWY_INLINE Vec128 ConcatEven(Vec128 hi, - Vec128 lo) { - return Vec128(vtrn1_u8(lo.raw, hi.raw)); -} -HWY_INLINE Vec128 ConcatEven(Vec128 hi, - Vec128 lo) { - return Vec128(vtrn1_u16(lo.raw, hi.raw)); -} +// vtrn returns a struct with even and odd result. +#define HWY_NEON_BUILD_TPL_HWY_TRN +#define HWY_NEON_BUILD_RET_HWY_TRN(type, size) type##x##size##x2_t +// Pass raw args so we can accept uint16x2 args, for which there is no +// corresponding uint16x2x2 return type. +#define HWY_NEON_BUILD_PARAM_HWY_TRN(TYPE, size) \ + Raw128::type a, Raw128::type b +#define HWY_NEON_BUILD_ARG_HWY_TRN a, b +// Cannot use UINT8 etc. type macros because the x2_t tuples are only defined +// for full and half vectors. +HWY_NEON_DEF_FUNCTION(uint8, 16, InterleaveEvenOdd, vtrnq, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint8, 8, InterleaveEvenOdd, vtrn, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 8, InterleaveEvenOdd, vtrnq, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 4, InterleaveEvenOdd, vtrn, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 4, InterleaveEvenOdd, vtrnq, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 2, InterleaveEvenOdd, vtrn, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 16, InterleaveEvenOdd, vtrnq, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 8, InterleaveEvenOdd, vtrn, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 8, InterleaveEvenOdd, vtrnq, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 4, InterleaveEvenOdd, vtrn, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 4, InterleaveEvenOdd, vtrnq, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 2, InterleaveEvenOdd, vtrn, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 4, InterleaveEvenOdd, vtrnq, _, f32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 2, InterleaveEvenOdd, vtrn, _, f32, HWY_TRN) +#endif } // namespace detail // <= 32-bit input/output template -HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, +HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, Vec128 lo) { // Treat half-width input as two lanes and take every second one. const Repartition, decltype(d)> du; - return BitCast(d, detail::ConcatEven(BitCast(du, hi), BitCast(du, lo))); -} - +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveEven(BitCast(du, lo), BitCast(du, hi))); #else - -template -HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, - Vec128 lo) { - const Half d2; - return Combine(LowerHalf(d2, hi), LowerHalf(d2, lo)); + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[0])); +#endif } -#endif // HWY_ARCH_ARM_A64 // ------------------------------ ConcatUpperUpper // 64 or 128-bit input: just interleave template -HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, +HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, Vec128 lo) { // Treat half-width input as a single lane and interleave them. const Repartition, decltype(d)> du; return BitCast(d, InterleaveUpper(du, BitCast(du, lo), BitCast(du, hi))); } -#if HWY_ARCH_ARM_A64 -namespace detail { - -HWY_INLINE Vec128 ConcatOdd(Vec128 hi, - Vec128 lo) { - return Vec128(vtrn2_u8(lo.raw, hi.raw)); -} -HWY_INLINE Vec128 ConcatOdd(Vec128 hi, - Vec128 lo) { - return Vec128(vtrn2_u16(lo.raw, hi.raw)); -} - -} // namespace detail - // <= 32-bit input/output template -HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, +HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, Vec128 lo) { // Treat half-width input as two lanes and take every second one. const Repartition, decltype(d)> du; - return BitCast(d, detail::ConcatOdd(BitCast(du, hi), BitCast(du, lo))); -} - +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveOdd(BitCast(du, lo), BitCast(du, hi))); #else - -template -HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, - Vec128 lo) { - const Half d2; - return Combine(UpperHalf(d2, hi), UpperHalf(d2, lo)); + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[1])); +#endif } -#endif // HWY_ARCH_ARM_A64 - // ------------------------------ ConcatLowerUpper (ShiftLeftBytes) // 64 or 128-bit input: extract from concatenated template -HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, +HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, Vec128 lo) { return CombineShiftRightBytes(d, hi, lo); } // <= 32-bit input/output template -HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, +HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, Vec128 lo) { constexpr size_t kSize = N * sizeof(T); const Repartition d8; - const Simd d8x8; - const Simd d64; + const Full64 d8x8; + const Full64 d64; using V8x8 = VFromD; const V8x8 hi8x8(BitCast(d8, hi).raw); // Move into most-significant bytes @@ -3789,7 +3896,7 @@ HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, // Works for all N. template -HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, +HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, Vec128 lo) { return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); } @@ -3812,19 +3919,19 @@ HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, // 32-bit partial template -HWY_API Vec128 ConcatOdd(Simd /* tag */, +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, Vec128 lo) { return Vec128(vuzp2_u32(lo.raw, hi.raw)); } template -HWY_API Vec128 ConcatOdd(Simd /* tag */, +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, Vec128 lo) { return Vec128(vuzp2_s32(lo.raw, hi.raw)); } template -HWY_API Vec128 ConcatOdd(Simd /* tag */, +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, Vec128 lo) { return Vec128(vuzp2_f32(lo.raw, hi.raw)); } @@ -3854,19 +3961,19 @@ HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, // 32-bit partial template -HWY_API Vec128 ConcatEven(Simd /* tag */, +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, Vec128 lo) { return Vec128(vuzp1_u32(lo.raw, hi.raw)); } template -HWY_API Vec128 ConcatEven(Simd /* tag */, +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, Vec128 lo) { return Vec128(vuzp1_s32(lo.raw, hi.raw)); } template -HWY_API Vec128 ConcatEven(Simd /* tag */, +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, Vec128 lo) { return Vec128(vuzp1_f32(lo.raw, hi.raw)); } @@ -3878,11 +3985,43 @@ HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { return InterleaveLower(d, lo, hi); } +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveEven(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[0]); +#endif +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(Simd(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveOdd(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[1]); +#endif +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(Simd(), v, v); +} + // ------------------------------ OddEven (IfThenElse) template HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { - const Simd d; + const Simd d; const Repartition d8; alignas(16) constexpr uint8_t kBytes[16] = { ((0 / sizeof(T)) & 1) ? 0 : 0xFF, ((1 / sizeof(T)) & 1) ? 0 : 0xFF, @@ -3911,11 +4050,19 @@ HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { return v; } +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + // ------------------------------ ReorderDemote2To (OddEven) template HWY_API Vec128 ReorderDemote2To( - Simd dbf16, Vec128 a, Vec128 b) { + Simd dbf16, Vec128 a, Vec128 b) { const RebindToUnsigned du16; const Repartition du32; const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); @@ -3943,6 +4090,11 @@ HWY_API Vec128 AESRound(Vec128 state, round_key; } +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128(vaeseq_u8(state.raw, vdupq_n_u8(0))) ^ round_key; +} + HWY_API Vec128 CLMulLower(Vec128 a, Vec128 b) { return Vec128((uint64x2_t)vmull_p64(GetLane(a), GetLane(b))); } @@ -3957,7 +4109,7 @@ HWY_API Vec128 CLMulUpper(Vec128 a, Vec128 b) { // ================================================== MISC template -HWY_API Vec128 PromoteTo(Simd df32, +HWY_API Vec128 PromoteTo(Simd df32, const Vec128 v) { const Rebind du16; const RebindToSigned di32; @@ -3986,7 +4138,7 @@ HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { template HWY_API Vec128 MulEven(const Vec128 a, const Vec128 b) { - const Simd d; + const DFromV d; int32x2_t a_packed = ConcatEven(d, a, a).raw; int32x2_t b_packed = ConcatEven(d, b, b).raw; return Vec128( @@ -3995,7 +4147,7 @@ HWY_API Vec128 MulEven(const Vec128 a, template HWY_API Vec128 MulEven(const Vec128 a, const Vec128 b) { - const Simd d; + const DFromV d; uint32x2_t a_packed = ConcatEven(d, a, a).raw; uint32x2_t b_packed = ConcatEven(d, b, b).raw; return Vec128( @@ -4042,7 +4194,7 @@ template HWY_API Vec128 TableLookupBytes(const Vec128 bytes, const Vec128 from) { const Full128 d_full; - const Vec128 from64(from.raw); + const Vec64 from64(from.raw); const auto idx_full = Combine(d_full, from64, from64); const auto out_full = TableLookupBytes(bytes, idx_full); return Vec128(LowerHalf(Half(), out_full).raw); @@ -4059,10 +4211,10 @@ HWY_API Vec128 TableLookupBytes(const Vec128 bytes, // Partial both template -HWY_API VFromD>> TableLookupBytes( +HWY_API VFromD>> TableLookupBytes( Vec128 bytes, Vec128 from) { - const Simd d; - const Simd d_idx; + const Simd d; + const Simd d_idx; const Repartition d_idx8; // uint8x8 const auto bytes8 = BitCast(Repartition(), bytes); @@ -4080,7 +4232,8 @@ HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { // ------------------------------ Scatter (Store) template -HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); @@ -4088,7 +4241,7 @@ HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, Store(v, d, lanes); alignas(16) Offset offset_lanes[N]; - Store(offset, Simd(), offset_lanes); + Store(offset, Rebind(), offset_lanes); uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { @@ -4097,7 +4250,7 @@ HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, } template -HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); @@ -4105,7 +4258,7 @@ HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, Store(v, d, lanes); alignas(16) Index index_lanes[N]; - Store(index, Simd(), index_lanes); + Store(index, Rebind(), index_lanes); for (size_t i = 0; i < N; ++i) { base[index_lanes[i]] = lanes[i]; @@ -4115,13 +4268,13 @@ HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, // ------------------------------ Gather (Load/Store) template -HWY_API Vec128 GatherOffset(const Simd d, +HWY_API Vec128 GatherOffset(const Simd d, const T* HWY_RESTRICT base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); alignas(16) Offset offset_lanes[N]; - Store(offset, Simd(), offset_lanes); + Store(offset, Rebind(), offset_lanes); alignas(16) T lanes[N]; const uint8_t* base_bytes = reinterpret_cast(base); @@ -4132,12 +4285,13 @@ HWY_API Vec128 GatherOffset(const Simd d, } template -HWY_API Vec128 GatherIndex(const Simd d, const T* HWY_RESTRICT base, +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); alignas(16) Index index_lanes[N]; - Store(index, Simd(), index_lanes); + Store(index, Rebind(), index_lanes); alignas(16) T lanes[N]; for (size_t i = 0; i < N; ++i) { @@ -4264,35 +4418,35 @@ HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, // u16/i16 template HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { - const Repartition> d32; + const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(d32, Min(even, odd)); // Also broadcast into odd lanes. - return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); + return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); } template HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { - const Repartition> d32; + const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(d32, Max(even, odd)); // Also broadcast into odd lanes. - return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); + return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); } } // namespace detail template -HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { return detail::SumOfLanes(v); } template -HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { return detail::MinOfLanes(hwy::SizeTag(), v); } template -HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { return detail::MaxOfLanes(hwy::SizeTag(), v); } @@ -4304,9 +4458,9 @@ namespace detail { // overload is required to call the q vs non-q intrinsics. Note that 8-bit // LoadMaskBits only requires 16 bits, but 64 avoids casting. template -HWY_INLINE Vec128 Set64(Simd /* tag */, uint64_t mask_bits) { - const auto v64 = Vec128(vdup_n_u64(mask_bits)); - return Vec128(BitCast(Simd(), v64).raw); +HWY_INLINE Vec128 Set64(Simd /* tag */, uint64_t mask_bits) { + const auto v64 = Vec64(vdup_n_u64(mask_bits)); + return Vec128(BitCast(Full64(), v64).raw); } template HWY_INLINE Vec128 Set64(Full128 d, uint64_t mask_bits) { @@ -4314,7 +4468,7 @@ HWY_INLINE Vec128 Set64(Full128 d, uint64_t mask_bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; // Easier than Set(), which would require an >8-bit type, which would not // compile for T=uint8_t, N=1. @@ -4331,7 +4485,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; const auto vmask_bits = Set(du, static_cast(mask_bits)); @@ -4339,7 +4493,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; const auto vmask_bits = Set(du, static_cast(mask_bits)); @@ -4347,7 +4501,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(16) constexpr uint64_t kBit[8] = {1, 2}; return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); @@ -4357,7 +4511,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { // `p` points to at least 8 readable bytes, not all of which need be valid. template -HWY_API Mask128 LoadMaskBits(Simd d, +HWY_API Mask128 LoadMaskBits(Simd d, const uint8_t* HWY_RESTRICT bits) { uint64_t mask_bits = 0; CopyBytes<(N + 7) / 8>(bits, &mask_bits); @@ -4400,9 +4554,9 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, // we load all kSliceLanes so the upper lanes do not pollute the valid bits. alignas(8) constexpr uint8_t kSliceLanes[8] = {1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80}; - const Simd d; - const Simd du; - const Vec128 slice(Load(Simd(), kSliceLanes).raw); + const Simd d; + const RebindToUnsigned du; + const Vec128 slice(Load(Full64(), kSliceLanes).raw); const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; #if HWY_ARCH_ARM_A64 @@ -4439,9 +4593,9 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, // Upper lanes of partial loads are undefined. OnlyActive will fix this if // we load all kSliceLanes so the upper lanes do not pollute the valid bits. alignas(8) constexpr uint16_t kSliceLanes[4] = {1, 2, 4, 8}; - const Simd d; - const Simd du; - const Vec128 slice(Load(Simd(), kSliceLanes).raw); + const Simd d; + const RebindToUnsigned du; + const Vec128 slice(Load(Full64(), kSliceLanes).raw); const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; #if HWY_ARCH_ARM_A64 return vaddv_u16(values.raw); @@ -4474,9 +4628,9 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, // Upper lanes of partial loads are undefined. OnlyActive will fix this if // we load all kSliceLanes so the upper lanes do not pollute the valid bits. alignas(8) constexpr uint32_t kSliceLanes[2] = {1, 2}; - const Simd d; - const Simd du; - const Vec128 slice(Load(Simd(), kSliceLanes).raw); + const Simd d; + const RebindToUnsigned du; + const Vec128 slice(Load(Full64(), kSliceLanes).raw); const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; #if HWY_ARCH_ARM_A64 return vaddv_u32(values.raw); @@ -4503,10 +4657,9 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask128 m) { template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask128 m) { - const Simd d; - const Simd du; - const Vec128 values = - BitCast(du, VecFromMask(d, m)) & Set(du, 1); + const Full64 d; + const Full64 du; + const Vec64 values = BitCast(du, VecFromMask(d, m)) & Set(du, 1); return vget_lane_u64(values.raw, 0); } @@ -4596,20 +4749,20 @@ HWY_API size_t CountTrue(Full128 /* tag */, const Mask128 mask) { // Partial template -HWY_API size_t CountTrue(Simd /* tag */, const Mask128 mask) { +HWY_API size_t CountTrue(Simd /* tag */, const Mask128 mask) { return PopCount(detail::BitsFromMask(mask)); } template -HWY_API intptr_t FindFirstTrue(const Simd /* tag */, - const Mask128 mask) { +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, + const Mask128 mask) { const uint64_t bits = detail::BitsFromMask(mask); return bits ? static_cast(Num0BitsBelowLS1Bit_Nonzero64(bits)) : -1; } // `p` points to at least 8 writable bytes. template -HWY_API size_t StoreMaskBits(Simd /* tag */, const Mask128 mask, +HWY_API size_t StoreMaskBits(Simd /* tag */, const Mask128 mask, uint8_t* bits) { const uint64_t mask_bits = detail::BitsFromMask(mask); const size_t kNumBytes = (N + 7) / 8; @@ -4633,13 +4786,13 @@ HWY_API bool AllFalse(const Full128 d, const Mask128 m) { // Partial template -HWY_API bool AllFalse(const Simd /* tag */, const Mask128 m) { +HWY_API bool AllFalse(const Simd /* tag */, const Mask128 m) { return detail::BitsFromMask(m) == 0; } template -HWY_API bool AllTrue(const Simd d, const Mask128 m) { - return AllFalse(VecFromMask(d, m) == Zero(d)); +HWY_API bool AllTrue(const Simd d, const Mask128 m) { + return AllFalse(d, VecFromMask(d, m) == Zero(d)); } // ------------------------------ Compress @@ -4655,7 +4808,7 @@ HWY_INLINE Vec128 Load8Bytes(Full128 /*d*/, // Load 8 bytes and return half-reg with N <= 8 bytes. template -HWY_INLINE Vec128 Load8Bytes(Simd d, +HWY_INLINE Vec128 Load8Bytes(Simd d, const uint8_t* bytes) { return Load(d, bytes); } @@ -4664,9 +4817,9 @@ template HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<2> /*tag*/, const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 256); - const Simd d; + const Simd d; const Repartition d8; - const Simd du; + const Simd du; // ARM does not provide an equivalent of AVX2 permutevar, so we need byte // indices for VTBL (one vector's worth for each of 256 combinations of @@ -4821,12 +4974,12 @@ HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<4> /*tag*/, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const Simd d; + const Simd d; const Repartition d8; return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); } -#if HWY_CAP_INTEGER64 || HWY_CAP_FLOAT64 +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 template HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<8> /*tag*/, @@ -4840,7 +4993,7 @@ HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<8> /*tag*/, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const Simd d; + const Simd d; const Repartition d8; return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); } @@ -4853,7 +5006,7 @@ template HWY_INLINE Vec128 Compress(Vec128 v, const uint64_t mask_bits) { const auto idx = detail::IdxFromBits(hwy::SizeTag(), mask_bits); - using D = Simd; + using D = Simd; const RebindToSigned di; return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } @@ -4883,7 +5036,7 @@ HWY_INLINE Vec128 CompressBits(Vec128 v, // ------------------------------ CompressStore template HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, - Simd d, T* HWY_RESTRICT unaligned) { + Simd d, T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(mask); StoreU(detail::Compress(v, mask_bits), d, unaligned); return PopCount(mask_bits); @@ -4892,7 +5045,8 @@ HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, // ------------------------------ CompressBlendedStore template HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, - Simd d, T* HWY_RESTRICT unaligned) { + Simd d, + T* HWY_RESTRICT unaligned) { const RebindToUnsigned du; // so we can support fp16/bf16 using TU = TFromD; const uint64_t mask_bits = detail::BitsFromMask(m); @@ -4908,8 +5062,8 @@ HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, template HWY_API size_t CompressBitsStore(Vec128 v, - const uint8_t* HWY_RESTRICT bits, Simd d, - T* HWY_RESTRICT unaligned) { + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { uint64_t mask_bits = 0; constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(bits, &mask_bits); @@ -4929,17 +5083,15 @@ HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v2, Full128 /*tag*/, uint8_t* HWY_RESTRICT unaligned) { - const uint8x16x3_t triple = {v0.raw, v1.raw, v2.raw}; + const uint8x16x3_t triple = {{v0.raw, v1.raw, v2.raw}}; vst3q_u8(unaligned, triple); } // 64 bits -HWY_API void StoreInterleaved3(const Vec128 v0, - const Vec128 v1, - const Vec128 v2, - Simd /*tag*/, +HWY_API void StoreInterleaved3(const Vec64 v0, const Vec64 v1, + const Vec64 v2, Full64 /*tag*/, uint8_t* HWY_RESTRICT unaligned) { - const uint8x8x3_t triple = {v0.raw, v1.raw, v2.raw}; + const uint8x8x3_t triple = {{v0.raw, v1.raw, v2.raw}}; vst3_u8(unaligned, triple); } @@ -4948,10 +5100,10 @@ template HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, const Vec128 v2, - Simd /*tag*/, + Simd /*tag*/, uint8_t* HWY_RESTRICT unaligned) { alignas(16) uint8_t buf[24]; - const uint8x8x3_t triple = {v0.raw, v1.raw, v2.raw}; + const uint8x8x3_t triple = {{v0.raw, v1.raw, v2.raw}}; vst3_u8(buf, triple); CopyBytes(buf, unaligned); } @@ -4965,18 +5117,16 @@ HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v3, Full128 /*tag*/, uint8_t* HWY_RESTRICT unaligned) { - const uint8x16x4_t quad = {v0.raw, v1.raw, v2.raw, v3.raw}; + const uint8x16x4_t quad = {{v0.raw, v1.raw, v2.raw, v3.raw}}; vst4q_u8(unaligned, quad); } // 64 bits -HWY_API void StoreInterleaved4(const Vec128 v0, - const Vec128 v1, - const Vec128 v2, - const Vec128 v3, - Simd /*tag*/, +HWY_API void StoreInterleaved4(const Vec64 v0, const Vec64 v1, + const Vec64 v2, const Vec64 v3, + Full64 /*tag*/, uint8_t* HWY_RESTRICT unaligned) { - const uint8x8x4_t quad = {v0.raw, v1.raw, v2.raw, v3.raw}; + const uint8x8x4_t quad = {{v0.raw, v1.raw, v2.raw, v3.raw}}; vst4_u8(unaligned, quad); } @@ -4986,108 +5136,64 @@ HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, const Vec128 v2, const Vec128 v3, - Simd /*tag*/, + Simd /*tag*/, uint8_t* HWY_RESTRICT unaligned) { alignas(16) uint8_t buf[32]; - const uint8x8x4_t quad = {v0.raw, v1.raw, v2.raw, v3.raw}; + const uint8x8x4_t quad = {{v0.raw, v1.raw, v2.raw, v3.raw}}; vst4_u8(buf, quad); CopyBytes(buf, unaligned); } -// ================================================== DEPRECATED +// ------------------------------ Lt128 -template -HWY_API size_t StoreMaskBits(const Mask128 mask, uint8_t* bits) { - return StoreMaskBits(Simd(), mask, bits); +namespace detail { + +template +Mask128 ShiftMaskLeft(Mask128 m) { + return MaskFromVec(ShiftLeftLanes(VecFromMask(Simd(), m))); } -template -HWY_API bool AllTrue(const Mask128 mask) { - return AllTrue(Simd(), mask); +} // namespace detail + +template +HWY_INLINE Mask128 Lt128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "Use u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const Mask128 eqHL = Eq(a, b); + const Mask128 ltHL = Lt(a, b); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. + const Mask128 ltLx = detail::ShiftMaskLeft<1>(ltHL); + const Mask128 outHx = Or(ltHL, And(eqHL, ltLx)); + const Vec128 vecHx = VecFromMask(d, outHx); + return MaskFromVec(InterleaveUpper(d, vecHx, vecHx)); } -template -HWY_API bool AllFalse(const Mask128 mask) { - return AllFalse(Simd(), mask); +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); } -template -HWY_API size_t CountTrue(const Mask128 mask) { - return CountTrue(Simd(), mask); -} - -template -HWY_API Vec128 SumOfLanes(const Vec128 v) { - return SumOfLanes(Simd(), v); -} -template -HWY_API Vec128 MinOfLanes(const Vec128 v) { - return MinOfLanes(Simd(), v); -} -template -HWY_API Vec128 MaxOfLanes(const Vec128 v) { - return MaxOfLanes(Simd(), v); -} - -template -HWY_API Vec128 UpperHalf(Vec128 v) { - return UpperHalf(Half>(), v); -} - -template -HWY_API Vec128 ShiftRightBytes(const Vec128 v) { - return ShiftRightBytes(Simd(), v); -} - -template -HWY_API Vec128 ShiftRightLanes(const Vec128 v) { - return ShiftRightLanes(Simd(), v); -} - -template -HWY_API Vec128 CombineShiftRightBytes(Vec128 hi, Vec128 lo) { - return CombineShiftRightBytes(Simd(), hi, lo); -} - -template -HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { - return InterleaveUpper(Simd(), a, b); -} - -template > -HWY_API VFromD> ZipUpper(Vec128 a, Vec128 b) { - return InterleaveUpper(RepartitionToWide(), a, b); -} - -template -HWY_API Vec128 Combine(Vec128 hi2, Vec128 lo2) { - return Combine(Simd(), hi2, lo2); -} - -template -HWY_API Vec128 ZeroExtendVector(Vec128 lo) { - return ZeroExtendVector(Simd(), lo); -} - -template -HWY_API Vec128 ConcatLowerLower(Vec128 hi, Vec128 lo) { - return ConcatLowerLower(Simd(), hi, lo); -} - -template -HWY_API Vec128 ConcatUpperUpper(Vec128 hi, Vec128 lo) { - return ConcatUpperUpper(Simd(), hi, lo); -} - -template -HWY_API Vec128 ConcatLowerUpper(const Vec128 hi, - const Vec128 lo) { - return ConcatLowerUpper(Simd(), hi, lo); -} - -template -HWY_API Vec128 ConcatUpperLower(Vec128 hi, Vec128 lo) { - return ConcatUpperLower(Simd(), hi, lo); +template +HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), b, a); } // ================================================== Operator wrapper diff --git a/third_party/highway/hwy/ops/arm_sve-inl.h b/third_party/highway/hwy/ops/arm_sve-inl.h index 85b4e340b806..2b25b7d0fcf6 100644 --- a/third_party/highway/hwy/ops/arm_sve-inl.h +++ b/third_party/highway/hwy/ops/arm_sve-inl.h @@ -18,11 +18,7 @@ #include #include -#if defined(HWY_EMULATE_SVE) -#include "third_party/farm_sve/farm_sve.h" -#else #include -#endif #include "hwy/base.h" #include "hwy/ops/shared-inl.h" @@ -31,10 +27,6 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { -// SVE only supports fractions, not LMUL > 1. -template -using Full = Simd> (-kShift)) : 0>; - template struct DFromV_t {}; // specialized in macros template @@ -56,21 +48,26 @@ using TFromV = TFromD>; namespace detail { // for code folding // Unsigned: -#define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, NAME, OP) -#define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, NAME, OP) -#define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) X_MACRO(uint, u, 32, NAME, OP) -#define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) X_MACRO(uint, u, 64, NAME, OP) +#define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 64, 32, NAME, OP) // Signed: -#define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, NAME, OP) -#define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, NAME, OP) -#define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, NAME, OP) -#define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, NAME, OP) +#define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP) // Float: -#define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) X_MACRO(float, f, 16, NAME, OP) -#define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) X_MACRO(float, f, 32, NAME, OP) -#define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) X_MACRO(float, f, 64, NAME, OP) +#define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 16, 16, NAME, OP) +#define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 64, 32, NAME, OP) // For all element sizes: #define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ @@ -129,59 +126,52 @@ namespace detail { // for code folding // Assemble types for use in x-macros #define HWY_SVE_T(BASE, BITS) BASE##BITS##_t -#define HWY_SVE_D(BASE, BITS, N) Simd +#define HWY_SVE_D(BASE, BITS, N, POW2) Simd #define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t } // namespace detail -#define HWY_SPECIALIZE(BASE, CHAR, BITS, NAME, OP) \ - template <> \ - struct DFromV_t { \ - using type = HWY_SVE_D(BASE, BITS, HWY_LANES(HWY_SVE_T(BASE, BITS))); \ +#define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <> \ + struct DFromV_t { \ + using type = ScalableTag; \ }; HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _) #undef HWY_SPECIALIZE -// vector = f(d), e.g. Undefined -#define HWY_SVE_RETV_ARGD(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N) d) { \ - return sv##OP##_##CHAR##BITS(); \ - } - // Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX // instructions, and we anyway only use it when the predicate is ptrue. // vector = f(vector), e.g. Not -#define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ } -#define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ return sv##OP##_##CHAR##BITS(v); \ } -// vector = f(vector, scalar), e.g. detail::AddK -#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, NAME, OP) \ +// vector = f(vector, scalar), e.g. detail::AddN +#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ } -#define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ return sv##OP##_##CHAR##BITS(a, b); \ } // vector = f(vector, vector), e.g. Add -#define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ } -#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ return sv##OP##_##CHAR##BITS(a, b); \ @@ -221,22 +211,14 @@ HWY_INLINE size_t HardwareLanes(hwy::SizeTag<8> /* tag */) { } // namespace detail -// Capped to <= 128-bit: SVE is at least that large, so no need to query actual. -template -HWY_API constexpr size_t Lanes(Simd /* tag */) { - return N; -} - -// Returns actual number of lanes after dividing by div={1,2,4,8}. -// May return 0 if div > 16/sizeof(T): there is no "1/8th" of a u32x4, but it -// would be valid for u32x8 (i.e. hardware vectors >= 256 bits). -template -HWY_API size_t Lanes(Simd /* tag */) { - static_assert(N <= HWY_LANES(T), "N cannot exceed a full vector"); - +// Returns actual number of lanes after capping by N and shifting. May return 0 +// (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8). +template +HWY_API size_t Lanes(Simd d) { const size_t actual = detail::HardwareLanes(hwy::SizeTag()); - const size_t div = HWY_LANES(T) / N; - return (div <= 8) ? actual / div : HWY_MIN(actual, N); + // Common case of full vectors: avoid any extra instructions. + if (detail::IsFull(d)) return actual; + return HWY_MIN(detail::ScaleByPower(actual, kPow2), N); } // ================================================== MASK INIT @@ -244,10 +226,11 @@ HWY_API size_t Lanes(Simd /* tag */) { // One mask bit per byte; only the one belonging to the lowest byte is valid. // ------------------------------ FirstN -#define HWY_SVE_FIRSTN(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, KN) /* d */, size_t N) { \ - return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast(N)); \ +#define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \ + const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \ + return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast(limit)); \ } HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt) #undef HWY_SVE_FIRSTN @@ -257,10 +240,10 @@ namespace detail { // All-true mask from a macro #define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2) -#define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N) d) { \ - return HWY_SVE_PTRUE(BITS); \ +#define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_PTRUE(BITS); \ } HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true @@ -271,10 +254,10 @@ HWY_API svbool_t PFalse() { return svpfalse_b(); } // Returns all-true if d is HWY_FULL or FirstN(N) after capping N. // // This is used in functions that load/store memory; other functions (e.g. -// arithmetic on partial vectors) can ignore d and use PTrue instead. -template -svbool_t Mask(Simd d) { - return N == HWY_LANES(T) ? PTrue(d) : FirstN(d, Lanes(d)); +// arithmetic) can ignore d and use PTrue instead. +template +svbool_t MakeMask(D d) { + return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d)); } } // namespace detail @@ -283,19 +266,19 @@ svbool_t Mask(Simd d) { // ------------------------------ Set // vector = f(d, scalar), e.g. Set -#define HWY_SVE_SET(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N) d, HWY_SVE_T(BASE, BITS) arg) { \ - return sv##OP##_##CHAR##BITS(arg); \ +#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) arg) { \ + return sv##OP##_##CHAR##BITS(arg); \ } HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n) #undef HWY_SVE_SET // Required for Zero and VFromD -template -svuint16_t Set(Simd d, bfloat16_t arg) { +template +svuint16_t Set(Simd d, bfloat16_t arg) { return Set(RebindToUnsigned(), arg.bits); } @@ -311,39 +294,39 @@ VFromD Zero(D d) { // ------------------------------ Undefined -#if defined(HWY_EMULATE_SVE) -template -VFromD Undefined(D d) { - return Zero(d); -} -#else -HWY_SVE_FOREACH(HWY_SVE_RETV_ARGD, Undefined, undef) -#endif +#define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return sv##OP##_##CHAR##BITS(); \ + } + +HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) // ------------------------------ BitCast namespace detail { // u8: no change -#define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, NAME, OP) \ - HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ - return v; \ - } \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \ - HWY_SVE_D(BASE, BITS, N) /* d */, HWY_SVE_V(BASE, BITS) v) { \ - return v; \ +#define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return v; \ } // All other types -#define HWY_SVE_CAST(BASE, CHAR, BITS, NAME, OP) \ - HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ - return sv##OP##_u8_##CHAR##BITS(v); \ - } \ - template \ - HWY_INLINE HWY_SVE_V(BASE, BITS) \ - BitCastFromByte(HWY_SVE_D(BASE, BITS, N) /* d */, svuint8_t v) { \ - return sv##OP##_##CHAR##BITS##_u8(v); \ +#define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u8_##CHAR##BITS(v); \ + } \ + template \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \ + return sv##OP##_##CHAR##BITS##_u8(v); \ } HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _) @@ -356,10 +339,10 @@ HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret) #undef HWY_SVE_CAST_NOP #undef HWY_SVE_CAST -template -HWY_INLINE svuint16_t BitCastFromByte(Simd /* d */, +template +HWY_INLINE svuint16_t BitCastFromByte(Simd /* d */, svuint8_t v) { - return BitCastFromByte(Simd(), v); + return BitCastFromByte(Simd(), v); } } // namespace detail @@ -421,20 +404,20 @@ HWY_API V Xor(const V a, const V b) { // ------------------------------ AndNot namespace detail { -#define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, NAME, OP) \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ - return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ +#define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ } HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n) #undef HWY_SVE_RETV_ARGPVN_SWAP } // namespace detail -#define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, NAME, OP) \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ - return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ +#define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ } HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic) #undef HWY_SVE_RETV_ARGPVV_SWAP @@ -446,6 +429,13 @@ HWY_API V AndNot(const V a, const V b) { return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b))); } +// ------------------------------ OrAnd + +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + // ------------------------------ PopulationCount #ifdef HWY_NATIVE_POPCNT @@ -455,7 +445,7 @@ HWY_API V AndNot(const V a, const V b) { #endif // Need to return original type instead of unsigned. -#define HWY_SVE_POPCNT(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ return BitCast(DFromV(), \ sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \ @@ -499,7 +489,7 @@ HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add) namespace detail { // Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg. -#define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ return sv##OP##_##CHAR##BITS##_z(pg, a, b); \ @@ -511,6 +501,21 @@ HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub) +// ------------------------------ SumsOf8 +HWY_API svuint64_t SumsOf8(const svuint8_t v) { + const ScalableTag du32; + const ScalableTag du64; + const svbool_t pg = detail::PTrue(du64); + + const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1); + // Compute pairwise sum of u32 and extend to u64. + // TODO(janwas): on SVE2, we can instead use svaddp. + const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32); + // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended) + const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4)); + return Add(hi, lo); +} + // ------------------------------ SaturatedAdd HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) @@ -526,7 +531,7 @@ HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPVV, AbsDiff, abd) // ------------------------------ ShiftLeft[Same] -#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ template \ HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ @@ -558,12 +563,12 @@ HWY_API V RotateRight(const V v) { // ------------------------------ Shl/r -#define HWY_SVE_SHIFT(BASE, CHAR, BITS, NAME, OP) \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ - using TU = HWY_SVE_T(uint, BITS); \ - return sv##OP##_##CHAR##BITS##_x( \ - HWY_SVE_PTRUE(BITS), v, BitCast(Simd(), bits)); \ +#define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ + const RebindToUnsigned> du; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ + BitCast(du, bits)); \ } HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl) @@ -609,7 +614,7 @@ HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt) HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte) // ------------------------------ MulAdd -#define HWY_SVE_FMA(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ HWY_SVE_V(BASE, BITS) add) { \ @@ -667,10 +672,10 @@ HWY_API svbool_t Xor(svbool_t a, svbool_t b) { // ------------------------------ CountTrue -#define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N) d, svbool_t m) { \ - return sv##OP##_b##BITS(detail::Mask(d), m); \ +#define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \ + return sv##OP##_b##BITS(detail::MakeMask(d), m); \ } HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp) @@ -679,10 +684,10 @@ HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp) // For 16-bit Compress: full vector, not limited to SV_POW2. namespace detail { -#define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N) d, svbool_t m) { \ - return sv##OP##_b##BITS(svptrue_b##BITS(), m); \ +#define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \ + return sv##OP##_b##BITS(svptrue_b##BITS(), m); \ } HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp) @@ -691,25 +696,27 @@ HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp) } // namespace detail // ------------------------------ AllFalse -template -HWY_API bool AllFalse(Simd d, svbool_t m) { - return !svptest_any(detail::Mask(d), m); +template +HWY_API bool AllFalse(D d, svbool_t m) { + return !svptest_any(detail::MakeMask(d), m); } // ------------------------------ AllTrue -template -HWY_API bool AllTrue(Simd d, svbool_t m) { +template +HWY_API bool AllTrue(D d, svbool_t m) { return CountTrue(d, m) == Lanes(d); } // ------------------------------ FindFirstTrue -template -HWY_API intptr_t FindFirstTrue(Simd d, svbool_t m) { - return AllFalse(d, m) ? -1 : CountTrue(d, svbrkb_b_z(detail::Mask(d), m)); +template +HWY_API intptr_t FindFirstTrue(D d, svbool_t m) { + return AllFalse(d, m) ? intptr_t{-1} + : static_cast( + CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m))); } // ------------------------------ IfThenElse -#define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \ return sv##OP##_##CHAR##BITS(m, yes, no); \ @@ -733,25 +740,31 @@ HWY_API V IfThenZeroElse(const M mask, const V no) { // ================================================== COMPARE // mask = f(vector, vector) -#define HWY_SVE_COMPARE(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ } -#define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ } // ------------------------------ Eq HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n) +} // namespace detail // ------------------------------ Ne HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n) +} // namespace detail // ------------------------------ Lt HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt) namespace detail { -HWY_SVE_FOREACH_IF(HWY_SVE_COMPARE_N, LtN, cmplt_n) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n) } // namespace detail // ------------------------------ Le @@ -774,13 +787,13 @@ HWY_API svbool_t Ge(const V a, const V b) { // ------------------------------ TestBit template HWY_API svbool_t TestBit(const V a, const V bit) { - return Ne(And(a, bit), Zero(DFromV())); + return detail::NeN(And(a, bit), 0); } // ------------------------------ MaskFromVec (Ne) template HWY_API svbool_t MaskFromVec(const V v) { - return Ne(v, Zero(DFromV())); + return detail::NeN(v, static_cast>(0)); } // ------------------------------ VecFromMask @@ -796,48 +809,57 @@ HWY_API VFromD VecFromMask(const D d, svbool_t mask) { return BitCast(d, VecFromMask(RebindToUnsigned(), mask)); } +// ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse) + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + // TODO(janwas): use svbsl for SVE2 + return IfThenElse(MaskFromVec(mask), yes, no); +} + // ================================================== MEMORY // ------------------------------ Load/MaskedLoad/LoadDup128/Store/Stream -#define HWY_SVE_LOAD(BASE, CHAR, BITS, NAME, OP) \ - template \ +#define HWY_SVE_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N) d, \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - return sv##OP##_##CHAR##BITS(detail::Mask(d), p); \ + return sv##OP##_##CHAR##BITS(detail::MakeMask(d), p); \ } -#define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N) d, \ - const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - return sv##OP##_##CHAR##BITS(m, p); \ +#define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return sv##OP##_##CHAR##BITS(m, p); \ } -#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N) d, \ - const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - /* All-true predicate to load all 128 bits. */ \ - return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), p); \ +#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + /* All-true predicate to load all 128 bits. */ \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), p); \ } -#define HWY_SVE_STORE(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N) d, \ - HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - sv##OP##_##CHAR##BITS(detail::Mask(d), p, v); \ - } - -#define HWY_SVE_MASKED_STORE(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API void NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v, \ - HWY_SVE_D(BASE, BITS, N) d, \ +#define HWY_SVE_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - sv##OP##_##CHAR##BITS(m, p, v); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), p, v); \ + } + +#define HWY_SVE_MASKED_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(m, p, v); \ } HWY_SVE_FOREACH(HWY_SVE_LOAD, Load, ld1) @@ -854,15 +876,15 @@ HWY_SVE_FOREACH(HWY_SVE_MASKED_STORE, MaskedStore, st1) #undef HWY_SVE_MASKED_STORE // BF16 is the same as svuint16_t because BF16 is optional before v8.6. -template -HWY_API svuint16_t Load(Simd d, +template +HWY_API svuint16_t Load(Simd d, const bfloat16_t* HWY_RESTRICT p) { return Load(RebindToUnsigned(), reinterpret_cast(p)); } -template -HWY_API void Store(svuint16_t v, Simd d, +template +HWY_API void Store(svuint16_t v, Simd d, bfloat16_t* HWY_RESTRICT p) { Store(v, RebindToUnsigned(), reinterpret_cast(p)); @@ -884,20 +906,22 @@ HWY_API void StoreU(const V v, D d, TFromD* HWY_RESTRICT p) { // ------------------------------ ScatterOffset/Index -#define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N) d, \ +#define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ HWY_SVE_V(int, BITS) offset) { \ - sv##OP##_s##BITS##offset_##CHAR##BITS(detail::Mask(d), base, offset, v); \ + sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \ + v); \ } -#define HWY_SVE_SCATTER_INDEX(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N) d, \ - HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ - HWY_SVE_V(int, BITS) index) { \ - sv##OP##_s##BITS##index_##CHAR##BITS(detail::Mask(d), base, index, v); \ +#define HWY_SVE_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, HWY_SVE_V(int, BITS) index) { \ + sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, index, v); \ } HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter) @@ -907,22 +931,23 @@ HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_INDEX, ScatterIndex, st1_scatter) // ------------------------------ GatherOffset/Index -#define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N) d, \ - const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ - HWY_SVE_V(int, BITS) offset) { \ - return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::Mask(d), base, \ - offset); \ +#define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \ + offset); \ } -#define HWY_SVE_GATHER_INDEX(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N) d, \ - const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ - HWY_SVE_V(int, BITS) index) { \ - return sv##OP##_s##BITS##index_##CHAR##BITS(detail::Mask(d), base, index); \ +#define HWY_SVE_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) index) { \ + return sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, \ + index); \ } HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather) @@ -932,13 +957,14 @@ HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_INDEX, GatherIndex, ld1_gather) // ------------------------------ StoreInterleaved3 -#define HWY_SVE_STORE3(BASE, CHAR, BITS, NAME, OP) \ - template \ +#define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ - HWY_SVE_V(BASE, BITS) v2, HWY_SVE_D(BASE, BITS, N) d, \ + HWY_SVE_V(BASE, BITS) v2, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ const sv##BASE##BITS##x3_t triple = svcreate3##_##CHAR##BITS(v0, v1, v2); \ - sv##OP##_##CHAR##BITS(detail::Mask(d), unaligned, triple); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, triple); \ } HWY_SVE_FOREACH_U08(HWY_SVE_STORE3, StoreInterleaved3, st3) @@ -946,15 +972,15 @@ HWY_SVE_FOREACH_U08(HWY_SVE_STORE3, StoreInterleaved3, st3) // ------------------------------ StoreInterleaved4 -#define HWY_SVE_STORE4(BASE, CHAR, BITS, NAME, OP) \ - template \ +#define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \ - HWY_SVE_D(BASE, BITS, N) d, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ const sv##BASE##BITS##x4_t quad = \ svcreate4##_##CHAR##BITS(v0, v1, v2, v3); \ - sv##OP##_##CHAR##BITS(detail::Mask(d), unaligned, quad); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, quad); \ } HWY_SVE_FOREACH_U08(HWY_SVE_STORE4, StoreInterleaved4, st4) @@ -965,14 +991,11 @@ HWY_SVE_FOREACH_U08(HWY_SVE_STORE4, StoreInterleaved4, st4) // ------------------------------ PromoteTo // Same sign -#define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N) /* tag */, \ - VFromD, \ - HWY_LANES(HWY_SVE_T(BASE, BITS)) * 2>> \ - v) { \ - return sv##OP##_##CHAR##BITS(v); \ +#define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \ + return sv##OP##_##CHAR##BITS(v); \ } HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) @@ -980,34 +1003,34 @@ HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) // 2x -template -HWY_API svuint32_t PromoteTo(Simd dto, svuint8_t vfrom) { +template +HWY_API svuint32_t PromoteTo(Simd dto, svuint8_t vfrom) { const RepartitionToWide> d2; return PromoteTo(dto, PromoteTo(d2, vfrom)); } -template -HWY_API svint32_t PromoteTo(Simd dto, svint8_t vfrom) { +template +HWY_API svint32_t PromoteTo(Simd dto, svint8_t vfrom) { const RepartitionToWide> d2; return PromoteTo(dto, PromoteTo(d2, vfrom)); } -template +template HWY_API svuint32_t U32FromU8(svuint8_t v) { - return PromoteTo(Simd(), v); + return PromoteTo(Simd(), v); } // Sign change -template -HWY_API svint16_t PromoteTo(Simd dto, svuint8_t vfrom) { +template +HWY_API svint16_t PromoteTo(Simd dto, svuint8_t vfrom) { const RebindToUnsigned du; return BitCast(dto, PromoteTo(du, vfrom)); } -template -HWY_API svint32_t PromoteTo(Simd dto, svuint16_t vfrom) { +template +HWY_API svint32_t PromoteTo(Simd dto, svuint16_t vfrom) { const RebindToUnsigned du; return BitCast(dto, PromoteTo(du, vfrom)); } -template -HWY_API svint32_t PromoteTo(Simd dto, svuint8_t vfrom) { +template +HWY_API svint32_t PromoteTo(Simd dto, svuint8_t vfrom) { const Repartition> du16; const Repartition di16; return PromoteTo(dto, BitCast(di16, PromoteTo(du16, vfrom))); @@ -1015,19 +1038,33 @@ HWY_API svint32_t PromoteTo(Simd dto, svuint8_t vfrom) { // ------------------------------ PromoteTo F -template -HWY_API svfloat32_t PromoteTo(Simd /* d */, const svfloat16_t v) { - return svcvt_f32_f16_x(detail::PTrue(Simd()), v); +// svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so +// first replicate each lane once. +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLower, zip1) +// Do not use zip2 to implement PromoteUpperTo or similar because vectors may be +// non-powers of two, so getting the actual "upper half" requires MaskUpperHalf. +} // namespace detail + +template +HWY_API svfloat32_t PromoteTo(Simd /* d */, + const svfloat16_t v) { + const svfloat16_t vv = detail::ZipLower(v, v); + return svcvt_f32_f16_x(detail::PTrue(Simd()), vv); } -template -HWY_API svfloat64_t PromoteTo(Simd /* d */, const svfloat32_t v) { - return svcvt_f64_f32_x(detail::PTrue(Simd()), v); +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLower(v, v); + return svcvt_f64_f32_x(detail::PTrue(Simd()), vv); } -template -HWY_API svfloat64_t PromoteTo(Simd /* d */, const svint32_t v) { - return svcvt_f64_s32_x(detail::PTrue(Simd()), v); +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svint32_t v) { + const svint32_t vv = detail::ZipLower(v, v); + return svcvt_f64_s32_x(detail::PTrue(Simd()), vv); } // For 16-bit Compress @@ -1035,8 +1072,8 @@ namespace detail { HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) #undef HWY_SVE_PROMOTE_TO -template -HWY_API svfloat32_t PromoteUpperTo(Simd df, const svfloat16_t v) { +template +HWY_API svfloat32_t PromoteUpperTo(Simd df, svfloat16_t v) { const RebindToUnsigned du; const RepartitionToNarrow dn; return BitCast(df, PromoteUpperTo(du, BitCast(dn, v))); @@ -1057,44 +1094,43 @@ VU SaturateU(VU v) { // Saturates unsigned vectors to half/quarter-width TN. template VI SaturateI(VI v) { - const DFromV di; return detail::MinN(detail::MaxN(v, LimitsMin()), LimitsMax()); } } // namespace detail -template -HWY_API svuint8_t DemoteTo(Simd dn, const svint16_t v) { +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint16_t v) { const DFromV di; const RebindToUnsigned du; using TN = TFromD; // First clamp negative numbers to zero and cast to unsigned. - const svuint16_t clamped = BitCast(du, Max(Zero(di), v)); + const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0)); // Saturate to unsigned-max and halve the width. const svuint8_t vn = BitCast(dn, detail::SaturateU(clamped)); return svuzp1_u8(vn, vn); } -template -HWY_API svuint16_t DemoteTo(Simd dn, const svint32_t v) { +template +HWY_API svuint16_t DemoteTo(Simd dn, const svint32_t v) { const DFromV di; const RebindToUnsigned du; using TN = TFromD; // First clamp negative numbers to zero and cast to unsigned. - const svuint32_t clamped = BitCast(du, Max(Zero(di), v)); + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); // Saturate to unsigned-max and halve the width. const svuint16_t vn = BitCast(dn, detail::SaturateU(clamped)); return svuzp1_u16(vn, vn); } -template -HWY_API svuint8_t DemoteTo(Simd dn, const svint32_t v) { +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint32_t v) { const DFromV di; const RebindToUnsigned du; const RepartitionToNarrow d2; using TN = TFromD; // First clamp negative numbers to zero and cast to unsigned. - const svuint32_t clamped = BitCast(du, Max(Zero(di), v)); + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); // Saturate to unsigned-max and quarter the width. const svuint16_t cast16 = BitCast(d2, detail::SaturateU(clamped)); const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16)); @@ -1114,38 +1150,35 @@ HWY_API svuint8_t U8FromU32(const svuint32_t v) { // ------------------------------ DemoteTo I -template -HWY_API svint8_t DemoteTo(Simd dn, const svint16_t v) { - const DFromV di; - using TN = TFromD; +template +HWY_API svint8_t DemoteTo(Simd dn, const svint16_t v) { #if HWY_TARGET == HWY_SVE2 const svint8_t vn = BitCast(dn, svqxtnb_s16(v)); #else + using TN = TFromD; const svint8_t vn = BitCast(dn, detail::SaturateI(v)); #endif return svuzp1_s8(vn, vn); } -template -HWY_API svint16_t DemoteTo(Simd dn, const svint32_t v) { - const DFromV di; - using TN = TFromD; +template +HWY_API svint16_t DemoteTo(Simd dn, const svint32_t v) { #if HWY_TARGET == HWY_SVE2 const svint16_t vn = BitCast(dn, svqxtnb_s32(v)); #else + using TN = TFromD; const svint16_t vn = BitCast(dn, detail::SaturateI(v)); #endif return svuzp1_s16(vn, vn); } -template -HWY_API svint8_t DemoteTo(Simd dn, const svint32_t v) { - const DFromV di; - using TN = TFromD; +template +HWY_API svint8_t DemoteTo(Simd dn, const svint32_t v) { const RepartitionToWide d2; #if HWY_TARGET == HWY_SVE2 const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v))); #else + using TN = TFromD; const svint16_t cast16 = BitCast(d2, detail::SaturateI(v)); #endif const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16)); @@ -1158,10 +1191,10 @@ HWY_API svint8_t DemoteTo(Simd dn, const svint32_t v) { // full vector length, not rounded down to a power of two as we require). namespace detail { -#define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, NAME, OP) \ - HWY_INLINE HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ - return sv##OP##_##CHAR##BITS(lo, hi); \ +#define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi); \ } HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEven, uzp1) HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOdd, uzp2) @@ -1169,7 +1202,7 @@ HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOdd, uzp2) // Used to slide up / shift whole register left; mask indicates which range // to take from lo, and the rest is filled from hi starting at its lowest. -#define HWY_SVE_SPLICE(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) NAME( \ HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \ return sv##OP##_##CHAR##BITS(mask, lo, hi); \ @@ -1203,40 +1236,43 @@ HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { // ------------------------------ DemoteTo F -template -HWY_API svfloat16_t DemoteTo(Simd d, const svfloat32_t v) { - return svcvt_f16_f32_x(detail::PTrue(d), v); +template +HWY_API svfloat16_t DemoteTo(Simd d, const svfloat32_t v) { + const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); + return detail::ConcatEven(in_even, in_even); // only low 1/2 of result valid } -template -HWY_API svuint16_t DemoteTo(Simd d, const svfloat32_t v) { - const svuint16_t halves = BitCast(Full(), v); - return detail::ConcatOdd(halves, halves); // can ignore upper half of vec +template +HWY_API svuint16_t DemoteTo(Simd /* d */, svfloat32_t v) { + const svuint16_t in_even = BitCast(ScalableTag(), v); + return detail::ConcatOdd(in_even, in_even); // can ignore upper half of vec } -template -HWY_API svfloat32_t DemoteTo(Simd d, const svfloat64_t v) { - return svcvt_f32_f64_x(detail::PTrue(d), v); +template +HWY_API svfloat32_t DemoteTo(Simd d, const svfloat64_t v) { + const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v); + return detail::ConcatEven(in_even, in_even); // only low 1/2 of result valid } -template -HWY_API svint32_t DemoteTo(Simd d, const svfloat64_t v) { - return svcvt_s32_f64_x(detail::PTrue(d), v); +template +HWY_API svint32_t DemoteTo(Simd d, const svfloat64_t v) { + const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v); + return detail::ConcatEven(in_even, in_even); // only low 1/2 of result valid } // ------------------------------ ConvertTo F -#define HWY_SVE_CONVERT(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N) /* d */, HWY_SVE_V(int, BITS) v) { \ - return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ - } \ - /* Truncates (rounds toward zero). */ \ - template \ - HWY_API HWY_SVE_V(int, BITS) \ - NAME(HWY_SVE_D(int, BITS, N) /* d */, HWY_SVE_V(BASE, BITS) v) { \ - return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ +#define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Truncates (rounds toward zero). */ \ + template \ + HWY_API HWY_SVE_V(int, BITS) \ + NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ } // API only requires f32 but we provide f64 for use by Iota. @@ -1253,11 +1289,11 @@ HWY_API VFromD NearestInt(VF v) { // ------------------------------ Iota (Add, ConvertTo) -#define HWY_SVE_IOTA(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N) d, HWY_SVE_T(BASE, BITS) first) { \ - return sv##OP##_##CHAR##BITS(first, 1); \ +#define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) first) { \ + return sv##OP##_##CHAR##BITS(first, 1); \ } HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index) @@ -1273,19 +1309,19 @@ HWY_API VFromD Iota(const D d, TFromD first) { namespace detail { -template -svbool_t MaskLowerHalf(Simd d) { +template +svbool_t MaskLowerHalf(D d) { return FirstN(d, Lanes(d) / 2); } -template -svbool_t MaskUpperHalf(Simd d) { +template +svbool_t MaskUpperHalf(D d) { // For Splice to work as intended, make sure bits above Lanes(d) are zero. - return AndNot(MaskLowerHalf(d), detail::Mask(d)); + return AndNot(MaskLowerHalf(d), detail::MakeMask(d)); } // Right-shift vector pair by constexpr; can be used to slide down (=N) or up // (=Lanes()-N). -#define HWY_SVE_EXT(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP) \ template \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ @@ -1348,7 +1384,7 @@ HWY_API V LowerHalf(const V v) { } template -HWY_API V UpperHalf(const D2 d2, const V v) { +HWY_API V UpperHalf(const D2 /* d2 */, const V v) { return detail::Splice(v, v, detail::MaskUpperHalf(Twice())); } @@ -1356,7 +1392,7 @@ HWY_API V UpperHalf(const D2 d2, const V v) { // ------------------------------ GetLane -#define HWY_SVE_GET_LANE(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_T(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ return sv##OP##_##CHAR##BITS(detail::PFalse(), v); \ } @@ -1364,12 +1400,32 @@ HWY_API V UpperHalf(const D2 d2, const V v) { HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLane, lasta) #undef HWY_SVE_GET_LANE +// ------------------------------ DupEven + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1) +} // namespace detail + +template +HWY_API V DupEven(const V v) { + return detail::InterleaveEven(v, v); +} + +// ------------------------------ DupOdd + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2) +} // namespace detail + +template +HWY_API V DupOdd(const V v) { + return detail::InterleaveOdd(v, v); +} + // ------------------------------ OddEven namespace detail { HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVN, Insert, insr_n) -HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1) -HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2) } // namespace detail template @@ -1382,42 +1438,26 @@ HWY_API V OddEven(const V odd, const V even) { template HWY_API V OddEvenBlocks(const V odd, const V even) { const RebindToUnsigned> du; - constexpr size_t kShift = CeilLog2(16 / sizeof(TFromV)); + using TU = TFromD; + constexpr size_t kShift = CeilLog2(16 / sizeof(TU)); const auto idx_block = ShiftRight(Iota(du, 0)); - const svbool_t is_even = Eq(detail::AndN(idx_block, 1), Zero(du)); + const auto lsb = detail::AndN(idx_block, static_cast(1)); + const svbool_t is_even = detail::EqN(lsb, static_cast(0)); return IfThenElse(is_even, even, odd); } -// ------------------------------ SwapAdjacentBlocks - -namespace detail { - -template -constexpr size_t LanesPerBlock(Simd /* tag */) { - // We might have a capped vector smaller than a block, so honor that. - return HWY_MIN(16 / sizeof(T), N); -} - -} // namespace detail - -template -HWY_API V SwapAdjacentBlocks(const V v) { - const DFromV d; - constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); - const V down = detail::Ext(v, v); - const V up = detail::Splice(v, v, FirstN(d, kLanesPerBlock)); - return OddEvenBlocks(up, down); -} - // ------------------------------ TableLookupLanes template HWY_API VFromD> IndicesFromVec(D d, VI vec) { - static_assert(sizeof(TFromD) == sizeof(TFromV), "Index != lane"); + using TI = TFromV; + static_assert(sizeof(TFromD) == sizeof(TI), "Index/lane size mismatch"); const RebindToUnsigned du; const auto indices = BitCast(du, vec); #if HWY_IS_DEBUG_BUILD - HWY_DASSERT(AllTrue(du, Lt(indices, Set(du, Lanes(d))))); + HWY_DASSERT(AllTrue(du, detail::LtN(indices, static_cast(Lanes(d))))); +#else + (void)d; #endif return indices; } @@ -1429,7 +1469,7 @@ HWY_API VFromD> SetTableIndices(D d, const TI* idx) { } // <32bit are not part of Highway API, but used in Broadcast. -#define HWY_SVE_TABLE(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \ return sv##OP##_##CHAR##BITS(v, idx); \ @@ -1438,30 +1478,95 @@ HWY_API VFromD> SetTableIndices(D d, const TI* idx) { HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl) #undef HWY_SVE_TABLE +// ------------------------------ SwapAdjacentBlocks (TableLookupLanes) + +namespace detail { + +template +constexpr size_t LanesPerBlock(Simd /* tag */) { + // We might have a capped vector smaller than a block, so honor that. + return HWY_MIN(16 / sizeof(T), detail::ScaleByPower(N, kPow2)); +} + +} // namespace detail + +template +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV d; + const RebindToUnsigned du; + constexpr auto kLanesPerBlock = + static_cast>(detail::LanesPerBlock(d)); + const VFromD idx = detail::XorN(Iota(du, 0), kLanesPerBlock); + return TableLookupLanes(v, idx); +} + // ------------------------------ Reverse #if 0 // if we could assume VL is a power of two #error "Update macro" #endif -#define HWY_SVE_REVERSE(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(Simd d, HWY_SVE_V(BASE, BITS) v) { \ - const auto reversed = sv##OP##_##CHAR##BITS(v); \ - /* Shift right to remove extra (non-pow2 and remainder) lanes. */ \ - const size_t all_lanes = \ - detail::AllHardwareLanes(hwy::SizeTag()); \ - /* TODO(janwas): on SVE2, use whilege. */ \ - const svbool_t mask = Not(FirstN(d, all_lanes - Lanes(d))); \ - return detail::Splice(reversed, reversed, mask); \ +#define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, HWY_SVE_V(BASE, BITS) v) { \ + const auto reversed = sv##OP##_##CHAR##BITS(v); \ + /* Shift right to remove extra (non-pow2 and remainder) lanes. */ \ + const size_t all_lanes = \ + detail::AllHardwareLanes(hwy::SizeTag()); \ + /* TODO(janwas): on SVE2, use whilege. */ \ + /* Avoids FirstN truncating to the return vector size. */ \ + const ScalableTag dfull; \ + const svbool_t mask = Not(FirstN(dfull, all_lanes - Lanes(d))); \ + return detail::Splice(reversed, reversed, mask); \ } HWY_SVE_FOREACH(HWY_SVE_REVERSE, Reverse, rev) #undef HWY_SVE_REVERSE +// ------------------------------ Reverse2 + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D /* tag */, const VFromD v) { // 3210 + const auto even_in_odd = detail::Insert(v, 0); // 210z + return detail::InterleaveOdd(v, even_in_odd); // 2301 +} + +// ------------------------------ Reverse4 (TableLookupLanes) + +// TODO(janwas): is this approach faster than Shuffle0123? +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 3); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Reverse8 (TableLookupLanes) + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 7); + return TableLookupLanes(v, idx); +} + // ------------------------------ Compress (PromoteTo) -#define HWY_SVE_COMPRESS(BASE, CHAR, BITS, NAME, OP) \ +#define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ return sv##OP##_##CHAR##BITS(mask, v); \ } @@ -1541,7 +1646,7 @@ HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { template svbool_t FirstNPerBlock(D d) { - const RebindToSigned di; + const RebindToSigned di; constexpr size_t kLanesPerBlock = detail::LanesPerBlock(di); const auto idx_mod = detail::AndN(Iota(di, 0), kLanesPerBlock - 1); return detail::LtN(BitCast(di, idx_mod), kLanes); @@ -1562,21 +1667,11 @@ HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) { // ------------------------------ Shuffle2301 -#define HWY_SVE_SHUFFLE_2301(BASE, CHAR, BITS, NAME, OP) \ - HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ - const DFromV d; \ - const svuint64_t vu64 = BitCast(Repartition(), v); \ - return BitCast(d, sv##OP##_u64_x(HWY_SVE_PTRUE(64), vu64)); \ - } - -HWY_SVE_FOREACH_UI32(HWY_SVE_SHUFFLE_2301, Shuffle2301, revw) -#undef HWY_SVE_SHUFFLE_2301 - -template +template HWY_API V Shuffle2301(const V v) { - const DFromV df; - const RebindToUnsigned du; - return BitCast(df, Shuffle2301(BitCast(du, v))); + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return Reverse2(d, v); } // ------------------------------ Shuffle2103 @@ -1625,6 +1720,13 @@ HWY_API V Shuffle0123(const V v) { return Shuffle2301(Shuffle1032(v)); } +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { + const Repartition du64; + return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v)))); +} + // ------------------------------ TableLookupBytes template @@ -1643,11 +1745,7 @@ HWY_API VI TableLookupBytesOr0(const V v, const VI idx) { const Repartition di8; auto idx8 = BitCast(di8, idx); - const auto msb = Lt(idx8, Zero(di8)); -// Prevent overflow in table lookups (unnecessary if native) -#if defined(HWY_EMULATE_SVE) - idx8 = IfThenZeroElse(msb, idx8); -#endif + const auto msb = detail::LtN(idx8, 0); const auto lookup = TableLookupBytes(BitCast(di8, v), idx8); return BitCast(d, IfThenZeroElse(msb, lookup)); @@ -1672,7 +1770,6 @@ HWY_API V Broadcast(const V v) { template > HWY_API V ShiftLeftLanes(D d, const V v) { - const RebindToSigned di; const auto zero = Zero(d); const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes)); // Match x86 semantics by zeroing lower lanes in 128-bit blocks @@ -1685,12 +1782,11 @@ HWY_API V ShiftLeftLanes(const V v) { } // ------------------------------ ShiftRightLanes -template >> -HWY_API V ShiftRightLanes(Simd d, V v) { - const RebindToSigned di; - // For partial vectors, clear upper lanes so we shift in zeros. - if (N != HWY_LANES(T)) { - v = IfThenElseZero(detail::Mask(d), v); +template > +HWY_API V ShiftRightLanes(D d, V v) { + // For capped/fractional vectors, clear upper lanes so we shift in zeros. + if (!detail::IsFull(d)) { + v = IfThenElseZero(detail::MakeMask(d), v); } const auto shifted = detail::Ext(v, v); @@ -1722,12 +1818,6 @@ HWY_API V ShiftRightBytes(const D d, const V v) { // ------------------------------ InterleaveLower -namespace detail { -HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLower, zip1) -// Do not use zip2 to implement PromoteUpperTo or similar because vectors may be -// non-powers of two, so getting the actual "upper half" requires MaskUpperHalf. -} // namespace detail - template HWY_API V InterleaveLower(D d, const V a, const V b) { static_assert(IsSame, TFromV>(), "D/V mismatch"); @@ -1749,8 +1839,9 @@ HWY_API V InterleaveLower(const V a, const V b) { // ------------------------------ InterleaveUpper // Full vector: guaranteed to have at least one block -template >> -HWY_API V InterleaveUpper(Simd d, const V a, const V b) { +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { // Move upper halves of blocks to lower half of vector. const Repartition d64; const auto a64 = BitCast(d64, a); @@ -1760,26 +1851,16 @@ HWY_API V InterleaveUpper(Simd d, const V a, const V b) { return detail::ZipLower(BitCast(d, a_blocks), BitCast(d, b_blocks)); } -// Capped: less than one block -template >> -HWY_API V InterleaveUpper(Simd d, const V a, const V b) { - static_assert(IsSame>(), "D/V mismatch"); - const Half d2; - return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); -} - -// Partial: need runtime check -template = 16)>* = nullptr, - class V = VFromD>> -HWY_API V InterleaveUpper(Simd d, const V a, const V b) { - static_assert(IsSame>(), "D/V mismatch"); +// Capped/fraction: need runtime check +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { // Less than one block: treat as capped - if (Lanes(d) * sizeof(T) < 16) { + if (Lanes(d) * sizeof(TFromD) < 16) { const Half d2; return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); } - return InterleaveUpper(Full(), a, b); + return InterleaveUpper(DFromV(), a, b); } // ------------------------------ ZipLower @@ -1805,11 +1886,12 @@ HWY_API VFromD ZipUpper(DW dw, V a, V b) { // ================================================== REDUCE -#define HWY_SVE_REDUCE(BASE, CHAR, BITS, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N) d, HWY_SVE_V(BASE, BITS) v) { \ - return Set(d, sv##OP##_##CHAR##BITS(detail::Mask(d), v)); \ +#define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, HWY_SVE_V(BASE, BITS) v) { \ + return Set(d, static_cast( \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), v))); \ } HWY_SVE_FOREACH(HWY_SVE_REDUCE, SumOfLanes, addv) @@ -1825,16 +1907,17 @@ HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanes, maxnmv) // ------------------------------ PromoteTo bfloat16 (ZipLower) -template -HWY_API svfloat32_t PromoteTo(Simd df32, const svuint16_t v) { +template +HWY_API svfloat32_t PromoteTo(Simd df32, + const svuint16_t v) { return BitCast(df32, detail::ZipLower(svdup_n_u16(0), v)); } // ------------------------------ ReorderDemote2To (OddEven) -template -HWY_API svuint16_t ReorderDemote2To(Simd dbf16, svfloat32_t a, - svfloat32_t b) { +template +HWY_API svuint16_t ReorderDemote2To(Simd dbf16, + svfloat32_t a, svfloat32_t b) { const RebindToUnsigned du16; const Repartition du32; const svuint32_t b_in_even = ShiftRight<16>(BitCast(du32, b)); @@ -1844,9 +1927,7 @@ HWY_API svuint16_t ReorderDemote2To(Simd dbf16, svfloat32_t a, // ------------------------------ ZeroIfNegative (Lt, IfThenElse) template HWY_API V ZeroIfNegative(const V v) { - const auto v0 = Zero(DFromV()); - // We already have a zero constant, so avoid IfThenZeroElse. - return IfThenElse(Lt(v, v0), v0, v); + return IfThenZeroElse(detail::LtN(v, 0), v); } // ------------------------------ BroadcastSignBit (ShiftRight) @@ -1855,6 +1936,17 @@ HWY_API V BroadcastSignBit(const V v) { return ShiftRight) * 8 - 1>(v); } +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + const svbool_t m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + // ------------------------------ AverageRound (ShiftRight) #if HWY_TARGET == HWY_SVE2 @@ -1863,7 +1955,7 @@ HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) #else template V AverageRound(const V a, const V b) { - return ShiftRight<1>(Add(Add(a, b), Set(DFromV(), 1))); + return ShiftRight<1>(detail::AddN(Add(a, b), 1)); } #endif // HWY_TARGET == HWY_SVE2 @@ -1944,51 +2036,57 @@ HWY_INLINE svbool_t LoadMaskBits(D /* tag */, namespace detail { -// Returns mask ? 1 : 0 in BYTE lanes. -template -HWY_API svuint8_t BoolFromMask(Simd d, svbool_t m) { +// For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes. +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { return svdup_n_u8_z(m, 1); } -template -HWY_API svuint8_t BoolFromMask(Simd d, svbool_t m) { - const Repartition d8; +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d8; const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1)); return detail::ConcatEven(b16, b16); // only lower half needed } -template -HWY_API svuint8_t BoolFromMask(Simd d, svbool_t m) { +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { return U8FromU32(svdup_n_u32_z(m, 1)); } -template -HWY_API svuint8_t BoolFromMask(Simd d, svbool_t m) { - const Repartition d32; +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d32; const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1)); return U8FromU32(detail::ConcatEven(b64, b64)); // only lower half needed } +// Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane. +HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) { + const ScalableTag d8; + const ScalableTag d16; + const ScalableTag d32; + const ScalableTag d64; + // TODO(janwas): could use SVE2 BDEP, but it's optional. + x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x)))); + x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x)))); + x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x)))); + return BitCast(d64, x); +} + } // namespace detail // `p` points to at least 8 writable bytes. -template -HWY_API size_t StoreMaskBits(Simd d, svbool_t m, uint8_t* bits) { - const Repartition d8; - const Repartition d16; - const Repartition d32; - const Repartition d64; - auto x = detail::BoolFromMask(d, m); - // Compact bytes to bits. Could use SVE2 BDEP, but it's optional. - x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x)))); - x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x)))); - x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x)))); +template +HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) { + svuint64_t bits_in_u64 = + detail::BitsFromBool(detail::BoolFromMask>(m)); const size_t num_bits = Lanes(d); const size_t num_bytes = (num_bits + 8 - 1) / 8; // Round up, see below - // Truncate to 8 bits and store. - svst1b_u64(FirstN(d64, num_bytes), bits, BitCast(d64, x)); + // Truncate each u64 to 8 bits and store to u8. + svst1b_u64(FirstN(ScalableTag(), num_bytes), bits, bits_in_u64); // Non-full byte, need to clear the undefined upper bits. Can happen for - // capped/partial vectors or large T and small hardware vectors. + // capped/fractional vectors or large T and small hardware vectors. if (num_bits < 8) { const int mask = (1 << num_bits) - 1; bits[0] = static_cast(bits[0] & mask); @@ -2015,7 +2113,14 @@ HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, #if HWY_TARGET == HWY_SVE2 namespace detail { -HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulEven, mullb) +#define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEven, mullb) +#undef HWY_SVE_MUL_EVEN } // namespace detail #endif @@ -2044,9 +2149,9 @@ HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) -template -HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd df32, svuint16_t a, - svuint16_t b, +template +HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd df32, + svuint16_t a, svuint16_t b, const svfloat32_t sum0, svfloat32_t& sum1) { // TODO(janwas): svbfmlalb_f32 if __ARM_FEATURE_SVE_BF16. @@ -2073,12 +2178,13 @@ HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd df32, svuint16_t a, #endif HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) { - // NOTE: it is important that AESE and AESMC be consecutive instructions so - // they can be fused. AESE includes AddRoundKey, which is a different ordering - // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual - // round key (the compiler will hopefully optimize this for multiple rounds). + // It is not clear whether E and MC fuse like they did on NEON. const svuint8_t zero = svdup_n_u8(0); - return Xor(vaesmcq_u8(vaeseq_u8(state, zero), round_key)); + return Xor(svaesmc_u8(svaese_u8(state, zero)), round_key); +} + +HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) { + return Xor(svaese_u8(state, svdup_n_u8(0)), round_key); } HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) { @@ -2091,6 +2197,44 @@ HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) { #endif // __ARM_FEATURE_SVE2_AES +// ------------------------------ Lt128 + +template +HWY_INLINE svbool_t Lt128(D /* d */, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, "Use u64"); + // Truth table of Eq and Compare for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) = IfThenElse(=H, cL, cH) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const svbool_t eqHL = Eq(a, b); + const svbool_t ltHL = Lt(a, b); + // trn (interleave even/odd) allow us to move and copy masks across lanes. + const svbool_t cmpLL = svtrn1_b64(ltHL, ltHL); + const svbool_t outHx = svsel_b(eqHL, cmpLL, ltHL); // See truth table above. + return svtrn2_b64(outHx, outHx); // replicate to HH +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128(d, a, b), b, a); +} + // ================================================== END MACROS namespace detail { // for code folding #undef HWY_IF_FLOAT_V @@ -2121,7 +2265,6 @@ namespace detail { // for code folding #undef HWY_SVE_FOREACH_UI64 #undef HWY_SVE_FOREACH_UIF3264 #undef HWY_SVE_PTRUE -#undef HWY_SVE_RETV_ARGD #undef HWY_SVE_RETV_ARGPV #undef HWY_SVE_RETV_ARGPVN #undef HWY_SVE_RETV_ARGPVV @@ -2129,6 +2272,7 @@ namespace detail { // for code folding #undef HWY_SVE_RETV_ARGVN #undef HWY_SVE_RETV_ARGVV #undef HWY_SVE_T +#undef HWY_SVE_UNDEFINED #undef HWY_SVE_V } // namespace detail diff --git a/third_party/highway/hwy/ops/generic_ops-inl.h b/third_party/highway/hwy/ops/generic_ops-inl.h index 35cec12f75f0..0c0b5229c51d 100644 --- a/third_party/highway/hwy/ops/generic_ops-inl.h +++ b/third_party/highway/hwy/ops/generic_ops-inl.h @@ -19,14 +19,14 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { -// The lane type of a vector type, e.g. float for Vec>. +// The lane type of a vector type, e.g. float for Vec>. template using LaneType = decltype(GetLane(V())); -// Vector type, e.g. Vec128 for Simd. Useful as the return type -// of functions that do not take a vector argument, or as an argument type if -// the function only has a template argument for D, or for explicit type names -// instead of auto. This may be a built-in type. +// Vector type, e.g. Vec128 for CappedTag. Useful as the return +// type of functions that do not take a vector argument, or as an argument type +// if the function only has a template argument for D, or for explicit type +// names instead of auto. This may be a built-in type. template using Vec = decltype(Zero(D())); @@ -53,12 +53,6 @@ HWY_API V CombineShiftRightLanes(D d, const V hi, const V lo) { return CombineShiftRightBytes(d, hi, lo); } -// DEPRECATED -template -HWY_API V CombineShiftRightLanes(const V hi, const V lo) { - return CombineShiftRightLanes(DFromV(), hi, lo); -} - #endif // Returns lanes with the most significant bit set and all other bits zero. @@ -208,6 +202,15 @@ HWY_API V AESRound(V state, const V round_key) { return state; } +template // u8 +HWY_API V AESLastRound(V state, const V round_key) { + // LIke AESRound, but without MixColumns. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + // Constant-time implementation inspired by // https://www.bearssl.org/constanttime.html, but about half the cost because we // use 64x64 multiplies and 128-bit XORs. @@ -278,23 +281,47 @@ HWY_API V CLMulUpper(V a, V b) { #define HWY_NATIVE_POPCNT #endif -template +#if HWY_TARGET == HWY_RVV +#define HWY_MIN_POW2_FOR_128 1 +#else +// All other targets except HWY_SCALAR (which is excluded by HWY_IF_GE128_D) +// guarantee 128 bits anyway. +#define HWY_MIN_POW2_FOR_128 0 +#endif + +// This algorithm requires vectors to be at least 16 bytes, which is the case +// for LMUL >= 2. If not, use the fallback below. +template ), + HWY_IF_POW2_GE(DFromV, HWY_MIN_POW2_FOR_128)> HWY_API V PopulationCount(V v) { - constexpr DFromV d; + const DFromV d; HWY_ALIGN constexpr uint8_t kLookup[16] = { 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, }; - auto lo = And(v, Set(d, 0xF)); - auto hi = ShiftRight<4>(v); - auto lookup = LoadDup128(Simd(), kLookup); + const auto lo = And(v, Set(d, 0xF)); + const auto hi = ShiftRight<4>(v); + const auto lookup = LoadDup128(d, kLookup); return Add(TableLookupBytes(lookup, hi), TableLookupBytes(lookup, lo)); } +// RVV has a specialization that avoids the Set(). +#if HWY_TARGET != HWY_RVV +// Slower fallback for capped vectors. +template )> +HWY_API V PopulationCount(V v) { + const DFromV d; + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + v = Sub(v, And(ShiftRight<1>(v), Set(d, 0x55))); + v = Add(And(ShiftRight<2>(v), Set(d, 0x33)), And(v, Set(d, 0x33))); + return And(Add(v, ShiftRight<4>(v)), Set(d, 0x0F)); +} +#endif // HWY_TARGET != HWY_RVV + template HWY_API V PopulationCount(V v) { const DFromV d; - Repartition d8; - auto vals = BitCast(d, PopulationCount(BitCast(d8, v))); + const Repartition d8; + const auto vals = BitCast(d, PopulationCount(BitCast(d8, v))); return Add(ShiftRight<8>(vals), And(vals, Set(d, 0xFF))); } @@ -306,7 +333,7 @@ HWY_API V PopulationCount(V v) { return Add(ShiftRight<16>(vals), And(vals, Set(d, 0xFF))); } -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 template HWY_API V PopulationCount(V v) { const DFromV d; diff --git a/third_party/highway/hwy/ops/rvv-inl.h b/third_party/highway/hwy/ops/rvv-inl.h index 14a0306fe84b..44e10d7db6e6 100644 --- a/third_party/highway/hwy/ops/rvv-inl.h +++ b/third_party/highway/hwy/ops/rvv-inl.h @@ -34,17 +34,14 @@ using DFromV = typename DFromV_t>::type; template using TFromV = TFromD>; -template -HWY_INLINE constexpr size_t MLenFromD(Simd /* tag */) { - // Returns divisor = type bits / LMUL - return sizeof(T) * 8 / (N / HWY_LANES(T)); +template +constexpr size_t MLenFromD(Simd d) { + // Returns divisor = type bits / LMUL. Folding *8 into the ScaleByPower + // argument enables fractional LMUL < 1. Limit to 64 because that is the + // largest value for which vbool##_t are defined. + return HWY_MIN(64, sizeof(T) * 8 * 8 / detail::ScaleByPower(8, kPow2)); } -// kShift = log2 of multiplier: 0 for m1, 1 for m2, -2 for mf4 -template -using Full = Simd> (-kShift)) - : (HWY_LANES(T) << kShift)>; - // ================================================== MACROS // Generate specializations and function definitions using X macros. Although @@ -63,231 +60,308 @@ namespace detail { // for code folding X_MACRO(8, 2, 2, NAME, OP) \ X_MACRO(8, 3, 1, NAME, OP) -// For given SEW, iterate over all LMUL. Precompute SEW/LMUL => MLEN because we -// need to token-paste the result. For the same reason, we also pass the -// twice-as-long and half-as-long LMUL suffixes as arguments. -// TODO(janwas): add fractional LMUL -#define HWY_RVV_FOREACH_08(X_MACRO, BASE, CHAR, NAME, OP) \ - X_MACRO(BASE, CHAR, 8, m1, m2, mf2, /*kShift=*/0, /*MLEN=*/8, NAME, OP) \ - X_MACRO(BASE, CHAR, 8, m2, m4, m1, /*kShift=*/1, /*MLEN=*/4, NAME, OP) \ - X_MACRO(BASE, CHAR, 8, m4, m8, m2, /*kShift=*/2, /*MLEN=*/2, NAME, OP) \ - X_MACRO(BASE, CHAR, 8, m8, __, m4, /*kShift=*/3, /*MLEN=*/1, NAME, OP) +// For given SEW, iterate over one of LMULS: _TRUNC, _EXT, _ALL. This allows +// reusing type lists such as HWY_RVV_FOREACH_U for _ALL (the usual case) or +// _EXT (for Combine). To achieve this, we HWY_CONCAT with the LMULS suffix. +// +// Precompute SEW/LMUL => MLEN to allow token-pasting the result. For the same +// reason, also pass the double-width and half SEW and LMUL (suffixed D and H, +// respectively). "__" means there is no corresponding LMUL (e.g. LMULD for m8). +// Args: BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP -#define HWY_RVV_FOREACH_16(X_MACRO, BASE, CHAR, NAME, OP) \ - X_MACRO(BASE, CHAR, 16, m1, m2, mf2, /*kShift=*/0, /*MLEN=*/16, NAME, OP) \ - X_MACRO(BASE, CHAR, 16, m2, m4, m1, /*kShift=*/1, /*MLEN=*/8, NAME, OP) \ - X_MACRO(BASE, CHAR, 16, m4, m8, m2, /*kShift=*/2, /*MLEN=*/4, NAME, OP) \ - X_MACRO(BASE, CHAR, 16, m8, __, m4, /*kShift=*/3, /*MLEN=*/2, NAME, OP) +// LMULS = _TRUNC: truncatable (not the smallest LMUL) +#define HWY_RVV_FOREACH_08_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) -#define HWY_RVV_FOREACH_32(X_MACRO, BASE, CHAR, NAME, OP) \ - X_MACRO(BASE, CHAR, 32, m1, m2, mf2, /*kShift=*/0, /*MLEN=*/32, NAME, OP) \ - X_MACRO(BASE, CHAR, 32, m2, m4, m1, /*kShift=*/1, /*MLEN=*/16, NAME, OP) \ - X_MACRO(BASE, CHAR, 32, m4, m8, m2, /*kShift=*/2, /*MLEN=*/8, NAME, OP) \ - X_MACRO(BASE, CHAR, 32, m8, __, m4, /*kShift=*/3, /*MLEN=*/4, NAME, OP) +#define HWY_RVV_FOREACH_16_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) -#define HWY_RVV_FOREACH_64(X_MACRO, BASE, CHAR, NAME, OP) \ - X_MACRO(BASE, CHAR, 64, m1, m2, mf2, /*kShift=*/0, /*MLEN=*/64, NAME, OP) \ - X_MACRO(BASE, CHAR, 64, m2, m4, m1, /*kShift=*/1, /*MLEN=*/32, NAME, OP) \ - X_MACRO(BASE, CHAR, 64, m4, m8, m2, /*kShift=*/2, /*MLEN=*/16, NAME, OP) \ - X_MACRO(BASE, CHAR, 64, m8, __, m4, /*kShift=*/3, /*MLEN=*/8, NAME, OP) +#define HWY_RVV_FOREACH_32_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _DEMOTE: can demote from SEW*LMUL to SEWH*LMULH. +#define HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _EXT: not the largest LMUL +#define HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) + +// LMULS = _ALL (2^MinPow2() <= LMUL <= 8) +#define HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) // SEW for unsigned: -#define HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_08(X_MACRO, uint, u, NAME, OP) -#define HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_16(X_MACRO, uint, u, NAME, OP) -#define HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_32(X_MACRO, uint, u, NAME, OP) -#define HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_64(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, uint, u, NAME, OP) // SEW for signed: -#define HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_08(X_MACRO, int, i, NAME, OP) -#define HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_16(X_MACRO, int, i, NAME, OP) -#define HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_32(X_MACRO, int, i, NAME, OP) -#define HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_64(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, int, i, NAME, OP) // SEW for float: -#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_16(X_MACRO, float, f, NAME, OP) -#define HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_32(X_MACRO, float, f, NAME, OP) -#define HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_64(X_MACRO, float, f, NAME, OP) +#if HWY_HAVE_FLOAT16 +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, float, f, NAME, OP) +#else +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) +#endif +#define HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, float, f, NAME, OP) +#define HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, float, f, NAME, OP) + +// Commonly used type/SEW groups: +#define HWY_RVV_FOREACH_UI08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI3264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) // For all combinations of SEW: -#define HWY_RVV_FOREACH_U(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP) +#define HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) -#define HWY_RVV_FOREACH_I(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP) +#define HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) -#if HWY_CAP_FLOAT16 -#define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP) -#else -#define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP) -#endif - -// Commonly used type categories for a given SEW: -#define HWY_RVV_FOREACH_UI16(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP) - -#define HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP) - -#define HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP) +#define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) // Commonly used type categories: -#define HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_U(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_I(X_MACRO, NAME, OP) +#define HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) -#define HWY_RVV_FOREACH(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_U(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_I(X_MACRO, NAME, OP) \ - HWY_RVV_FOREACH_F(X_MACRO, NAME, OP) +#define HWY_RVV_FOREACH(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) // Assemble types for use in x-macros #define HWY_RVV_T(BASE, SEW) BASE##SEW##_t -#define HWY_RVV_D(CHAR, SEW, LMUL) D##CHAR##SEW##LMUL +#define HWY_RVV_D(BASE, SEW, N, SHIFT) Simd #define HWY_RVV_V(BASE, SEW, LMUL) v##BASE##SEW##LMUL##_t #define HWY_RVV_M(MLEN) vbool##MLEN##_t } // namespace detail -// TODO(janwas): remove typedefs and only use HWY_RVV_V etc. directly - // Until we have full intrinsic support for fractional LMUL, mixed-precision // code can use LMUL 1..8 (adequate unless they need many registers). -#define HWY_SPECIALIZE(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ - using HWY_RVV_D(CHAR, SEW, LMUL) = Full; \ - using V##CHAR##SEW##LMUL = HWY_RVV_V(BASE, SEW, LMUL); \ +#define HWY_SPECIALIZE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ template <> \ struct DFromV_t { \ using Lane = HWY_RVV_T(BASE, SEW); \ - using type = Full; \ + using type = ScalableTag; \ }; -#if HWY_CAP_FLOAT16 -using Vf16m1 = vfloat16m1_t; -using Vf16m2 = vfloat16m2_t; -using Vf16m4 = vfloat16m4_t; -using Vf16m8 = vfloat16m8_t; -using Df16m1 = Full; -using Df16m2 = Full; -using Df16m4 = Full; -using Df16m8 = Full; -#endif -HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _) +HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _, _ALL) #undef HWY_SPECIALIZE // ------------------------------ Lanes // WARNING: we want to query VLMAX/sizeof(T), but this actually changes VL! // vlenb is not exposed through intrinsics and vreadvl is not VLMAX. -#define HWY_RVV_LANES(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ - HWY_API size_t NAME(HWY_RVV_D(CHAR, SEW, LMUL) /* d */) { \ - return v##OP##SEW##LMUL(); \ +#define HWY_RVV_LANES(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + const size_t actual = v##OP##SEW##LMUL(); \ + /* Common case of full vectors: avoid any extra instructions. */ \ + /* actual includes LMUL, so do not shift again. */ \ + return detail::IsFull(d) ? actual : HWY_MIN(actual, N); \ } -HWY_RVV_FOREACH(HWY_RVV_LANES, Lanes, setvlmax_e) +HWY_RVV_FOREACH(HWY_RVV_LANES, Lanes, setvlmax_e, _ALL) #undef HWY_RVV_LANES -// Capped -template * = nullptr> -HWY_API size_t Lanes(Simd /* tag*/) { - return HWY_MIN(N, Lanes(Full())); -} - -template -HWY_API size_t Lanes(Simd /* tag*/) { - return Lanes(Simd()); +template +HWY_API size_t Lanes(Simd /* tag*/) { + return Lanes(Simd()); } // ------------------------------ Common x-macros -// Last argument to most intrinsics. Use when the op has no d arg of its own. -#define HWY_RVV_AVL(SEW, SHIFT) Lanes(Full()) +// Last argument to most intrinsics. Use when the op has no d arg of its own, +// which means there is no user-specified cap. +#define HWY_RVV_AVL(SEW, SHIFT) \ + Lanes(ScalableTag()) // vector = f(vector), e.g. Not -#define HWY_RVV_RETV_ARGV(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ - return v##OP##_v_##CHAR##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ +#define HWY_RVV_RETV_ARGV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ } // vector = f(vector, scalar), e.g. detail::AddS -#define HWY_RVV_RETV_ARGVS(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ - return v##OP##_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ +#define HWY_RVV_RETV_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return v##OP##_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ } // vector = f(vector, vector), e.g. Add -#define HWY_RVV_RETV_ARGVV(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ - return v##OP##_vv_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ +#define HWY_RVV_RETV_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ } // ================================================== INIT // ------------------------------ Set -#define HWY_RVV_SET(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_D(CHAR, SEW, LMUL) d, HWY_RVV_T(BASE, SEW) arg) { \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_T(BASE, SEW) arg) { \ return v##OP##_##CHAR##SEW##LMUL(arg, Lanes(d)); \ } -HWY_RVV_FOREACH_UI(HWY_RVV_SET, Set, mv_v_x) -HWY_RVV_FOREACH_F(HWY_RVV_SET, Set, fmv_v_f) +HWY_RVV_FOREACH_UI(HWY_RVV_SET, Set, mv_v_x, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_SET, Set, fmv_v_f, _ALL) #undef HWY_RVV_SET // Treat bfloat16_t as uint16_t (using the previously defined Set overloads); // required for Zero and VFromD. -template -decltype(Set(Simd(), 0)) Set(Simd d, - bfloat16_t arg) { +template +decltype(Set(Simd(), 0)) Set(Simd d, + bfloat16_t arg) { return Set(RebindToUnsigned(), arg.bits); } -// Capped vectors -template * = nullptr> -HWY_API decltype(Set(Full(), T{0})) Set(Simd /*tag*/, T arg) { - return Set(Full(), arg); -} - template using VFromD = decltype(Set(D(), TFromD())); // ------------------------------ Zero -template -HWY_API VFromD> Zero(Simd d) { +template +HWY_API VFromD> Zero(Simd d) { return Set(d, T(0)); } @@ -297,14 +371,15 @@ HWY_API VFromD> Zero(Simd d) { // by it gives unpredictable results. It should only be used for maskoff, so // keep it internal. For the Highway op, just use Zero (single instruction). namespace detail { -#define HWY_RVV_UNDEFINED(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_D(CHAR, SEW, LMUL) /* tag */) { \ - return v##OP##_##CHAR##SEW##LMUL(); /* no AVL */ \ +#define HWY_RVV_UNDEFINED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) /* tag */) { \ + return v##OP##_##CHAR##SEW##LMUL(); /* no AVL */ \ } -HWY_RVV_FOREACH(HWY_RVV_UNDEFINED, Undefined, undefined) +HWY_RVV_FOREACH(HWY_RVV_UNDEFINED, Undefined, undefined, _ALL) #undef HWY_RVV_UNDEFINED } // namespace detail @@ -318,68 +393,74 @@ HWY_API VFromD Undefined(D d) { namespace detail { // There is no reinterpret from u8 <-> u8, so just return. -#define HWY_RVV_CAST_U8(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API vuint8##LMUL##_t BitCastToByte(vuint8##LMUL##_t v) { return v; } \ - HWY_API vuint8##LMUL##_t BitCastFromByte(HWY_RVV_D(CHAR, SEW, LMUL) /* d */, \ - vuint8##LMUL##_t v) { \ - return v; \ +#define HWY_RVV_CAST_U8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API vuint8##LMUL##_t BitCastToByte(vuint8##LMUL##_t v) { return v; } \ + template \ + HWY_API vuint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v; \ } // For i8, need a single reinterpret (HWY_RVV_CAST_IF does two). -#define HWY_RVV_CAST_I8(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API vuint8##LMUL##_t BitCastToByte(vint8##LMUL##_t v) { \ - return vreinterpret_v_i8##LMUL##_u8##LMUL(v); \ - } \ - HWY_API vint8##LMUL##_t BitCastFromByte(HWY_RVV_D(CHAR, SEW, LMUL) /* d */, \ - vuint8##LMUL##_t v) { \ - return vreinterpret_v_u8##LMUL##_i8##LMUL(v); \ +#define HWY_RVV_CAST_I8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API vuint8##LMUL##_t BitCastToByte(vint8##LMUL##_t v) { \ + return vreinterpret_v_i8##LMUL##_u8##LMUL(v); \ + } \ + template \ + HWY_API vint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return vreinterpret_v_u8##LMUL##_i8##LMUL(v); \ } // Separate u/i because clang only provides signed <-> unsigned reinterpret for // the same SEW. -#define HWY_RVV_CAST_U(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_CAST_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ HWY_API vuint8##LMUL##_t BitCastToByte(HWY_RVV_V(BASE, SEW, LMUL) v) { \ return v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v); \ } \ + template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ - HWY_RVV_D(CHAR, SEW, LMUL) /* d */, vuint8##LMUL##_t v) { \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ return v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v); \ } // Signed/Float: first cast to/from unsigned -#define HWY_RVV_CAST_IF(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API vuint8##LMUL##_t BitCastToByte(HWY_RVV_V(BASE, SEW, LMUL) v) { \ - return v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ - v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v)); \ - } \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ - HWY_RVV_D(CHAR, SEW, LMUL) /* d */, vuint8##LMUL##_t v) { \ - return v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ - v##OP##_v_u8##LMUL##_u##SEW##LMUL(v)); \ +#define HWY_RVV_CAST_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API vuint8##LMUL##_t BitCastToByte(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + v##OP##_v_u8##LMUL##_u##SEW##LMUL(v)); \ } -HWY_RVV_FOREACH_U08(HWY_RVV_CAST_U8, _, reinterpret) -HWY_RVV_FOREACH_I08(HWY_RVV_CAST_I8, _, reinterpret) -HWY_RVV_FOREACH_U16(HWY_RVV_CAST_U, _, reinterpret) -HWY_RVV_FOREACH_U32(HWY_RVV_CAST_U, _, reinterpret) -HWY_RVV_FOREACH_U64(HWY_RVV_CAST_U, _, reinterpret) -HWY_RVV_FOREACH_I16(HWY_RVV_CAST_IF, _, reinterpret) -HWY_RVV_FOREACH_I32(HWY_RVV_CAST_IF, _, reinterpret) -HWY_RVV_FOREACH_I64(HWY_RVV_CAST_IF, _, reinterpret) -HWY_RVV_FOREACH_F(HWY_RVV_CAST_IF, _, reinterpret) +// Cannot use existing type lists because U/I8 are no-ops. +HWY_RVV_FOREACH_U08(HWY_RVV_CAST_U8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I08(HWY_RVV_CAST_I8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_CAST_U, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U32(HWY_RVV_CAST_U, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U64(HWY_RVV_CAST_U, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I32(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I64(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_IF, _, reinterpret, _ALL) #undef HWY_RVV_CAST_U8 #undef HWY_RVV_CAST_I8 #undef HWY_RVV_CAST_U #undef HWY_RVV_CAST_IF -template -HWY_INLINE VFromD> BitCastFromByte( - Simd /* d */, VFromD> v) { - return BitCastFromByte(Simd(), v); +template +HWY_INLINE VFromD> BitCastFromByte( + Simd /* d */, VFromD> v) { + return BitCastFromByte(Simd(), v); } } // namespace detail @@ -389,13 +470,6 @@ HWY_API VFromD BitCast(D d, FromV v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } -// Capped -template * = nullptr> -HWY_API VFromD> BitCast(Simd /*tag*/, FromV v) { - return BitCast(Full(), v); -} - namespace detail { template >> @@ -409,12 +483,14 @@ HWY_INLINE VFromD BitCastToUnsigned(V v) { namespace detail { -#define HWY_RVV_IOTA(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_D(CHAR, SEW, LMUL) d) { \ - return v##OP##_##CHAR##SEW##LMUL(Lanes(d)); \ +#define HWY_RVV_IOTA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + return v##OP##_##CHAR##SEW##LMUL(Lanes(d)); \ } -HWY_RVV_FOREACH_U(HWY_RVV_IOTA, Iota0, id_v) +HWY_RVV_FOREACH_U(HWY_RVV_IOTA, Iota0, id_v, _ALL) #undef HWY_RVV_IOTA template > @@ -422,20 +498,13 @@ HWY_INLINE VFromD Iota0(const D /*d*/) { return BitCastToUnsigned(Iota0(DU())); } -// Capped -template , - hwy::EnableIf<(N < HWY_LANES(T) / 8)>* = nullptr> -HWY_INLINE VFromD> Iota0(Simd /*tag*/) { - return Iota0(Full()); -} - } // namespace detail // ================================================== LOGICAL // ------------------------------ Not -HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGV, Not, not ) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGV, Not, not, _ALL) template HWY_API V Not(const V v) { @@ -448,10 +517,10 @@ HWY_API V Not(const V v) { // Non-vector version (ideally immediate) for use with Iota0 namespace detail { -HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AndS, and_vx) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AndS, and_vx, _ALL) } // namespace detail -HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, And, and) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, And, and, _ALL) template HWY_API V And(const V a, const V b) { @@ -462,9 +531,7 @@ HWY_API V And(const V a, const V b) { // ------------------------------ Or -#undef HWY_RVV_OR_MASK - -HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Or, or) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Or, or, _ALL) template HWY_API V Or(const V a, const V b) { @@ -477,10 +544,10 @@ HWY_API V Or(const V a, const V b) { // Non-vector version (ideally immediate) for use with Iota0 namespace detail { -HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, XorS, xor_vx) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, XorS, xor_vx, _ALL) } // namespace detail -HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Xor, xor) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Xor, xor, _ALL) template HWY_API V Xor(const V a, const V b) { @@ -496,9 +563,16 @@ HWY_API V AndNot(const V not_a, const V b) { return And(Not(not_a), b); } +// ------------------------------ OrAnd + +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + // ------------------------------ CopySign -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, CopySign, fsgnj) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, CopySign, fsgnj, _ALL) template HWY_API V CopySignToAbs(const V abs, const V sign) { @@ -511,43 +585,46 @@ HWY_API V CopySignToAbs(const V abs, const V sign) { // ------------------------------ Add namespace detail { -HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AddS, add_vx) -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, AddS, fadd_vf) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AddS, add_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, AddS, fadd_vf, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, ReverseSubS, rsub_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, ReverseSubS, frsub_vf, _ALL) } // namespace detail -HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Add, add) -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Add, fadd) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Add, add, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Add, fadd, _ALL) // ------------------------------ Sub -HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Sub, sub) -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Sub, fsub) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Sub, sub, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Sub, fsub, _ALL) // ------------------------------ SaturatedAdd -HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu) -HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu) +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) -HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd) -HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd) +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) // ------------------------------ SaturatedSub -HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu) -HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu) +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) -HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub) -HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub) +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) // ------------------------------ AverageRound // TODO(janwas): check vxrm rounding mode -HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, AverageRound, aaddu) -HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, AverageRound, aaddu) +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, AverageRound, aaddu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, AverageRound, aaddu, _ALL) // ------------------------------ ShiftLeft[Same] // Intrinsics do not define .vi forms, so use .vx instead. -#define HWY_RVV_SHIFT(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_SHIFT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ return v##OP##_vx_##CHAR##SEW##LMUL(v, kBits, HWY_RVV_AVL(SEW, SHIFT)); \ @@ -558,15 +635,39 @@ HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, AverageRound, aaddu) HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH_UI(HWY_RVV_SHIFT, ShiftLeft, sll) +HWY_RVV_FOREACH_UI(HWY_RVV_SHIFT, ShiftLeft, sll, _ALL) // ------------------------------ ShiftRight[Same] -HWY_RVV_FOREACH_U(HWY_RVV_SHIFT, ShiftRight, srl) -HWY_RVV_FOREACH_I(HWY_RVV_SHIFT, ShiftRight, sra) +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT, ShiftRight, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT, ShiftRight, sra, _ALL) #undef HWY_RVV_SHIFT +// ------------------------------ SumsOf8 (ShiftRight, Add) +template +HWY_API VFromD>> SumsOf8(const VU8 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = detail::AndS(BitCast(du16, v), 0xFF); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return detail::AndS(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), 0xFFFFull); +} + // ------------------------------ RotateRight template HWY_API V RotateRight(const V v) { @@ -577,109 +678,110 @@ HWY_API V RotateRight(const V v) { } // ------------------------------ Shl -#define HWY_RVV_SHIFT_VV(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ - return v##OP##_vv_##CHAR##SEW##LMUL(v, bits, HWY_RVV_AVL(SEW, SHIFT)); \ +#define HWY_RVV_SHIFT_VV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, bits, HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shl, sll) +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shl, sll, _ALL) -#define HWY_RVV_SHIFT_II(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ - return v##OP##_vv_##CHAR##SEW##LMUL(v, detail::BitCastToUnsigned(bits), \ - HWY_RVV_AVL(SEW, SHIFT)); \ +#define HWY_RVV_SHIFT_II(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, detail::BitCastToUnsigned(bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shl, sll) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shl, sll, _ALL) // ------------------------------ Shr -HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shr, srl) -HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shr, sra) +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shr, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shr, sra, _ALL) #undef HWY_RVV_SHIFT_II #undef HWY_RVV_SHIFT_VV // ------------------------------ Min -HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Min, minu) -HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Min, min) -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Min, fmin) +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Min, minu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Min, min, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Min, fmin, _ALL) // ------------------------------ Max namespace detail { -HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, MaxS, maxu_vx) -HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, MaxS, max_vx) -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, MaxS, fmax_vf) +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, MaxS, maxu_vx, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, MaxS, max_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, MaxS, fmax_vf, _ALL) } // namespace detail -HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Max, maxu) -HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Max, max) -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Max, fmax) +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Max, maxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Max, max, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Max, fmax, _ALL) // ------------------------------ Mul // Only for internal use (Highway only promises Mul for 16/32-bit inputs). // Used by MulLower. namespace detail { -HWY_RVV_FOREACH_U64(HWY_RVV_RETV_ARGVV, Mul, mul) +HWY_RVV_FOREACH_U64(HWY_RVV_RETV_ARGVV, Mul, mul, _ALL) } // namespace detail -HWY_RVV_FOREACH_UI16(HWY_RVV_RETV_ARGVV, Mul, mul) -HWY_RVV_FOREACH_UI32(HWY_RVV_RETV_ARGVV, Mul, mul) -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Mul, fmul) +HWY_RVV_FOREACH_UI16(HWY_RVV_RETV_ARGVV, Mul, mul, _ALL) +HWY_RVV_FOREACH_UI32(HWY_RVV_RETV_ARGVV, Mul, mul, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Mul, fmul, _ALL) // ------------------------------ MulHigh // Only for internal use (Highway only promises MulHigh for 16-bit inputs). // Used by MulEven; vwmul does not work for m8. namespace detail { -HWY_RVV_FOREACH_I32(HWY_RVV_RETV_ARGVV, MulHigh, mulh) -HWY_RVV_FOREACH_U32(HWY_RVV_RETV_ARGVV, MulHigh, mulhu) -HWY_RVV_FOREACH_U64(HWY_RVV_RETV_ARGVV, MulHigh, mulhu) +HWY_RVV_FOREACH_I32(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) +HWY_RVV_FOREACH_U32(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +HWY_RVV_FOREACH_U64(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) } // namespace detail -HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, MulHigh, mulhu) -HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulHigh, mulh) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) // ------------------------------ Div -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Div, fdiv) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Div, fdiv, _ALL) // ------------------------------ ApproximateReciprocal -HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocal, frec7) +HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocal, frec7, _ALL) // ------------------------------ Sqrt -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, Sqrt, fsqrt) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, Sqrt, fsqrt, _ALL) // ------------------------------ ApproximateReciprocalSqrt -HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocalSqrt, frsqrt7) +HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocalSqrt, frsqrt7, _ALL) // ------------------------------ MulAdd // Note: op is still named vv, not vvv. -#define HWY_RVV_FMA(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_FMA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_V(BASE, SEW, LMUL) mul, HWY_RVV_V(BASE, SEW, LMUL) x, \ HWY_RVV_V(BASE, SEW, LMUL) add) { \ return v##OP##_vv_##CHAR##SEW##LMUL(add, mul, x, HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulAdd, fmacc) +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulAdd, fmacc, _ALL) // ------------------------------ NegMulAdd -HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulAdd, fnmsac) +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulAdd, fnmsac, _ALL) // ------------------------------ MulSub -HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulSub, fmsac) +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulSub, fmsac, _ALL) // ------------------------------ NegMulSub -HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc) +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL) #undef HWY_RVV_FMA @@ -690,42 +792,53 @@ HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc) // of all bits; SLEN 8 / LMUL 4 = half of all bits. // mask = f(vector, vector) -#define HWY_RVV_RETM_ARGVV(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_M(MLEN) \ - NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ - return v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN(a, b, \ - HWY_RVV_AVL(SEW, SHIFT)); \ +#define HWY_RVV_RETM_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN(a, b, \ + HWY_RVV_AVL(SEW, SHIFT)); \ } // mask = f(vector, scalar) -#define HWY_RVV_RETM_ARGVS(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ +#define HWY_RVV_RETM_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ HWY_API HWY_RVV_M(MLEN) \ NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ - return v##OP##_vx_##CHAR##SEW##LMUL##_b##MLEN(a, b, \ - HWY_RVV_AVL(SEW, SHIFT)); \ + return v##OP##_##CHAR##SEW##LMUL##_b##MLEN(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ } // ------------------------------ Eq -HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Eq, mseq) -HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Eq, mfeq) - -// ------------------------------ Ne -HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Ne, msne) -HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Ne, mfne) - -// ------------------------------ Lt -HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Lt, msltu) -HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Lt, mslt) -HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Lt, mflt) +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Eq, mseq, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Eq, mfeq, _ALL) namespace detail { -HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, LtS, mslt) +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, EqS, mseq_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, EqS, mfeq_vf, _ALL) +} // namespace detail + +// ------------------------------ Ne +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Ne, msne, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Ne, mfne, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, NeS, msne_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, NeS, mfne_vf, _ALL) +} // namespace detail + +// ------------------------------ Lt +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Lt, msltu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Lt, mslt, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Lt, mflt, _ALL) + +namespace detail { +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, LtS, mslt_vx, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVS, LtS, msltu_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, LtS, mflt_vf, _ALL) } // namespace detail // ------------------------------ Le -HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Le, mfle) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Le, mfle, _ALL) #undef HWY_RVV_RETM_ARGVV #undef HWY_RVV_RETM_ARGVS @@ -745,7 +858,7 @@ HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) { // ------------------------------ TestBit template HWY_API auto TestBit(const V a, const V bit) -> decltype(Eq(a, bit)) { - return Ne(And(a, bit), Zero(DFromV())); + return detail::NeS(And(a, bit), 0); } // ------------------------------ Not @@ -782,15 +895,15 @@ HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Xor, xor) #undef HWY_RVV_RETM_ARGMM // ------------------------------ IfThenElse -#define HWY_RVV_IF_THEN_ELSE(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, \ - NAME, OP) \ +#define HWY_RVV_IF_THEN_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) yes, \ HWY_RVV_V(BASE, SEW, LMUL) no) { \ return v##OP##_vvm_##CHAR##SEW##LMUL(m, no, yes, HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH(HWY_RVV_IF_THEN_ELSE, IfThenElse, merge) +HWY_RVV_FOREACH(HWY_RVV_IF_THEN_ELSE, IfThenElse, merge, _ALL) #undef HWY_RVV_IF_THEN_ELSE @@ -801,16 +914,24 @@ HWY_API V IfThenElseZero(const M mask, const V yes) { } // ------------------------------ IfThenZeroElse -template -HWY_API V IfThenZeroElse(const M mask, const V no) { - return IfThenElse(mask, Zero(DFromV()), no); -} + +#define HWY_RVV_IF_THEN_ZERO_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return v##OP##_##CHAR##SEW##LMUL(m, no, 0, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, merge_vxm, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, fmerge_vfm, _ALL) + +#undef HWY_RVV_IF_THEN_ZERO_ELSE // ------------------------------ MaskFromVec template HWY_API auto MaskFromVec(const V v) -> decltype(Eq(v, v)) { - return Ne(v, Zero(DFromV())); + return detail::NeS(v, 0); } template @@ -826,15 +947,15 @@ HWY_API MFromD RebindMask(const D /*d*/, const MFrom mask) { // ------------------------------ VecFromMask namespace detail { -#define HWY_RVV_VEC_FROM_MASK(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, \ - NAME, OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_M(MLEN) m) { \ - return v##OP##_##CHAR##SEW##LMUL##_m(m, v0, v0, 1, \ - HWY_RVV_AVL(SEW, SHIFT)); \ +#define HWY_RVV_VEC_FROM_MASK(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_M(MLEN) m) { \ + return v##OP##_##CHAR##SEW##LMUL##_m(m, v0, v0, 1, \ + HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH_UI(HWY_RVV_VEC_FROM_MASK, SubS, sub_vx) +HWY_RVV_FOREACH_UI(HWY_RVV_VEC_FROM_MASK, SubS, sub_vx, _ALL) #undef HWY_RVV_VEC_FROM_MASK } // namespace detail @@ -848,12 +969,17 @@ HWY_API VFromD VecFromMask(const D d, MFromD mask) { return BitCast(d, VecFromMask(RebindToUnsigned(), mask)); } +// ------------------------------ IfVecThenElse (MaskFromVec) + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + // ------------------------------ ZeroIfNegative template HWY_API V ZeroIfNegative(const V v) { - const auto v0 = Zero(DFromV()); - // We already have a zero constant, so avoid IfThenZeroElse. - return IfThenElse(Lt(v, v0), v0, v); + return IfThenZeroElse(detail::LtS(v, 0), v); } // ------------------------------ BroadcastSignBit @@ -862,6 +988,18 @@ HWY_API V BroadcastSignBit(const V v) { return ShiftRight) * 8 - 1>(v); } +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + MFromD m = + MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + // ------------------------------ FindFirstTrue #define HWY_RVV_FIND_FIRST_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ @@ -908,33 +1046,28 @@ HWY_RVV_FOREACH_B(HWY_RVV_COUNT_TRUE, _, _) // ------------------------------ Load -#define HWY_RVV_LOAD(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_D(CHAR, SEW, LMUL) d, \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, Lanes(d)); \ } -HWY_RVV_FOREACH(HWY_RVV_LOAD, Load, le) +HWY_RVV_FOREACH(HWY_RVV_LOAD, Load, le, _ALL) #undef HWY_RVV_LOAD -// Capped -template * = nullptr> -HWY_API VFromD> Load(Simd d, const T* HWY_RESTRICT p) { - return Load(d, p); -} - // There is no native BF16, treat as uint16_t. -template -HWY_API VFromD> Load(Simd d, - const bfloat16_t* HWY_RESTRICT p) { +template +HWY_API VFromD> Load( + Simd d, const bfloat16_t* HWY_RESTRICT p) { return Load(RebindToUnsigned(), reinterpret_cast(p)); } -template -HWY_API void Store(VFromD> v, Simd d, - bfloat16_t* HWY_RESTRICT p) { +template +HWY_API void Store(VFromD> v, + Simd d, bfloat16_t* HWY_RESTRICT p) { Store(v, RebindToUnsigned(), reinterpret_cast(p)); } @@ -949,59 +1082,55 @@ HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { // ------------------------------ MaskedLoad -#define HWY_RVV_MASKED_LOAD(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, \ - NAME, OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_M(MLEN) m, HWY_RVV_D(CHAR, SEW, LMUL) d, \ - const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, Zero(d), p, Lanes(d)); \ +#define HWY_RVV_MASKED_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, Zero(d), p, Lanes(d)); \ } -HWY_RVV_FOREACH(HWY_RVV_MASKED_LOAD, MaskedLoad, le) +HWY_RVV_FOREACH(HWY_RVV_MASKED_LOAD, MaskedLoad, le, _ALL) #undef HWY_RVV_MASKED_LOAD // ------------------------------ Store -#define HWY_RVV_RET_ARGVDP(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ - HWY_RVV_D(CHAR, SEW, LMUL) d, \ - HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, Lanes(d)); \ +#define HWY_RVV_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, Lanes(d)); \ } -HWY_RVV_FOREACH(HWY_RVV_RET_ARGVDP, Store, se) -#undef HWY_RVV_RET_ARGVDP - -// Capped -template * = nullptr> -HWY_API void Store(VFromD> v, Simd /* d */, - T* HWY_RESTRICT p) { - return Store(v, Full(), p); -} +HWY_RVV_FOREACH(HWY_RVV_STORE, Store, se, _ALL) +#undef HWY_RVV_STORE // ------------------------------ MaskedStore -#define HWY_RVV_RET_ARGMVDP(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, \ - NAME, OP) \ - HWY_API void NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) v, \ - HWY_RVV_D(CHAR, SEW, LMUL) d, \ - HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, p, v, Lanes(d)); \ +#define HWY_RVV_MASKED_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, p, v, Lanes(d)); \ } -HWY_RVV_FOREACH(HWY_RVV_RET_ARGMVDP, MaskedStore, se) -#undef HWY_RVV_RET_ARGMVDP +HWY_RVV_FOREACH(HWY_RVV_MASKED_STORE, MaskedStore, se, _ALL) +#undef HWY_RVV_MASKED_STORE namespace detail { -#define HWY_RVV_RET_ARGNVDP(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, \ - NAME, OP) \ - HWY_API void NAME(size_t count, HWY_RVV_V(BASE, SEW, LMUL) v, \ - HWY_RVV_D(CHAR, SEW, LMUL) /* d */, \ - HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, count); \ +#define HWY_RVV_STOREN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(size_t count, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, count); \ } -HWY_RVV_FOREACH(HWY_RVV_RET_ARGNVDP, StoreN, se) -#undef HWY_RVV_RET_ARGNVDP +HWY_RVV_FOREACH(HWY_RVV_STOREN, StoreN, se, _ALL) +#undef HWY_RVV_STOREN } // namespace detail @@ -1021,27 +1150,19 @@ HWY_API void Stream(const V v, D d, T* HWY_RESTRICT aligned) { // ------------------------------ ScatterOffset -#define HWY_RVV_SCATTER(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ - HWY_RVV_D(CHAR, SEW, LMUL) d, \ - HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ - HWY_RVV_V(int, SEW, LMUL) offset) { \ - return v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ - base, detail::BitCastToUnsigned(offset), v, Lanes(d)); \ +#define HWY_RVV_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + return v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + base, detail::BitCastToUnsigned(offset), v, Lanes(d)); \ } -HWY_RVV_FOREACH(HWY_RVV_SCATTER, ScatterOffset, sux) +HWY_RVV_FOREACH(HWY_RVV_SCATTER, ScatterOffset, sux, _ALL) #undef HWY_RVV_SCATTER -// Capped -template * = nullptr> -HWY_API void ScatterOffset(VFromD> v, Simd /* d */, - T* HWY_RESTRICT base, - VFromD, N>> offset) { - return ScatterOffset(v, Full(), base, offset); -} - // ------------------------------ ScatterIndex template @@ -1058,26 +1179,19 @@ HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, // ------------------------------ GatherOffset -#define HWY_RVV_GATHER(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_GATHER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_D(CHAR, SEW, LMUL) d, \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ HWY_RVV_V(int, SEW, LMUL) offset) { \ return v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ base, detail::BitCastToUnsigned(offset), Lanes(d)); \ } -HWY_RVV_FOREACH(HWY_RVV_GATHER, GatherOffset, lux) +HWY_RVV_FOREACH(HWY_RVV_GATHER, GatherOffset, lux, _ALL) #undef HWY_RVV_GATHER -// Capped -template * = nullptr> -HWY_API VFromD> GatherOffset(Simd /* d */, - const T* HWY_RESTRICT base, - VFromD, N>> offset) { - return GatherOffset(Full(), base, offset); -} - // ------------------------------ GatherIndex template @@ -1097,10 +1211,12 @@ HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, // ------------------------------ StoreInterleaved3 -#define HWY_RVV_STORE3(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_STORE3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ HWY_API void NAME( \ HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ - HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_D(CHAR, SEW, LMUL) d, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ const v##BASE##SEW##LMUL##x3_t triple = \ vcreate_##CHAR##SEW##LMUL##x3(v0, v1, v2); \ @@ -1112,22 +1228,15 @@ HWY_RVV_STORE3(uint, u, 8, m2, /*kShift=*/1, 4, StoreInterleaved3, sseg3) #undef HWY_RVV_STORE3 -// Capped -template * = nullptr> -HWY_API void StoreInterleaved3(VFromD> v0, VFromD> v1, - VFromD> v2, Simd /*tag*/, - T* unaligned) { - return StoreInterleaved3(v0, v1, v2, Full(), unaligned); -} - // ------------------------------ StoreInterleaved4 -#define HWY_RVV_STORE4(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_STORE4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ HWY_API void NAME( \ HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_V(BASE, SEW, LMUL) v3, \ - HWY_RVV_D(CHAR, SEW, LMUL) d, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ HWY_RVV_T(BASE, SEW) * HWY_RESTRICT aligned) { \ const v##BASE##SEW##LMUL##x4_t quad = \ vcreate_##CHAR##SEW##LMUL##x4(v0, v1, v2, v3); \ @@ -1139,54 +1248,58 @@ HWY_RVV_STORE4(uint, u, 8, m2, /*kShift=*/1, 4, StoreInterleaved4, sseg4) #undef HWY_RVV_STORE4 -// Capped -template * = nullptr> -HWY_API void StoreInterleaved4(VFromD> v0, VFromD> v1, - VFromD> v2, VFromD> v3, - Simd /*tag*/, T* unaligned) { - return StoreInterleaved4(v0, v1, v2, v3, Full(), unaligned); -} - #endif // GCC // ================================================== CONVERT -#define HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, LMUL, LMUL_IN) \ - HWY_API HWY_RVV_V(BASE, BITS, LMUL) PromoteTo( \ - HWY_RVV_D(CHAR, BITS, LMUL) d, HWY_RVV_V(BASE_IN, BITS_IN, LMUL_IN) v) { \ +// ------------------------------ PromoteTo + +// SEW is for the input so we can use F16 (no-op if not supported). +#define HWY_RVV_PROMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWD##LMULD(v, Lanes(d)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT) +HWY_RVV_FOREACH_U16(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT) +HWY_RVV_FOREACH_U32(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT) +HWY_RVV_FOREACH_I08(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT) +HWY_RVV_FOREACH_I16(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT) +HWY_RVV_FOREACH_I32(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT) +HWY_RVV_FOREACH_F16(HWY_RVV_PROMOTE, PromoteTo, vfwcvt_f_f_v_, _EXT) +HWY_RVV_FOREACH_F32(HWY_RVV_PROMOTE, PromoteTo, vfwcvt_f_f_v_, _EXT) +#undef HWY_RVV_PROMOTE + +// The above X-macro cannot handle 4x promotion nor type switching. +// TODO(janwas): use BASE2 arg to allow the latter. +#define HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, LMUL, LMUL_IN, \ + SHIFT, ADD) \ + template \ + HWY_API HWY_RVV_V(BASE, BITS, LMUL) \ + PromoteTo(HWY_RVV_D(BASE, BITS, N, SHIFT + ADD) d, \ + HWY_RVV_V(BASE_IN, BITS_IN, LMUL_IN) v) { \ return OP##CHAR##BITS##LMUL(v, Lanes(d)); \ } -#define HWY_RVV_PROMOTE_X2(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ - HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2) \ - HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, m1) \ - HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m2) \ - HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m4) +#define HWY_RVV_PROMOTE_X2(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, m1, 0, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m2, 1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m4, 2, 1) -#define HWY_RVV_PROMOTE_X4(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ - HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf4) \ - HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, mf2) \ - HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m1) \ - HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m2) +#define HWY_RVV_PROMOTE_X4(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, mf2, mf8, -3, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf4, -2, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, mf2, -1, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m1, 0, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m2, 1, 2) -// ------------------------------ PromoteTo - -HWY_RVV_PROMOTE_X2(vzext_vf2_, uint, u, 16, uint, 8) -HWY_RVV_PROMOTE_X2(vzext_vf2_, uint, u, 32, uint, 16) -HWY_RVV_PROMOTE_X2(vzext_vf2_, uint, u, 64, uint, 32) HWY_RVV_PROMOTE_X4(vzext_vf4_, uint, u, 32, uint, 8) - -HWY_RVV_PROMOTE_X2(vsext_vf2_, int, i, 16, int, 8) -HWY_RVV_PROMOTE_X2(vsext_vf2_, int, i, 32, int, 16) -HWY_RVV_PROMOTE_X2(vsext_vf2_, int, i, 64, int, 32) HWY_RVV_PROMOTE_X4(vsext_vf4_, int, i, 32, int, 8) -#if HWY_CAP_FLOAT16 -HWY_RVV_PROMOTE_X2(vfwcvt_f_f_v_, float, f, 32, float, 16) -#endif -HWY_RVV_PROMOTE_X2(vfwcvt_f_f_v_, float, f, 64, float, 32) - // i32 to f64 HWY_RVV_PROMOTE_X2(vfwcvt_f_x_v_, float, f, 64, int, 32) @@ -1194,26 +1307,31 @@ HWY_RVV_PROMOTE_X2(vfwcvt_f_x_v_, float, f, 64, int, 32) #undef HWY_RVV_PROMOTE_X2 #undef HWY_RVV_PROMOTE -template -HWY_API auto PromoteTo(Simd d, VFromD> v) +// Unsigned to signed: cast for unsigned promote. +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) -> VFromD { return BitCast(d, PromoteTo(RebindToUnsigned(), v)); } -template -HWY_API auto PromoteTo(Simd d, VFromD> v) +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) -> VFromD { return BitCast(d, PromoteTo(RebindToUnsigned(), v)); } -template -HWY_API auto PromoteTo(Simd d, VFromD> v) +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) -> VFromD { return BitCast(d, PromoteTo(RebindToUnsigned(), v)); } -template -HWY_API auto PromoteTo(Simd d, VFromD> v) +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) -> VFromD { const RebindToSigned di32; const Rebind du16; @@ -1222,150 +1340,165 @@ HWY_API auto PromoteTo(Simd d, VFromD> v) // ------------------------------ DemoteTo U +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWH##LMULH(v, 0, Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME##Shr16( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWH##LMULH(v, 16, Lanes(d)); \ + } + // Unsigned -> unsigned (also used for bf16) namespace detail { - -HWY_INLINE Vu16m1 DemoteTo(Du16m1 d, const Vu32m2 v) { - return vnclipu_wx_u16m1(v, 0, Lanes(d)); -} -HWY_INLINE Vu16m2 DemoteTo(Du16m2 d, const Vu32m4 v) { - return vnclipu_wx_u16m2(v, 0, Lanes(d)); -} -HWY_INLINE Vu16m4 DemoteTo(Du16m4 d, const Vu32m8 v) { - return vnclipu_wx_u16m4(v, 0, Lanes(d)); -} - -HWY_INLINE Vu8m1 DemoteTo(Du8m1 d, const Vu16m2 v) { - return vnclipu_wx_u8m1(v, 0, Lanes(d)); -} -HWY_INLINE Vu8m2 DemoteTo(Du8m2 d, const Vu16m4 v) { - return vnclipu_wx_u8m2(v, 0, Lanes(d)); -} -HWY_INLINE Vu8m4 DemoteTo(Du8m4 d, const Vu16m8 v) { - return vnclipu_wx_u8m4(v, 0, Lanes(d)); -} - +HWY_RVV_FOREACH_U16(HWY_RVV_DEMOTE, DemoteTo, vnclipu_wx_, _DEMOTE) +HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE, DemoteTo, vnclipu_wx_, _DEMOTE) } // namespace detail -// First clamp negative numbers to zero to match x86 packus. -HWY_API Vu16m1 DemoteTo(Du16m1 d, const Vi32m2 v) { - return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE_I_TO_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(uint, SEWH, LMULH) NAME( \ + HWY_RVV_D(uint, SEWH, N, SHIFT - 1) d, HWY_RVV_V(int, SEW, LMUL) v) { \ + /* First clamp negative numbers to zero to match x86 packus. */ \ + return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); \ + } +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE) +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE) +#undef HWY_RVV_DEMOTE_I_TO_U + +template +HWY_API vuint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return vnclipu_wx_u8mf8(DemoteTo(Simd(), v), 0, Lanes(d)); } -HWY_API Vu16m2 DemoteTo(Du16m2 d, const Vi32m4 v) { - return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); +template +HWY_API vuint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return vnclipu_wx_u8mf4(DemoteTo(Simd(), v), 0, Lanes(d)); } -HWY_API Vu16m4 DemoteTo(Du16m4 d, const Vi32m8 v) { - return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); +template +HWY_API vuint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return vnclipu_wx_u8mf2(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return vnclipu_wx_u8m1(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return vnclipu_wx_u8m2(DemoteTo(Simd(), v), 0, Lanes(d)); } -HWY_API Vu8m1 DemoteTo(Du8m1 d, const Vi32m4 v) { - return vnclipu_wx_u8m1(DemoteTo(Du16m2(), v), 0, Lanes(d)); +HWY_API vuint8mf8_t U8FromU32(const vuint32mf2_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8mf8(vnclipu_wx_u16mf4(v, 0, avl), 0, avl); } -HWY_API Vu8m2 DemoteTo(Du8m2 d, const Vi32m8 v) { - return vnclipu_wx_u8m2(DemoteTo(Du16m4(), v), 0, Lanes(d)); +HWY_API vuint8mf4_t U8FromU32(const vuint32m1_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8mf4(vnclipu_wx_u16mf2(v, 0, avl), 0, avl); } - -HWY_API Vu8m1 DemoteTo(Du8m1 d, const Vi16m2 v) { - return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); +HWY_API vuint8mf2_t U8FromU32(const vuint32m2_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8mf2(vnclipu_wx_u16m1(v, 0, avl), 0, avl); } -HWY_API Vu8m2 DemoteTo(Du8m2 d, const Vi16m4 v) { - return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); -} -HWY_API Vu8m4 DemoteTo(Du8m4 d, const Vi16m8 v) { - return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); -} - -HWY_API Vu8m1 U8FromU32(const Vu32m4 v) { - const size_t avl = Lanes(Full()); +HWY_API vuint8m1_t U8FromU32(const vuint32m4_t v) { + const size_t avl = Lanes(ScalableTag()); return vnclipu_wx_u8m1(vnclipu_wx_u16m2(v, 0, avl), 0, avl); } -HWY_API Vu8m2 U8FromU32(const Vu32m8 v) { - const size_t avl = Lanes(Full()); +HWY_API vuint8m2_t U8FromU32(const vuint32m8_t v) { + const size_t avl = Lanes(ScalableTag()); return vnclipu_wx_u8m2(vnclipu_wx_u16m4(v, 0, avl), 0, avl); } // ------------------------------ DemoteTo I -HWY_API Vi8m1 DemoteTo(Di8m1 d, const Vi16m2 v) { - return vnclip_wx_i8m1(v, 0, Lanes(d)); +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE, DemoteTo, vnclip_wx_, _DEMOTE) +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE, DemoteTo, vnclip_wx_, _DEMOTE) + +template +HWY_API vint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); } -HWY_API Vi8m2 DemoteTo(Di8m2 d, const Vi16m4 v) { - return vnclip_wx_i8m2(v, 0, Lanes(d)); +template +HWY_API vint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); } -HWY_API Vi8m4 DemoteTo(Di8m4 d, const Vi16m8 v) { - return vnclip_wx_i8m4(v, 0, Lanes(d)); +template +HWY_API vint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); } -HWY_API Vi16m1 DemoteTo(Di16m1 d, const Vi32m2 v) { - return vnclip_wx_i16m1(v, 0, Lanes(d)); -} -HWY_API Vi16m2 DemoteTo(Di16m2 d, const Vi32m4 v) { - return vnclip_wx_i16m2(v, 0, Lanes(d)); -} -HWY_API Vi16m4 DemoteTo(Di16m4 d, const Vi32m8 v) { - return vnclip_wx_i16m4(v, 0, Lanes(d)); -} - -HWY_API Vi8m1 DemoteTo(Di8m1 d, const Vi32m4 v) { - return DemoteTo(d, DemoteTo(Di16m2(), v)); -} -HWY_API Vi8m2 DemoteTo(Di8m2 d, const Vi32m8 v) { - return DemoteTo(d, DemoteTo(Di16m4(), v)); -} +#undef HWY_RVV_DEMOTE // ------------------------------ DemoteTo F -#if HWY_CAP_FLOAT16 -HWY_API Vf16m1 DemoteTo(Df16m1 d, const Vf32m2 v) { - return vfncvt_rod_f_f_w_f16m1(v, Lanes(d)); -} -HWY_API Vf16m2 DemoteTo(Df16m2 d, const Vf32m4 v) { - return vfncvt_rod_f_f_w_f16m2(v, Lanes(d)); -} -HWY_API Vf16m4 DemoteTo(Df16m4 d, const Vf32m8 v) { - return vfncvt_rod_f_f_w_f16m4(v, Lanes(d)); -} +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE_F(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##SEWH##LMULH(v, Lanes(d)); \ + } + +#if HWY_HAVE_FLOAT16 +HWY_RVV_FOREACH_F32(HWY_RVV_DEMOTE_F, DemoteTo, vfncvt_rod_f_f_w_f, _DEMOTE) #endif +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteTo, vfncvt_rod_f_f_w_f, _DEMOTE) +#undef HWY_RVV_DEMOTE_F -HWY_API Vf32m1 DemoteTo(Df32m1 d, const Vf64m2 v) { - return vfncvt_rod_f_f_w_f32m1(v, Lanes(d)); +// TODO(janwas): add BASE2 arg to allow generating this via DEMOTE_F. +template +HWY_API vint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); } -HWY_API Vf32m2 DemoteTo(Df32m2 d, const Vf64m4 v) { - return vfncvt_rod_f_f_w_f32m2(v, Lanes(d)); -} -HWY_API Vf32m4 DemoteTo(Df32m4 d, const Vf64m8 v) { - return vfncvt_rod_f_f_w_f32m4(v, Lanes(d)); -} - -HWY_API Vi32m1 DemoteTo(Di32m1 d, const Vf64m2 v) { +template +HWY_API vint32m1_t DemoteTo(Simd d, const vfloat64m2_t v) { return vfncvt_rtz_x_f_w_i32m1(v, Lanes(d)); } -HWY_API Vi32m2 DemoteTo(Di32m2 d, const Vf64m4 v) { +template +HWY_API vint32m2_t DemoteTo(Simd d, const vfloat64m4_t v) { return vfncvt_rtz_x_f_w_i32m2(v, Lanes(d)); } -HWY_API Vi32m4 DemoteTo(Di32m4 d, const Vf64m8 v) { +template +HWY_API vint32m4_t DemoteTo(Simd d, const vfloat64m8_t v) { return vfncvt_rtz_x_f_w_i32m4(v, Lanes(d)); } -template -HWY_API VFromD> DemoteTo(Simd d, - VFromD> v) { +template +HWY_API VFromD> DemoteTo( + Simd d, VFromD> v) { const RebindToUnsigned du16; const Rebind du32; - return DemoteTo(du16, BitCast(du32, v)); + return detail::DemoteToShr16(du16, BitCast(du32, v)); } // ------------------------------ ConvertTo F -#define HWY_RVV_CONVERT(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - ConvertTo(HWY_RVV_D(CHAR, SEW, LMUL) d, HWY_RVV_V(int, SEW, LMUL) v) { \ +#define HWY_RVV_CONVERT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(int, SEW, LMUL) v) { \ return vfcvt_f_x_v_f##SEW##LMUL(v, Lanes(d)); \ } \ /* Truncates (rounds toward zero). */ \ - HWY_API HWY_RVV_V(int, SEW, LMUL) \ - ConvertTo(HWY_RVV_D(i, SEW, LMUL) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + template \ + HWY_API HWY_RVV_V(int, SEW, LMUL) ConvertTo(HWY_RVV_D(int, SEW, N, SHIFT) d, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ return vfcvt_rtz_x_f_v_i##SEW##LMUL(v, Lanes(d)); \ } \ /* Uses default rounding mode. */ \ @@ -1375,24 +1508,17 @@ HWY_API VFromD> DemoteTo(Simd d, // API only requires f32 but we provide f64 for internal use (otherwise, it // seems difficult to implement Iota without a _mf2 vector half). -HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _) +HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _, _ALL) #undef HWY_RVV_CONVERT -// Capped -template * = nullptr> -HWY_API VFromD> ConvertTo(Simd /*tag*/, FromV v) { - return ConvertTo(Full(), v); -} - // ================================================== COMBINE namespace detail { // For x86-compatible behaviour mandated by Highway API: TableLookupBytes // offsets are implicitly relative to the start of their 128-bit block. -template -constexpr size_t LanesPerBlock(Simd /* tag */) { +template +constexpr size_t LanesPerBlock(Simd /* tag */) { // Also cap to the limit imposed by D (for fixed-size <= 128-bit vectors). return HWY_MIN(16 / sizeof(T), N); } @@ -1413,7 +1539,8 @@ HWY_INLINE MFromD FirstNPerBlock(D /* tag */) { } // vector = f(vector, vector, size_t) -#define HWY_RVV_SLIDE(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_SLIDE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_V(BASE, SEW, LMUL) dst, HWY_RVV_V(BASE, SEW, LMUL) src, \ size_t lanes) { \ @@ -1421,8 +1548,8 @@ HWY_INLINE MFromD FirstNPerBlock(D /* tag */) { HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideUp, slideup) -HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideDown, slidedown) +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideUp, slideup, _ALL) +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideDown, slidedown, _ALL) #undef HWY_RVV_SLIDE @@ -1459,90 +1586,131 @@ HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { // ------------------------------ Combine -// TODO(janwas): implement after LMUL ext/trunc -#if 0 +namespace detail { +#define HWY_RVV_EXT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMULD) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULD(v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT, Ext, lmul_ext, _EXT) +#undef HWY_RVV_EXT +} // namespace detail -template -HWY_API V Combine(const V a, const V b) { - using D = DFromV; - // double LMUL of inputs, then SlideUp with Lanes(). +template +HWY_API VFromD Combine(D2 d2, const V hi, const V lo) { + return detail::SlideUp(detail::Ext(lo), detail::Ext(hi), Lanes(d2) / 2); } -#endif - // ------------------------------ ZeroExtendVector -template -HWY_API V ZeroExtendVector(const V lo) { - return Combine(Xor(lo, lo), lo); +template +HWY_API VFromD ZeroExtendVector(D2 d2, const V lo) { + return Combine(d2, Xor(lo, lo), lo); } // ------------------------------ Lower/UpperHalf namespace detail { -#define HWY_RVV_TRUNC(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ - HWY_API HWY_RVV_V(BASE, SEW, HALF) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ - return v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##HALF(v); /* no AVL */ \ +// LMUL is for the source so we can use _TRUNC. +#define HWY_RVV_TRUNC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH(v); /* no AVL */ \ } -HWY_RVV_FOREACH_U08(HWY_RVV_TRUNC, Trunc, lmul_trunc) -HWY_RVV_FOREACH_I08(HWY_RVV_TRUNC, Trunc, lmul_trunc) -HWY_RVV_FOREACH_UI16(HWY_RVV_TRUNC, Trunc, lmul_trunc) -HWY_RVV_FOREACH_UI32(HWY_RVV_TRUNC, Trunc, lmul_trunc) -#if HWY_CAP_FLOAT16 -HWY_RVV_FOREACH_F16(HWY_RVV_TRUNC, Trunc, lmul_trunc) -#endif -HWY_RVV_FOREACH_F32(HWY_RVV_TRUNC, Trunc, lmul_trunc) +HWY_RVV_FOREACH_UI08(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) +HWY_RVV_FOREACH_UI16(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) +HWY_RVV_FOREACH_UI32(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) +HWY_RVV_FOREACH_F16(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) +HWY_RVV_FOREACH_F32(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) #undef HWY_RVV_TRUNC } // namespace detail -template -HWY_API VFromD LowerHalf(const D /* tag */, const VFromD v) { +template +HWY_API VFromD LowerHalf(const D2 /* tag */, const VFromD> v) { return detail::Trunc(v); } // Intrinsics do not provide mf2 for 64-bit T because VLEN might only be 64, // so "half-vectors" might not exist. However, the application processor profile // requires VLEN >= 128. Bypass this by casting to 32-bit. -template -HWY_API VFromD LowerHalf(const D d, const VFromD v) { - const Repartition d32; - return BitCast(d, detail::Trunc(BitCast(Twice(), v))); +template +HWY_API VFromD LowerHalf(const D2 d2, const VFromD> v) { + // Always use unsigned in case the type is float, in which case we might not + // support float16. + using TH = UnsignedFromSize) / 2>; + const Twice> dn; + return BitCast(d2, detail::Trunc(BitCast(dn, v))); } -template -HWY_API VFromD UpperHalf(const D d, const VFromD v) { - return LowerHalf(d, detail::SlideDown(v, v, Lanes(d))); +// Same, but without D arg +template +HWY_API VFromD>> LowerHalf(const V v) { + return LowerHalf(Half>(), v); +} + +template +HWY_API VFromD UpperHalf(const D2 d2, const VFromD> v) { + return LowerHalf(d2, detail::SlideDown(v, v, Lanes(d2))); } // ================================================== SWIZZLE -// ------------------------------ GetLane - -#define HWY_RVV_GET_LANE(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_T(BASE, SEW) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ - return v##OP##_s_##CHAR##SEW##LMUL##_##CHAR##SEW(v); /* no AVL */ \ +namespace detail { +// Special instruction for 1 lane is presumably faster? +#define HWY_RVV_SLIDE1(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_##CHAR##SEW##LMUL(v, 0, HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH_UI(HWY_RVV_GET_LANE, GetLane, mv_x) -HWY_RVV_FOREACH_F(HWY_RVV_GET_LANE, GetLane, fmv_f) +HWY_RVV_FOREACH_UI3264(HWY_RVV_SLIDE1, Slide1Up, slide1up_vx, _ALL) +HWY_RVV_FOREACH_F3264(HWY_RVV_SLIDE1, Slide1Up, fslide1up_vf, _ALL) +HWY_RVV_FOREACH_UI3264(HWY_RVV_SLIDE1, Slide1Down, slide1down_vx, _ALL) +HWY_RVV_FOREACH_F3264(HWY_RVV_SLIDE1, Slide1Down, fslide1down_vf, _ALL) +#undef HWY_RVV_SLIDE1 +} // namespace detail + +// ------------------------------ GetLane + +#define HWY_RVV_GET_LANE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_T(BASE, SEW) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_s_##CHAR##SEW##LMUL##_##CHAR##SEW(v); /* no AVL */ \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_GET_LANE, GetLane, mv_x, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_GET_LANE, GetLane, fmv_f, _ALL) #undef HWY_RVV_GET_LANE // ------------------------------ OddEven template HWY_API V OddEven(const V a, const V b) { const RebindToUnsigned> du; // Iota0 is unsigned only - const auto is_even = Eq(detail::AndS(detail::Iota0(du), 1), Zero(du)); + const auto is_even = detail::EqS(detail::AndS(detail::Iota0(du), 1), 0); return IfThenElse(is_even, b, a); } +// ------------------------------ DupEven (OddEven) +template +HWY_API V DupEven(const V v) { + const V up = detail::Slide1Up(v); + return OddEven(up, v); +} + +// ------------------------------ DupOdd (OddEven) +template +HWY_API V DupOdd(const V v) { + const V down = detail::Slide1Down(v); + return OddEven(v, down); +} + // ------------------------------ OddEvenBlocks template HWY_API V OddEvenBlocks(const V a, const V b) { const RebindToUnsigned> du; // Iota0 is unsigned only constexpr size_t kShift = CeilLog2(16 / sizeof(TFromV)); const auto idx_block = ShiftRight(detail::Iota0(du)); - const auto is_even = Eq(detail::AndS(idx_block, 1), Zero(du)); + const auto is_even = detail::EqS(detail::AndS(idx_block, 1), 0); return IfThenElse(is_even, b, a); } @@ -1565,7 +1733,7 @@ HWY_API VFromD> IndicesFromVec(D d, VI vec) { const RebindToUnsigned du; // instead of : avoids unused d. const auto indices = BitCast(du, vec); #if HWY_IS_DEBUG_BUILD - HWY_DASSERT(AllTrue(du, Lt(indices, Set(du, Lanes(d))))); + HWY_DASSERT(AllTrue(du, detail::LtS(indices, Lanes(d)))); #endif return indices; } @@ -1578,13 +1746,14 @@ HWY_API VFromD> SetTableIndices(D d, const TI* idx) { // <32bit are not part of Highway API, but used in Broadcast. This limits VLMAX // to 2048! We could instead use vrgatherei16. -#define HWY_RVV_TABLE(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_TABLE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEW, LMUL) idx) { \ return v##OP##_vv_##CHAR##SEW##LMUL(v, idx, HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH(HWY_RVV_TABLE, TableLookupLanes, rgather) +HWY_RVV_FOREACH(HWY_RVV_TABLE, TableLookupLanes, rgather, _ALL) #undef HWY_RVV_TABLE // ------------------------------ Reverse @@ -1593,23 +1762,73 @@ HWY_API VFromD Reverse(D /* tag */, VFromD v) { const RebindToUnsigned du; using TU = TFromD; const size_t N = Lanes(du); - const auto idx = Sub(Set(du, static_cast(N - 1)), detail::Iota0(du)); + const auto idx = + detail::ReverseSubS(detail::Iota0(du), static_cast(N - 1)); return TableLookupLanes(v, idx); } +// ------------------------------ Reverse2 (RotateRight, OddEven) + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Repartition du64; + return BitCast(d, RotateRight<32>(BitCast(du64, v))); +} + +template +HWY_API VFromD Reverse2(D /* tag */, const VFromD v) { + const VFromD up = detail::Slide1Up(v); + const VFromD down = detail::Slide1Down(v); + return OddEven(up, down); +} + +// ------------------------------ Reverse4 (TableLookupLanes) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 3); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ Reverse8 (TableLookupLanes) + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 7); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { + const Repartition du64; + const size_t N = Lanes(du64); + const auto rev = + detail::ReverseSubS(detail::Iota0(du64), static_cast(N - 1)); + // Swap lo/hi u64 within each block + const auto idx = detail::XorS(rev, 1); + return BitCast(d, TableLookupLanes(BitCast(du64, v), idx)); +} + // ------------------------------ Compress -#define HWY_RVV_COMPRESS(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ +#define HWY_RVV_COMPRESS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) mask) { \ return v##OP##_vm_##CHAR##SEW##LMUL(mask, v, v, HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH_UI16(HWY_RVV_COMPRESS, Compress, compress) -HWY_RVV_FOREACH_UI32(HWY_RVV_COMPRESS, Compress, compress) -HWY_RVV_FOREACH_UI64(HWY_RVV_COMPRESS, Compress, compress) -HWY_RVV_FOREACH_F(HWY_RVV_COMPRESS, Compress, compress) +HWY_RVV_FOREACH_UI163264(HWY_RVV_COMPRESS, Compress, compress, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_COMPRESS, Compress, compress, _ALL) #undef HWY_RVV_COMPRESS // ------------------------------ CompressStore @@ -1705,7 +1924,8 @@ HWY_API V Shuffle0123(const V v) { template HWY_API VI TableLookupBytes(const V v, const VI idx) { - const DFromV d; + const DFromV d; + const DFromV di; const Repartition d8; const auto offsets128 = detail::OffsetsOf128BitBlocks(d8, detail::Iota0(d8)); const auto idx8 = Add(BitCast(d8, idx), offsets128); @@ -1718,7 +1938,7 @@ HWY_API VI TableLookupBytesOr0(const V v, const VI idx) { // Mask size must match vector type, so cast everything to this type. const Repartition di8; const auto lookup = TableLookupBytes(BitCast(di8, v), BitCast(di8, idx)); - const auto msb = Lt(BitCast(di8, idx), Zero(di8)); + const auto msb = detail::LtS(BitCast(di8, idx), 0); return BitCast(d, IfThenZeroElse(msb, lookup)); } @@ -1740,11 +1960,12 @@ HWY_API V Broadcast(const V v) { template > HWY_API V ShiftLeftLanes(const D d, const V v) { const RebindToSigned di; + using TI = TFromD; const auto shifted = detail::SlideUp(v, v, kLanes); // Match x86 semantics by zeroing lower lanes in 128-bit blocks constexpr size_t kLanesPerBlock = detail::LanesPerBlock(di); const auto idx_mod = detail::AndS(detail::Iota0(di), kLanesPerBlock - 1); - const auto clear = Lt(BitCast(di, idx_mod), Set(di, kLanes)); + const auto clear = detail::LtS(BitCast(di, idx_mod), static_cast(kLanes)); return IfThenZeroElse(clear, shifted); } @@ -1755,8 +1976,8 @@ HWY_API V ShiftLeftLanes(const V v) { // ------------------------------ ShiftLeftBytes -template -HWY_API V ShiftLeftBytes(DFromV d, const V v) { +template +HWY_API VFromD ShiftLeftBytes(D d, const VFromD v) { const Repartition d8; return BitCast(d, ShiftLeftLanes(BitCast(d8, v))); } @@ -1767,9 +1988,11 @@ HWY_API V ShiftLeftBytes(const V v) { } // ------------------------------ ShiftRightLanes -template >> -HWY_API V ShiftRightLanes(const Simd d, V v) { +template >> +HWY_API V ShiftRightLanes(const Simd d, V v) { const RebindToSigned di; + using TI = TFromD; // For partial vectors, clear upper lanes so we shift in zeros. if (N <= 16 / sizeof(T)) { v = IfThenElseZero(FirstN(d, N), v); @@ -1779,7 +2002,8 @@ HWY_API V ShiftRightLanes(const Simd d, V v) { // Match x86 semantics by zeroing upper lanes in 128-bit blocks constexpr size_t kLanesPerBlock = detail::LanesPerBlock(di); const auto idx_mod = detail::AndS(detail::Iota0(di), kLanesPerBlock - 1); - const auto keep = Lt(BitCast(di, idx_mod), Set(di, kLanesPerBlock - kLanes)); + const auto keep = detail::LtS(BitCast(di, idx_mod), + static_cast(kLanesPerBlock - kLanes)); return IfThenElseZero(keep, shifted); } @@ -1801,7 +2025,7 @@ HWY_API V InterleaveLower(D d, const V a, const V b) { const auto i = detail::Iota0(du); const auto idx_mod = ShiftRight<1>(detail::AndS(i, kLanesPerBlock - 1)); const auto idx = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); - const auto is_even = Eq(detail::AndS(i, 1), Zero(du)); + const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); return IfThenElse(is_even, TableLookupLanes(a, idx), TableLookupLanes(b, idx)); } @@ -1822,7 +2046,7 @@ HWY_API V InterleaveUpper(const D d, const V a, const V b) { const auto idx_mod = ShiftRight<1>(detail::AndS(i, kLanesPerBlock - 1)); const auto idx_lower = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); const auto idx = detail::AddS(idx_lower, kLanesPerBlock / 2); - const auto is_even = Eq(detail::AndS(i, 1), Zero(du)); + const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); return IfThenElse(is_even, TableLookupLanes(a, idx), TableLookupLanes(b, idx)); } @@ -1852,61 +2076,74 @@ HWY_API VFromD ZipUpper(DW dw, V a, V b) { // ================================================== REDUCE // vector = f(vector, zero_m1) -#define HWY_RVV_REDUCE(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ +#define HWY_RVV_REDUCE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, m1) v0) { \ - return Set(HWY_RVV_D(CHAR, SEW, LMUL)(), \ - GetLane(v##OP##_vs_##CHAR##SEW##LMUL##_##CHAR##SEW##m1( \ - v0, v, v0, HWY_RVV_AVL(SEW, SHIFT)))); \ + NAME(D d, HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, m1) v0) { \ + return Set(d, GetLane(v##OP##_vs_##CHAR##SEW##LMUL##_##CHAR##SEW##m1( \ + v0, v, v0, Lanes(d)))); \ } // ------------------------------ SumOfLanes namespace detail { -HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum) -HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredusum) +HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredusum, _ALL) } // namespace detail template -HWY_API VFromD SumOfLanes(D /* d */, const VFromD v) { - const auto v0 = Zero(Full>()); // always m1 - return detail::RedSum(v, v0); +HWY_API VFromD SumOfLanes(D d, const VFromD v) { + const auto v0 = Zero(ScalableTag>()); // always m1 + return detail::RedSum(d, v, v0); } // ------------------------------ MinOfLanes namespace detail { -HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu) -HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin) -HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin) +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin, _ALL) } // namespace detail template -HWY_API VFromD MinOfLanes(D /* d */, const VFromD v) { +HWY_API VFromD MinOfLanes(D d, const VFromD v) { using T = TFromD; - const Full d1; // always m1 + const ScalableTag d1; // always m1 const auto neutral = Set(d1, HighestValue()); - return detail::RedMin(v, neutral); + return detail::RedMin(d, v, neutral); } // ------------------------------ MaxOfLanes namespace detail { -HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu) -HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax) -HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax) +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax, _ALL) } // namespace detail template -HWY_API VFromD MaxOfLanes(D /* d */, const VFromD v) { +HWY_API VFromD MaxOfLanes(D d, const VFromD v) { using T = TFromD; - const Full d1; // always m1 + const ScalableTag d1; // always m1 const auto neutral = Set(d1, LowestValue()); - return detail::RedMax(v, neutral); + return detail::RedMax(d, v, neutral); } #undef HWY_RVV_REDUCE // ================================================== Ops with dependencies +// ------------------------------ PopulationCount (ShiftRight) + +// Handles LMUL >= 2 or capped vectors, which generic_ops-inl cannot. +template , HWY_IF_LANES_ARE(uint8_t, V), + hwy::EnableIf* = nullptr> +HWY_API V PopulationCount(V v) { + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + v = Sub(v, detail::AndS(ShiftRight<1>(v), 0x55)); + v = Add(detail::AndS(ShiftRight<2>(v), 0x33), detail::AndS(v, 0x33)); + return detail::AndS(Add(v, ShiftRight<4>(v)), 0x0F); +} + // ------------------------------ LoadDup128 template @@ -1919,22 +2156,17 @@ HWY_API VFromD LoadDup128(D d, const TFromD* const HWY_RESTRICT p) { } // ------------------------------ StoreMaskBits -#define HWY_RVV_STORE_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ - /* DEPRECATED */ \ - HWY_API size_t StoreMaskBits(HWY_RVV_M(MLEN) m, uint8_t* bits) { \ - /* LMUL=1 is always enough */ \ - Full d8; \ - const size_t num_bytes = (Lanes(d8) + MLEN - 1) / MLEN; \ - /* TODO(janwas): how to convert vbool* to vuint?*/ \ - /*Store(m, d8, bits);*/ \ - (void)m; \ - (void)bits; \ - return num_bytes; \ - } \ - template \ - HWY_API size_t StoreMaskBits(D /* tag */, HWY_RVV_M(MLEN) m, \ - uint8_t* bits) { \ - return StoreMaskBits(m, bits); \ +#define HWY_RVV_STORE_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t StoreMaskBits(D /*d*/, HWY_RVV_M(MLEN) m, uint8_t* bits) { \ + /* LMUL=1 is always enough */ \ + ScalableTag d8; \ + const size_t num_bytes = (Lanes(d8) + MLEN - 1) / MLEN; \ + /* TODO(janwas): how to convert vbool* to vuint?*/ \ + /*Store(m, d8, bits);*/ \ + (void)m; \ + (void)bits; \ + return num_bytes; \ } HWY_RVV_FOREACH_B(HWY_RVV_STORE_MASK_BITS, _, _) #undef HWY_RVV_STORE_MASK_BITS @@ -1947,11 +2179,12 @@ HWY_API MFromD FirstN(const D d, const size_t n) { const RebindToSigned di; using TI = TFromD; return RebindMask( - d, Lt(BitCast(di, detail::Iota0(d)), Set(di, static_cast(n)))); + d, detail::LtS(BitCast(di, detail::Iota0(d)), static_cast(n))); } template HWY_API MFromD FirstN(const D d, const size_t n) { + // TODO(janwas): for reasons unknown, this freezes spike. const auto zero = Zero(d); const auto one = Set(d, 1); return Eq(detail::SlideUp(one, zero, n), one); @@ -1961,17 +2194,17 @@ HWY_API MFromD FirstN(const D d, const size_t n) { template HWY_API V Neg(const V v) { - return Sub(Zero(DFromV()), v); + return detail::ReverseSubS(v, 0); } // vector = f(vector), but argument is repeated -#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, \ - OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ - return v##OP##_vv_##CHAR##SEW##LMUL(v, v, HWY_RVV_AVL(SEW, SHIFT)); \ +#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, v, HWY_RVV_AVL(SEW, SHIFT)); \ } -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn, _ALL) // ------------------------------ Abs (Max, Neg) @@ -1980,7 +2213,7 @@ HWY_API V Abs(const V v) { return Max(v, Neg(v)); } -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Abs, fsgnjx) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Abs, fsgnjx, _ALL) #undef HWY_RVV_RETV_ARGV2 @@ -2002,7 +2235,7 @@ enum RoundingModes { kNear, kTrunc, kDown, kUp }; template HWY_INLINE auto UseInt(const V v) -> decltype(MaskFromVec(v)) { - return Lt(Abs(v), Set(DFromV(), MantissaEnd>())); + return detail::LtS(Abs(v), MantissaEnd>()); } } // namespace detail @@ -2051,13 +2284,13 @@ HWY_API V Floor(const V v) { template HWY_API VFromD Iota(const D d, TFromD first) { - return Add(detail::Iota0(d), Set(d, first)); + return detail::AddS(detail::Iota0(d), first); } template HWY_API VFromD Iota(const D d, TFromD first) { const RebindToUnsigned du; - return Add(BitCast(d, detail::Iota0(du)), Set(d, first)); + return detail::AddS(BitCast(d, detail::Iota0(du)), first); } template @@ -2069,26 +2302,12 @@ HWY_API VFromD Iota(const D d, TFromD first) { // ------------------------------ MulEven/Odd (Mul, OddEven) -namespace detail { -// Special instruction for 1 lane is presumably faster? -#define HWY_RVV_SLIDE1(BASE, CHAR, SEW, LMUL, X2, HALF, SHIFT, MLEN, NAME, OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ - return v##OP##_vx_##CHAR##SEW##LMUL(v, 0, HWY_RVV_AVL(SEW, SHIFT)); \ - } - -HWY_RVV_FOREACH_UI32(HWY_RVV_SLIDE1, Slide1Up, slide1up) -HWY_RVV_FOREACH_U64(HWY_RVV_SLIDE1, Slide1Up, slide1up) -HWY_RVV_FOREACH_UI32(HWY_RVV_SLIDE1, Slide1Down, slide1down) -HWY_RVV_FOREACH_U64(HWY_RVV_SLIDE1, Slide1Down, slide1down) -#undef HWY_RVV_SLIDE1 -} // namespace detail - -template -HWY_API VFromD>> MulEven(const V a, const V b) { +template , + class DW = RepartitionToWide> +HWY_API VFromD MulEven(const V a, const V b) { const auto lo = Mul(a, b); const auto hi = detail::MulHigh(a, b); - const RepartitionToWide> dw; - return BitCast(dw, OddEven(detail::Slide1Up(hi), lo)); + return BitCast(DW(), OddEven(detail::Slide1Up(hi), lo)); } // There is no 64x64 vwmul. @@ -2108,27 +2327,33 @@ HWY_INLINE V MulOdd(const V a, const V b) { // ------------------------------ ReorderDemote2To (OddEven) -template > -HWY_API VFromD> ReorderDemote2To(Simd dbf16, - VFromD a, VFromD b) { +template +HWY_API VFromD> ReorderDemote2To( + Simd dbf16, + VFromD> a, + VFromD> b) { const RebindToUnsigned du16; - const RebindToUnsigned du32; + const RebindToUnsigned> du32; const VFromD b_in_even = ShiftRight<16>(BitCast(du32, b)); return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) -template > -HWY_API auto ReorderWidenMulAccumulate(Simd df32, VFromD a, - VFromD b, +template +using DU16FromDF = RepartitionToNarrow>; + +template +HWY_API auto ReorderWidenMulAccumulate(Simd df32, + VFromD> a, + VFromD> b, const VFromD sum0, VFromD& sum1) -> VFromD { - const DU16 du16; + const DU16FromDF du16; const RebindToUnsigned du32; using VU32 = VFromD; - const VFromD zero = Zero(du16); + const VFromD zero = Zero(du16); const VU32 a0 = ZipLower(du32, zero, BitCast(du16, a)); const VU32 a1 = ZipUpper(du32, zero, BitCast(du16, a)); const VU32 b0 = ZipLower(du32, zero, BitCast(du16, b)); @@ -2137,20 +2362,89 @@ HWY_API auto ReorderWidenMulAccumulate(Simd df32, VFromD a, return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); } +// ------------------------------ Lt128 + +template +HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, "Use u64"); + // Truth table of Eq and Compare for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // Shift leftward so L can influence H. + const VFromD ltLx = detail::Slide1Up(ltHL); + const VFromD vecHx = OrAnd(ltHL, eqHL, ltLx); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(vecHx, detail::Slide1Down(vecHx))); +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD minHL = Min(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, a, b); + // The upper lane is just minHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(minHL, IfThenElse(eqXH, minHL, lo)); +} + +template +HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD maxHL = Max(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, b, a); + // The upper lane is just maxHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(maxHL, IfThenElse(eqXH, maxHL, lo)); +} + // ================================================== END MACROS namespace detail { // for code folding -#undef HWY_IF_FLOAT_V -#undef HWY_IF_SIGNED_V -#undef HWY_IF_UNSIGNED_V - +#undef HWY_RVV_AVL +#undef HWY_RVV_D #undef HWY_RVV_FOREACH -#undef HWY_RVV_FOREACH_08 -#undef HWY_RVV_FOREACH_16 -#undef HWY_RVV_FOREACH_32 -#undef HWY_RVV_FOREACH_64 +#undef HWY_RVV_FOREACH_08_ALL +#undef HWY_RVV_FOREACH_08_DEMOTE +#undef HWY_RVV_FOREACH_08_EXT +#undef HWY_RVV_FOREACH_08_TRUNC +#undef HWY_RVV_FOREACH_16_ALL +#undef HWY_RVV_FOREACH_16_DEMOTE +#undef HWY_RVV_FOREACH_16_EXT +#undef HWY_RVV_FOREACH_16_TRUNC +#undef HWY_RVV_FOREACH_32_ALL +#undef HWY_RVV_FOREACH_32_DEMOTE +#undef HWY_RVV_FOREACH_32_EXT +#undef HWY_RVV_FOREACH_32_TRUNC +#undef HWY_RVV_FOREACH_64_ALL +#undef HWY_RVV_FOREACH_64_DEMOTE +#undef HWY_RVV_FOREACH_64_EXT +#undef HWY_RVV_FOREACH_64_TRUNC #undef HWY_RVV_FOREACH_B #undef HWY_RVV_FOREACH_F +#undef HWY_RVV_FOREACH_F16 #undef HWY_RVV_FOREACH_F32 +#undef HWY_RVV_FOREACH_F3264 #undef HWY_RVV_FOREACH_F64 #undef HWY_RVV_FOREACH_I #undef HWY_RVV_FOREACH_I08 @@ -2163,19 +2457,18 @@ namespace detail { // for code folding #undef HWY_RVV_FOREACH_U32 #undef HWY_RVV_FOREACH_U64 #undef HWY_RVV_FOREACH_UI +#undef HWY_RVV_FOREACH_UI08 #undef HWY_RVV_FOREACH_UI16 +#undef HWY_RVV_FOREACH_UI163264 #undef HWY_RVV_FOREACH_UI32 +#undef HWY_RVV_FOREACH_UI3264 #undef HWY_RVV_FOREACH_UI64 - +#undef HWY_RVV_M #undef HWY_RVV_RETV_ARGV #undef HWY_RVV_RETV_ARGVS #undef HWY_RVV_RETV_ARGVV - #undef HWY_RVV_T -#undef HWY_RVV_D #undef HWY_RVV_V -#undef HWY_RVV_M - } // namespace detail // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE diff --git a/third_party/highway/hwy/ops/scalar-inl.h b/third_party/highway/hwy/ops/scalar-inl.h index 3e7758fcbca8..7f0d8231fd59 100644 --- a/third_party/highway/hwy/ops/scalar-inl.h +++ b/third_party/highway/hwy/ops/scalar-inl.h @@ -27,7 +27,7 @@ namespace HWY_NAMESPACE { // Single instruction, single data. template -using Sisd = Simd; +using Sisd = Simd; // (Wrapper class required for overloading comparison operators.) template @@ -187,6 +187,20 @@ HWY_API Vec1 operator^(const Vec1 a, const Vec1 b) { return Xor(a, b); } +// ------------------------------ OrAnd + +template +HWY_API Vec1 OrAnd(const Vec1 o, const Vec1 a1, const Vec1 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec1 IfVecThenElse(Vec1 mask, Vec1 yes, Vec1 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + // ------------------------------ CopySign template @@ -275,6 +289,11 @@ HWY_API Vec1 IfThenZeroElse(const Mask1 mask, const Vec1 no) { return mask.bits ? Vec1(0) : no; } +template +HWY_API Vec1 IfNegativeThenElse(Vec1 v, Vec1 yes, Vec1 no) { + return v.raw < 0 ? yes : no; +} + template HWY_API Vec1 ZeroIfNegative(const Vec1 v) { return v.raw < 0 ? Vec1(0) : v; @@ -423,7 +442,13 @@ HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { return Vec1(a.raw - b.raw); } -// ------------------------------ Saturating addition +// ------------------------------ SumsOf8 + +HWY_API Vec1 SumsOf8(const Vec1 v) { + return Vec1(v.raw); +} + +// ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. @@ -931,21 +956,30 @@ HWY_API Vec1 PromoteTo(Sisd /* tag */, Vec1 from) { return Vec1(static_cast(from.raw)); } -template -HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { - static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); - +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, +// so we overload for FromT=double and ToT={float,int32_t}. +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { // Prevent ubsan errors when converting float to narrower integer/float if (std::isinf(from.raw) || - std::fabs(from.raw) > static_cast(HighestValue())) { - return Vec1(std::signbit(from.raw) ? LowestValue() - : HighestValue()); + std::fabs(from.raw) > static_cast(HighestValue())) { + return Vec1(std::signbit(from.raw) ? LowestValue() + : HighestValue()); } - return Vec1(static_cast(from.raw)); + return Vec1(static_cast(from.raw)); +} +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + if (std::isinf(from.raw) || + std::fabs(from.raw) > static_cast(HighestValue())) { + return Vec1(std::signbit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); } -template +template HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + static_assert(!IsFloat(), "FromT=double are handled above"); static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); // Int to int: choose closest value in ToT to `from` (avoids UB) @@ -1083,6 +1117,12 @@ HWY_API T GetLane(const Vec1 v) { return v.raw; } +template +HWY_API Vec1 DupEven(Vec1 v) { + return v; +} +// DupOdd is unsupported. + template HWY_API Vec1 OddEven(Vec1 /* odd */, Vec1 even) { return even; @@ -1125,6 +1165,14 @@ HWY_API Vec1 TableLookupLanes(const Vec1 v, const Indices1 /* idx */) { return v; } +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec1 ReverseBlocks(Sisd /* tag */, const Vec1 v) { + return v; +} + // ------------------------------ Reverse template @@ -1132,6 +1180,21 @@ HWY_API Vec1 Reverse(Sisd /* tag */, const Vec1 v) { return v; } +template +HWY_API Vec1 Reverse2(Sisd /* tag */, const Vec1 v) { + return v; +} + +template +HWY_API Vec1 Reverse4(Sisd /* tag */, const Vec1 v) { + return v; +} + +template +HWY_API Vec1 Reverse8(Sisd /* tag */, const Vec1 v) { + return v; +} + // ================================================== BLOCKWISE // Shift*Bytes, CombineShiftRightBytes, Interleave*, Shuffle* are unsupported. @@ -1308,41 +1371,6 @@ HWY_API Vec1 MaxOfLanes(Sisd /* tag */, const Vec1 v) { return v; } -// ================================================== DEPRECATED - -template -HWY_API size_t StoreMaskBits(const Mask1 mask, uint8_t* bits) { - return StoreMaskBits(Sisd(), mask, bits); -} - -template -HWY_API bool AllTrue(const Mask1 mask) { - return AllTrue(Sisd(), mask); -} - -template -HWY_API bool AllFalse(const Mask1 mask) { - return AllFalse(Sisd(), mask); -} - -template -HWY_API size_t CountTrue(const Mask1 mask) { - return CountTrue(Sisd(), mask); -} - -template -HWY_API Vec1 SumOfLanes(const Vec1 v) { - return SumOfLanes(Sisd(), v); -} -template -HWY_API Vec1 MinOfLanes(const Vec1 v) { - return MinOfLanes(Sisd(), v); -} -template -HWY_API Vec1 MaxOfLanes(const Vec1 v) { - return MaxOfLanes(Sisd(), v); -} - // ================================================== Operator wrapper template diff --git a/third_party/highway/hwy/ops/set_macros-inl.h b/third_party/highway/hwy/ops/set_macros-inl.h index 1da80cd0ef4e..cc7c7709a315 100644 --- a/third_party/highway/hwy/ops/set_macros-inl.h +++ b/third_party/highway/hwy/ops/set_macros-inl.h @@ -32,9 +32,10 @@ #undef HWY_MAX_BYTES #undef HWY_LANES -#undef HWY_CAP_INTEGER64 -#undef HWY_CAP_FLOAT16 -#undef HWY_CAP_FLOAT64 +#undef HWY_HAVE_SCALABLE +#undef HWY_HAVE_INTEGER64 +#undef HWY_HAVE_FLOAT16 +#undef HWY_HAVE_FLOAT64 #undef HWY_CAP_GE256 #undef HWY_CAP_GE512 @@ -79,9 +80,10 @@ #define HWY_MAX_BYTES 16 #define HWY_LANES(T) (16 / sizeof(T)) -#define HWY_CAP_INTEGER64 1 -#define HWY_CAP_FLOAT16 1 -#define HWY_CAP_FLOAT64 1 +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 #define HWY_CAP_AES 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -96,9 +98,10 @@ #define HWY_MAX_BYTES 16 #define HWY_LANES(T) (16 / sizeof(T)) -#define HWY_CAP_INTEGER64 1 -#define HWY_CAP_FLOAT16 1 -#define HWY_CAP_FLOAT64 1 +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -113,9 +116,10 @@ #define HWY_MAX_BYTES 32 #define HWY_LANES(T) (32 / sizeof(T)) -#define HWY_CAP_INTEGER64 1 -#define HWY_CAP_FLOAT16 1 -#define HWY_CAP_FLOAT64 1 +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 #define HWY_CAP_GE256 1 #define HWY_CAP_GE512 0 @@ -129,9 +133,10 @@ #define HWY_MAX_BYTES 64 #define HWY_LANES(T) (64 / sizeof(T)) -#define HWY_CAP_INTEGER64 1 -#define HWY_CAP_FLOAT16 1 -#define HWY_CAP_FLOAT64 1 +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 #define HWY_CAP_GE256 1 #define HWY_CAP_GE512 1 @@ -159,9 +164,10 @@ #define HWY_MAX_BYTES 16 #define HWY_LANES(T) (16 / sizeof(T)) -#define HWY_CAP_INTEGER64 1 -#define HWY_CAP_FLOAT16 0 -#define HWY_CAP_FLOAT64 1 +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -177,15 +183,16 @@ #define HWY_MAX_BYTES 16 #define HWY_LANES(T) (16 / sizeof(T)) -#define HWY_CAP_INTEGER64 1 -#define HWY_CAP_FLOAT16 1 +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 #if HWY_ARCH_ARM_A64 -#define HWY_CAP_FLOAT64 1 +#define HWY_HAVE_FLOAT64 1 #else -#define HWY_CAP_FLOAT64 0 +#define HWY_HAVE_FLOAT64 0 #endif #define HWY_NAMESPACE N_NEON @@ -196,23 +203,19 @@ // SVE[2] #elif HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE -#if defined(HWY_EMULATE_SVE) && !defined(__F16C__) -#error "Disable HWY_CAP_FLOAT16 or ensure farm_sve actually converts to f16" -#endif - // SVE only requires lane alignment, not natural alignment of the entire vector. #define HWY_ALIGN alignas(8) #define HWY_MAX_BYTES 256 -// <= HWY_MAX_BYTES / sizeof(T): exact size. Otherwise a fraction 1/div (div = -// 1,2,4,8) is encoded as HWY_LANES(T) / div. This value leaves enough room for -// div=8 and demoting to 1/8 the lane width while still exceeding HWY_MAX_BYTES. -#define HWY_LANES(T) (32768 / sizeof(T)) +// Value ensures MaxLanes() is the tightest possible upper bound to reduce +// overallocation. +#define HWY_LANES(T) ((HWY_MAX_BYTES) / sizeof(T)) -#define HWY_CAP_INTEGER64 1 -#define HWY_CAP_FLOAT16 1 -#define HWY_CAP_FLOAT64 1 +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -232,9 +235,10 @@ #define HWY_MAX_BYTES 16 #define HWY_LANES(T) (16 / sizeof(T)) -#define HWY_CAP_INTEGER64 0 -#define HWY_CAP_FLOAT16 1 -#define HWY_CAP_FLOAT64 0 +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -250,9 +254,10 @@ #define HWY_MAX_BYTES 32 #define HWY_LANES(T) (32 / sizeof(T)) -#define HWY_CAP_INTEGER64 0 -#define HWY_CAP_FLOAT16 1 -#define HWY_CAP_FLOAT64 0 +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -271,20 +276,20 @@ // The spec requires VLEN <= 2^16 bits, so the limit is 2^16 bytes (LMUL=8). #define HWY_MAX_BYTES 65536 -// <= HWY_MAX_BYTES / sizeof(T): exact size. Otherwise a fraction 1/div (div = -// 1,2,4,8) is encoded as HWY_LANES(T) / div. This value leaves enough room for -// div=8 and demoting to 1/8 the lane width while still exceeding HWY_MAX_BYTES. -#define HWY_LANES(T) (8388608 / sizeof(T)) +// = HWY_MAX_BYTES divided by max LMUL=8 because MaxLanes includes the actual +// LMUL. This is the tightest possible upper bound. +#define HWY_LANES(T) (8192 / sizeof(T)) -#define HWY_CAP_INTEGER64 1 -#define HWY_CAP_FLOAT64 1 +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT64 1 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 #if defined(__riscv_zfh) -#define HWY_CAP_FLOAT16 1 +#define HWY_HAVE_FLOAT16 1 #else -#define HWY_CAP_FLOAT16 0 +#define HWY_HAVE_FLOAT16 0 #endif #define HWY_NAMESPACE N_RVV @@ -300,9 +305,10 @@ #define HWY_MAX_BYTES 8 #define HWY_LANES(T) 1 -#define HWY_CAP_INTEGER64 1 -#define HWY_CAP_FLOAT16 1 -#define HWY_CAP_FLOAT64 1 +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -344,7 +350,3 @@ #else #define HWY_ATTR #endif - -// DEPRECATED -#undef HWY_GATHER_LANES -#define HWY_GATHER_LANES(T) HWY_LANES(T) diff --git a/third_party/highway/hwy/ops/shared-inl.h b/third_party/highway/hwy/ops/shared-inl.h index 4a4ed1e29777..65245cbe9258 100644 --- a/third_party/highway/hwy/ops/shared-inl.h +++ b/third_party/highway/hwy/ops/shared-inl.h @@ -26,65 +26,117 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { -// SIMD operations are implemented as overloaded functions selected using a tag -// type D := Simd. T is the lane type, N an opaque integer for internal -// use only. Users create D via aliases ScalableTag() (a full vector), -// CappedTag or FixedTag. The actual number of lanes -// (always a power of two) is Lanes(D()). -template +// Highway operations are implemented as overloaded functions selected using an +// internal-only tag type D := Simd. T is the lane type. kPow2 is a +// shift count applied to scalable vectors. Instead of referring to Simd<> +// directly, users create D via aliases ScalableTag() (defaults to a +// full vector, or fractions/groups if the argument is negative/positive), +// CappedTag or FixedTag. The actual number of lanes is +// Lanes(D()), a power of two. For scalable vectors, N is either HWY_LANES or a +// cap. For constexpr-size vectors, N is the actual number of lanes. This +// ensures Half> is the same type as Full256, as required by x86. +template struct Simd { constexpr Simd() = default; using T = Lane; static_assert((N & (N - 1)) == 0 && N != 0, "N must be a power of two"); + // Only for use by MaxLanes, required by MSVC. Cannot be enum because GCC + // warns when using enums and non-enums in the same expression. Cannot be + // static constexpr function (another MSVC limitation). + static constexpr size_t kPrivateN = N; + static constexpr int kPrivatePow2 = kPow2; + + template + static constexpr size_t NewN() { + // Round up to correctly handle scalars with N=1. + return (N * sizeof(T) + sizeof(NewT) - 1) / sizeof(NewT); + } + +#if HWY_HAVE_SCALABLE + template + static constexpr int Pow2Ratio() { + return (sizeof(NewT) > sizeof(T)) + ? static_cast(CeilLog2(sizeof(NewT) / sizeof(T))) + : -static_cast(CeilLog2(sizeof(T) / sizeof(NewT))); + } +#endif + // Widening/narrowing ops change the number of lanes and/or their type. // To initialize such vectors, we need the corresponding tag types: - // PromoteTo/DemoteTo() with another lane type, but same number of lanes. - template - using Rebind = Simd; +// PromoteTo/DemoteTo() with another lane type, but same number of lanes. +#if HWY_HAVE_SCALABLE + template + using Rebind = Simd()>; +#else + template + using Rebind = Simd; +#endif - // MulEven() with another lane type, but same total size. - // Round up to correctly handle scalars with N=1. - template - using Repartition = - Simd; + // Change lane type while keeping the same vector size, e.g. for MulEven. + template + using Repartition = Simd(), kPow2>; - // LowerHalf() with the same lane type, but half the lanes. - // Round up to correctly handle scalars with N=1. - using Half = Simd; +// Half the lanes while keeping the same lane type, e.g. for LowerHalf. +// Round up to correctly handle scalars with N=1. +#if HWY_HAVE_SCALABLE + // Reducing the cap (N) is required for SVE - if N is the limiter for f32xN, + // then we expect Half> to have N/2 lanes (rounded up). + using Half = Simd; +#else + using Half = Simd; +#endif - // Combine() with the same lane type, but twice the lanes. - using Twice = Simd; +// Twice the lanes while keeping the same lane type, e.g. for Combine. +#if HWY_HAVE_SCALABLE + using Twice = Simd; +#else + using Twice = Simd; +#endif }; namespace detail { -// Given N from HWY_LANES(T), returns N for use in Simd to describe: -// - a full vector (pow2 = 0); -// - 2,4,8 regs on RVV, otherwise a full vector (pow2 [1,3]); -// - a fraction of a register from 1/8 to 1/2 (pow2 [-3,-1]). -constexpr size_t ScaleByPower(size_t N, int pow2) { -#if HWY_TARGET == HWY_RVV - // For fractions, if N == 1 ensure we still return at least one lane. - return pow2 >= 0 ? (N << pow2) : HWY_MAX(1, (N >> (-pow2))); -#else - // If pow2 > 0, replace it with 0 (there is nothing wider than a full vector). - return HWY_MAX(1, N >> HWY_MAX(-pow2, 0)); +#if HWY_HAVE_SCALABLE + +template +constexpr bool IsFull(Simd /* d */) { + return N == HWY_LANES(T) && kPow2 == 0; +} + #endif + +// Returns the number of lanes (possibly zero) after applying a shift: +// - 0: no change; +// - [1,3]: a group of 2,4,8 [fractional] vectors; +// - [-3,-1]: a fraction of a vector from 1/8 to 1/2. +constexpr size_t ScaleByPower(size_t N, int pow2) { + return pow2 >= 0 ? (N << pow2) : (N >> (-pow2)); } // Struct wrappers enable validation of arguments via static_assert. template struct ScalableTagChecker { static_assert(-3 <= kPow2 && kPow2 <= 3, "Fraction must be 1/8 to 8"); - using type = Simd; +#if HWY_TARGET == HWY_RVV + // Only RVV supports register groups. + using type = Simd; +#elif HWY_HAVE_SCALABLE + // For SVE[2], only allow full or fractions. + using type = Simd; +#elif HWY_TARGET == HWY_SCALAR + using type = Simd; +#else + // Only allow full or fractions. + using type = Simd; +#endif }; template struct CappedTagChecker { static_assert(kLimit != 0, "Does not make sense to have zero lanes"); - using type = Simd; + using type = Simd; }; template @@ -95,7 +147,7 @@ struct FixedTagChecker { // HWY_MAX_BYTES would still allow uint8x8, which is not supported. static_assert(kNumLanes == 1, "Scalar only supports one lane"); #endif - using type = Simd; + using type = Simd; }; } // namespace detail @@ -114,15 +166,14 @@ using ScalableTag = typename detail::ScalableTagChecker::type; // typically used for 1D loops with a relatively low application-defined upper // bound, e.g. for 8x8 DCTs. However, it is better if data structures are // designed to be vector-length-agnostic (e.g. a hybrid SoA where there are -// chunks of say 256 DC components followed by 256 AC1 and finally 256 AC63; +// chunks of `M >= MaxLanes(d)` DC components followed by M AC1, .., and M AC63; // this would enable vector-length-agnostic loops using ScalableTag). template using CappedTag = typename detail::CappedTagChecker::type; // Alias for a tag describing a vector with *exactly* kNumLanes active lanes, -// even on targets with scalable vectors. All targets except HWY_SCALAR support -// up to 16 / sizeof(T). Other targets may allow larger kNumLanes, but relying -// on that is non-portable and discouraged. +// even on targets with scalable vectors. HWY_SCALAR only supports one lane. +// All other targets allow kNumLanes up to HWY_MAX_BYTES / sizeof(T). // // NOTE: if the application does not need to support HWY_SCALAR (+), use this // instead of CappedTag to emphasize that there will be exactly kNumLanes lanes. @@ -163,11 +214,11 @@ using RepartitionToNarrow = Repartition>, D>; template using Half = typename D::Half; -// Descriptor for the same lane type as D, but twice the lanes. +// Tag for the same lane type as D, but twice the lanes. template using Twice = typename D::Twice; -// Same as base.h macros but with a Simd argument instead of T. +// Same as base.h macros but with a Simd argument instead of T. #define HWY_IF_UNSIGNED_D(D) HWY_IF_UNSIGNED(TFromD) #define HWY_IF_SIGNED_D(D) HWY_IF_SIGNED(TFromD) #define HWY_IF_FLOAT_D(D) HWY_IF_FLOAT(TFromD) @@ -175,6 +226,12 @@ using Twice = typename D::Twice; #define HWY_IF_LANE_SIZE_D(D, bytes) HWY_IF_LANE_SIZE(TFromD, bytes) #define HWY_IF_NOT_LANE_SIZE_D(D, bytes) HWY_IF_NOT_LANE_SIZE(TFromD, bytes) +// MSVC workaround: use PrivateN directly instead of MaxLanes. +#define HWY_IF_LT128_D(D) \ + hwy::EnableIf) < 16>* = nullptr +#define HWY_IF_GE128_D(D) \ + hwy::EnableIf) >= 16>* = nullptr + // Same, but with a vector argument. #define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(TFromV) #define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(TFromV) @@ -183,42 +240,59 @@ using Twice = typename D::Twice; // For implementing functions for a specific type. // IsSame<...>() in template arguments is broken on MSVC2015. -#define HWY_IF_LANES_ARE(T, V) \ - EnableIf>>::value>* = nullptr +#define HWY_IF_LANES_ARE(T, V) EnableIf>::value>* = nullptr -// Compile-time-constant, (typically but not guaranteed) an upper bound on the -// number of lanes. -// Prefer instead using Lanes() and dynamic allocation, or Rebind, or -// `#if HWY_CAP_GE*`. -template -HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(Simd) { - return N; +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr int Pow2(D /* d */) { + return D::kPrivatePow2; } -// Targets with non-constexpr Lanes define this themselves. -#if HWY_TARGET != HWY_RVV && HWY_TARGET != HWY_SVE2 && HWY_TARGET != HWY_SVE +// MSVC requires the explicit . +#define HWY_IF_POW2_GE(D, MIN) hwy::EnableIf(D()) >= (MIN)>* = nullptr + +#if HWY_HAVE_SCALABLE + +// Upper bound on the number of lanes. Intended for template arguments and +// reducing code size (e.g. for SSE4, we know at compile-time that vectors will +// not exceed 16 bytes). WARNING: this may be a loose bound, use Lanes() as the +// actual size for allocating storage. WARNING: MSVC might not be able to deduce +// arguments if this is used in EnableIf. See HWY_IF_LT128_D above. +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return detail::ScaleByPower(HWY_MIN(D::kPrivateN, HWY_LANES(TFromD)), + D::kPrivatePow2); +} + +#else +// Workaround for MSVC 2017: T,N,kPow2 argument deduction fails, so returning N +// is not an option, nor does a member function work. +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return D::kPrivateN; +} // (Potentially) non-constant actual size of the vector at runtime, subject to // the limit imposed by the Simd. Useful for advancing loop counters. -template -HWY_INLINE HWY_MAYBE_UNUSED size_t Lanes(Simd) { +// Targets with scalable vectors define this themselves. +template +HWY_INLINE HWY_MAYBE_UNUSED size_t Lanes(Simd) { return N; } -#endif +#endif // !HWY_HAVE_SCALABLE // NOTE: GCC generates incorrect code for vector arguments to non-inlined // functions in two situations: // - on Windows and GCC 10.3, passing by value crashes due to unaligned loads: // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412. -// - on ARM64 and GCC 9.3.0 or 11.2.1, passing by const& causes many (but not +// - on ARM64 and GCC 9.3.0 or 11.2.1, passing by value causes many (but not // all) tests to fail. // // We therefore pass by const& only on GCC and (Windows or ARM64). This alias // must be used for all vector/mask parameters of functions marked HWY_NOINLINE, // and possibly also other functions that are not inlined. #if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG && \ - ((defined(_WIN32) || defined(_WIN64)) || HWY_ARCH_ARM64) + ((defined(_WIN32) || defined(_WIN64)) || HWY_ARCH_ARM_A64) template using VecArg = const V&; #else diff --git a/third_party/highway/hwy/ops/wasm_128-inl.h b/third_party/highway/hwy/ops/wasm_128-inl.h index fbb9acf9ff34..39500eee1a0e 100644 --- a/third_party/highway/hwy/ops/wasm_128-inl.h +++ b/third_party/highway/hwy/ops/wasm_128-inl.h @@ -49,7 +49,10 @@ namespace hwy { namespace HWY_NAMESPACE { template -using Full128 = Simd; +using Full128 = Simd; + +template +using Full64 = Simd; namespace detail { @@ -96,6 +99,9 @@ class Vec128 { Raw raw; }; +template +using Vec64 = Vec128; + // FF..FF or 0. template struct Mask128 { @@ -104,11 +110,11 @@ struct Mask128 { namespace detail { -// Deduce Simd from Vec128 +// Deduce Simd from Vec128 struct DeduceD { template - Simd operator()(Vec128) const { - return Simd(); + Simd operator()(Vec128) const { + return Simd(); } }; @@ -148,7 +154,7 @@ struct BitCastFromInteger128 { }; template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128{BitCastFromInteger128()(v.raw)}; } @@ -156,7 +162,7 @@ HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, } // namespace detail template -HWY_API Vec128 BitCast(Simd d, +HWY_API Vec128 BitCast(Simd d, Vec128 v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } @@ -165,11 +171,11 @@ HWY_API Vec128 BitCast(Simd d, // Returns an all-zero vector/part. template -HWY_API Vec128 Zero(Simd /* tag */) { +HWY_API Vec128 Zero(Simd /* tag */) { return Vec128{wasm_i32x4_splat(0)}; } template -HWY_API Vec128 Zero(Simd /* tag */) { +HWY_API Vec128 Zero(Simd /* tag */) { return Vec128{wasm_f32x4_splat(0.0f)}; } @@ -180,41 +186,44 @@ using VFromD = decltype(Zero(D())); // Returns a vector/part with all lanes set to "t". template -HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { +HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { return Vec128{wasm_i8x16_splat(static_cast(t))}; } template -HWY_API Vec128 Set(Simd /* tag */, const uint16_t t) { +HWY_API Vec128 Set(Simd /* tag */, + const uint16_t t) { return Vec128{wasm_i16x8_splat(static_cast(t))}; } template -HWY_API Vec128 Set(Simd /* tag */, const uint32_t t) { +HWY_API Vec128 Set(Simd /* tag */, + const uint32_t t) { return Vec128{wasm_i32x4_splat(static_cast(t))}; } template -HWY_API Vec128 Set(Simd /* tag */, const uint64_t t) { +HWY_API Vec128 Set(Simd /* tag */, + const uint64_t t) { return Vec128{wasm_i64x2_splat(static_cast(t))}; } template -HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { +HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { return Vec128{wasm_i8x16_splat(t)}; } template -HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { +HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { return Vec128{wasm_i16x8_splat(t)}; } template -HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { +HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { return Vec128{wasm_i32x4_splat(t)}; } template -HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { +HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { return Vec128{wasm_i64x2_splat(t)}; } template -HWY_API Vec128 Set(Simd /* tag */, const float t) { +HWY_API Vec128 Set(Simd /* tag */, const float t) { return Vec128{wasm_f32x4_splat(t)}; } @@ -223,7 +232,7 @@ HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") // Returns a vector with uninitialized elements. template -HWY_API Vec128 Undefined(Simd d) { +HWY_API Vec128 Undefined(Simd d) { return Zero(d); } @@ -231,7 +240,7 @@ HWY_DIAGNOSTICS(pop) // Returns a vector with lane i=[0, N) set to "first" + i. template -Vec128 Iota(const Simd d, const T2 first) { +Vec128 Iota(const Simd d, const T2 first) { HWY_ALIGN T lanes[16 / sizeof(T)]; for (size_t i = 0; i < 16 / sizeof(T); ++i) { lanes[i] = static_cast(first + static_cast(i)); @@ -259,6 +268,11 @@ HWY_API Vec128 operator+(const Vec128 a, const Vec128 b) { return Vec128{wasm_i32x4_add(a.raw, b.raw)}; } +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} // Signed template @@ -276,6 +290,11 @@ HWY_API Vec128 operator+(const Vec128 a, const Vec128 b) { return Vec128{wasm_i32x4_add(a.raw, b.raw)}; } +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} // Float template @@ -302,6 +321,11 @@ HWY_API Vec128 operator-(const Vec128 a, const Vec128 b) { return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; } +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} // Signed template @@ -319,6 +343,11 @@ HWY_API Vec128 operator-(const Vec128 a, const Vec128 b) { return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; } +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} // Float template @@ -327,7 +356,7 @@ HWY_API Vec128 operator-(const Vec128 a, return Vec128{wasm_f32x4_sub(a.raw, b.raw)}; } -// ------------------------------ Saturating addition +// ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. @@ -355,7 +384,7 @@ HWY_API Vec128 SaturatedAdd(const Vec128 a, return Vec128{wasm_i16x8_add_sat(a.raw, b.raw)}; } -// ------------------------------ Saturating subtraction +// ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. @@ -416,7 +445,7 @@ HWY_API Vec128 Abs(const Vec128 v) { } template HWY_API Vec128 Abs(const Vec128 v) { - return Vec128{wasm_i62x2_abs(v.raw)}; + return Vec128{wasm_i64x2_abs(v.raw)}; } template @@ -440,9 +469,17 @@ HWY_API Vec128 ShiftLeft(const Vec128 v) { return Vec128{wasm_i32x4_shl(v.raw, kBits)}; } template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template HWY_API Vec128 ShiftRight(const Vec128 v) { return Vec128{wasm_u32x4_shr(v.raw, kBits)}; } +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u64x2_shr(v.raw, kBits)}; +} // Signed template @@ -458,14 +495,22 @@ HWY_API Vec128 ShiftLeft(const Vec128 v) { return Vec128{wasm_i32x4_shl(v.raw, kBits)}; } template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template HWY_API Vec128 ShiftRight(const Vec128 v) { return Vec128{wasm_i32x4_shr(v.raw, kBits)}; } +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i64x2_shr(v.raw, kBits)}; +} // 8-bit template HWY_API Vec128 ShiftLeft(const Vec128 v) { - const Simd d8; + const DFromV d8; // Use raw instead of BitCast to support N=1. const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; return kBits == 1 @@ -475,7 +520,7 @@ HWY_API Vec128 ShiftLeft(const Vec128 v) { template HWY_API Vec128 ShiftRight(const Vec128 v) { - const Simd d8; + const DFromV d8; // Use raw instead of BitCast to support N=1. const Vec128 shifted{ ShiftRight(Vec128{v.raw}).raw}; @@ -484,8 +529,8 @@ HWY_API Vec128 ShiftRight(const Vec128 v) { template HWY_API Vec128 ShiftRight(const Vec128 v) { - const Simd di; - const Simd du; + const DFromV di; + const RebindToUnsigned du; const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); return (shifted ^ shifted_sign) - shifted_sign; @@ -502,6 +547,10 @@ HWY_API Vec128 RotateRight(const Vec128 v) { // ------------------------------ Shift lanes by same variable #bits +// After https://reviews.llvm.org/D108415 shift argument became unsigned. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + // Unsigned template HWY_API Vec128 ShiftLeftSame(const Vec128 v, @@ -523,6 +572,16 @@ HWY_API Vec128 ShiftRightSame(const Vec128 v, const int bits) { return Vec128{wasm_u32x4_shr(v.raw, bits)}; } +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u64x2_shr(v.raw, bits)}; +} // Signed template @@ -545,11 +604,21 @@ HWY_API Vec128 ShiftRightSame(const Vec128 v, const int bits) { return Vec128{wasm_i32x4_shr(v.raw, bits)}; } +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shr(v.raw, bits)}; +} // 8-bit template HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { - const Simd d8; + const DFromV d8; // Use raw instead of BitCast to support N=1. const Vec128 shifted{ ShiftLeftSame(Vec128>{v.raw}, bits).raw}; @@ -559,7 +628,7 @@ HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { template HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { - const Simd d8; + const DFromV d8; // Use raw instead of BitCast to support N=1. const Vec128 shifted{ ShiftRightSame(Vec128{v.raw}, bits).raw}; @@ -568,73 +637,67 @@ HWY_API Vec128 ShiftRightSame(Vec128 v, template HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { - const Simd di; - const Simd du; + const DFromV di; + const RebindToUnsigned du; const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); return (shifted ^ shifted_sign) - shifted_sign; } +// ignore Wsign-conversion +HWY_DIAGNOSTICS(pop) + // ------------------------------ Minimum // Unsigned template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Min(Vec128 a, Vec128 b) { return Vec128{wasm_u8x16_min(a.raw, b.raw)}; } template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Min(Vec128 a, Vec128 b) { return Vec128{wasm_u16x8_min(a.raw, b.raw)}; } template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Min(Vec128 a, Vec128 b) { return Vec128{wasm_u32x4_min(a.raw, b.raw)}; } template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - alignas(16) float min[4]; - min[0] = - HWY_MIN(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); - min[1] = - HWY_MIN(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + alignas(16) uint64_t min[2]; + min[0] = HWY_MIN(wasm_u64x2_extract_lane(a.raw, 0), + wasm_u64x2_extract_lane(b.raw, 0)); + min[1] = HWY_MIN(wasm_u64x2_extract_lane(a.raw, 1), + wasm_u64x2_extract_lane(b.raw, 1)); return Vec128{wasm_v128_load(min)}; } // Signed template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Min(Vec128 a, Vec128 b) { return Vec128{wasm_i8x16_min(a.raw, b.raw)}; } template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Min(Vec128 a, Vec128 b) { return Vec128{wasm_i16x8_min(a.raw, b.raw)}; } template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Min(Vec128 a, Vec128 b) { return Vec128{wasm_i32x4_min(a.raw, b.raw)}; } template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - alignas(16) float min[4]; - min[0] = - HWY_MIN(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); - min[1] = - HWY_MIN(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + alignas(16) int64_t min[4]; + min[0] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + min[1] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); return Vec128{wasm_v128_load(min)}; } // Float template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Min(Vec128 a, Vec128 b) { return Vec128{wasm_f32x4_min(a.raw, b.raw)}; } @@ -642,62 +705,53 @@ HWY_API Vec128 Min(const Vec128 a, // Unsigned template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Max(Vec128 a, Vec128 b) { return Vec128{wasm_u8x16_max(a.raw, b.raw)}; } template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Max(Vec128 a, Vec128 b) { return Vec128{wasm_u16x8_max(a.raw, b.raw)}; } template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Max(Vec128 a, Vec128 b) { return Vec128{wasm_u32x4_max(a.raw, b.raw)}; } template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - alignas(16) float max[4]; - max[0] = - HWY_MAX(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); - max[1] = - HWY_MAX(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); - return Vec128{wasm_v128_load(max)}; +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + alignas(16) uint64_t max[2]; + max[0] = HWY_MAX(wasm_u64x2_extract_lane(a.raw, 0), + wasm_u64x2_extract_lane(b.raw, 0)); + max[1] = HWY_MAX(wasm_u64x2_extract_lane(a.raw, 1), + wasm_u64x2_extract_lane(b.raw, 1)); + return Vec128{wasm_v128_load(max)}; } // Signed template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Max(Vec128 a, Vec128 b) { return Vec128{wasm_i8x16_max(a.raw, b.raw)}; } template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Max(Vec128 a, Vec128 b) { return Vec128{wasm_i16x8_max(a.raw, b.raw)}; } template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Max(Vec128 a, Vec128 b) { return Vec128{wasm_i32x4_max(a.raw, b.raw)}; } template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - alignas(16) float max[4]; - max[0] = - HWY_MAX(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); - max[1] = - HWY_MAX(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + alignas(16) int64_t max[2]; + max[0] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + max[1] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); return Vec128{wasm_v128_load(max)}; } // Float template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { +HWY_API Vec128 Max(Vec128 a, Vec128 b) { return Vec128{wasm_f32x4_max(a.raw, b.raw)}; } @@ -781,7 +835,7 @@ HWY_API Vec128 MulEven(const Vec128 a, template HWY_API Vec128 Neg(const Vec128 v) { - return Xor(v, SignBit(Simd())); + return Xor(v, SignBit(DFromV())); } template @@ -915,7 +969,8 @@ HWY_API Vec128 Floor(const Vec128 v) { // Comparisons fill a lane with 1-bits if the condition is true, else 0. template -HWY_API Mask128 RebindMask(Simd /*tag*/, Mask128 m) { +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return Mask128{m.raw}; } @@ -944,6 +999,11 @@ HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; } +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} // Signed template @@ -961,6 +1021,11 @@ HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; } +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} // Float template @@ -987,6 +1052,11 @@ HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; } +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} // Signed template @@ -995,8 +1065,8 @@ HWY_API Mask128 operator!=(const Vec128 a, return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; } template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; } template @@ -1004,6 +1074,11 @@ HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; } +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} // Float template @@ -1032,28 +1107,42 @@ HWY_API Mask128 operator>(const Vec128 a, template HWY_API Mask128 operator>(const Vec128 a, const Vec128 b) { - const Simd d32; + return Mask128{wasm_i64x2_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition d32; const auto a32 = BitCast(d32, a); const auto b32 = BitCast(d32, b); - // If the upper half is less than or greater, this is the answer. - const auto m_gt = a32 < b32; + // If the upper halves are not equal, this is the answer. + const auto m_gt = a32 > b32; // Otherwise, the lower half decides. const auto m_eq = a32 == b32; - const auto lo_in_hi = wasm_i32x4_shuffle(m_gt, m_gt, 2, 2, 0, 0); - const auto lo_gt = And(m_eq, lo_in_hi); + const auto lo_in_hi = wasm_i32x4_shuffle(m_gt.raw, m_gt.raw, 0, 0, 2, 2); + const auto lo_gt = And(m_eq, MaskFromVec(VFromD{lo_in_hi})); const auto gt = Or(lo_gt, m_gt); // Copy result in upper 32 bits to lower 32 bits. - return Mask128{wasm_i32x4_shuffle(gt, gt, 3, 3, 1, 1)}; -} - -template -HWY_API Mask128 operator>(Vec128 a, Vec128 b) { - const Simd du; - const RebindToSigned di; - const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); - return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); + return Mask128{wasm_i32x4_shuffle(gt.raw, gt.raw, 1, 1, 3, 3)}; } template @@ -1084,7 +1173,7 @@ HWY_API Mask128 operator>=(const Vec128 a, // ------------------------------ FirstN (Iota, Lt) template -HWY_API Mask128 FirstN(const Simd d, size_t num) { +HWY_API Mask128 FirstN(const Simd d, size_t num) { const RebindToSigned di; // Signed comparisons may be cheaper. return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); } @@ -1127,6 +1216,21 @@ HWY_API Vec128 Xor(Vec128 a, Vec128 b) { return Vec128{wasm_v128_xor(a.raw, b.raw)}; } +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + // ------------------------------ Operator overloads (internal-only if float) template @@ -1150,7 +1254,7 @@ template HWY_API Vec128 CopySign(const Vec128 magn, const Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); - const auto msb = SignBit(Simd()); + const auto msb = SignBit(DFromV()); return Or(AndNot(msb, magn), And(msb, sign)); } @@ -1158,7 +1262,7 @@ template HWY_API Vec128 CopySignToAbs(const Vec128 abs, const Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); - return Or(abs, And(SignBit(Simd()), sign)); + return Or(abs, And(SignBit(DFromV()), sign)); } // ------------------------------ BroadcastSignBit (compare) @@ -1169,7 +1273,8 @@ HWY_API Vec128 BroadcastSignBit(const Vec128 v) { } template HWY_API Vec128 BroadcastSignBit(const Vec128 v) { - return VecFromMask(Simd(), v < Zero(Simd())); + const DFromV d; + return VecFromMask(d, v < Zero(d)); } // ------------------------------ Mask @@ -1181,13 +1286,7 @@ HWY_API Mask128 MaskFromVec(const Vec128 v) { } template -HWY_API Vec128 VecFromMask(Simd /* tag */, Mask128 v) { - return Vec128{v.raw}; -} - -// DEPRECATED -template -HWY_API Vec128 VecFromMask(const Mask128 v) { +HWY_API Vec128 VecFromMask(Simd /* tag */, Mask128 v) { return Vec128{v.raw}; } @@ -1201,18 +1300,29 @@ HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, // mask ? yes : 0 template HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { - return yes & VecFromMask(Simd(), mask); + return yes & VecFromMask(DFromV(), mask); } // mask ? 0 : no template HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { - return AndNot(VecFromMask(Simd(), mask), no); + return AndNot(VecFromMask(DFromV(), mask), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); } template HWY_API Vec128 ZeroIfNegative(Vec128 v) { - const Simd d; + const DFromV d; const auto zero = Zero(d); return IfThenElse(Mask128{(v > zero).raw}, v, zero); } @@ -1221,30 +1331,30 @@ HWY_API Vec128 ZeroIfNegative(Vec128 v) { template HWY_API Mask128 Not(const Mask128 m) { - return MaskFromVec(Not(VecFromMask(Simd(), m))); + return MaskFromVec(Not(VecFromMask(Simd(), m))); } template HWY_API Mask128 And(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Or(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } @@ -1260,7 +1370,7 @@ HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { template HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { - const Simd d; + const DFromV d; Mask128 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); @@ -1285,7 +1395,7 @@ HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { template HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { - const Simd d; + const DFromV d; Mask128 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); @@ -1312,11 +1422,23 @@ HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { return IfThenElse(mask, ShiftLeft<1>(v), v); } +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + alignas(16) T lanes[2]; + alignas(16) T bits_lanes[2]; + Store(v, d, lanes); + Store(bits, d, bits_lanes); + lanes[0] <<= bits_lanes[0]; + lanes[1] <<= bits_lanes[1]; + return Load(d, lanes); +} + // ------------------------------ Shr (BroadcastSignBit, IfThenElse) template HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { - const Simd d; + const DFromV d; Mask128 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); @@ -1341,7 +1463,7 @@ HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { template HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { - const Simd d; + const DFromV d; Mask128 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); @@ -1378,14 +1500,14 @@ HWY_API Vec128 Load(Full128 /* tag */, const T* HWY_RESTRICT aligned) { } template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } // Partial load. template -HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { +HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { Vec128 v; CopyBytes(p, &v); return v; @@ -1393,13 +1515,13 @@ HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { // LoadU == Load. template -HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { return Load(d, p); } // 128-bit SIMD => nothing to duplicate, same as an unaligned load. template -HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { +HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { return Load(d, p); } @@ -1412,18 +1534,18 @@ HWY_API void Store(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT aligned) { // Partial store. template -HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { +HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { CopyBytes(&v, p); } -HWY_API void Store(const Vec128 v, Simd /* tag */, +HWY_API void Store(const Vec128 v, Simd /* tag */, float* HWY_RESTRICT p) { *p = wasm_f32x4_extract_lane(v.raw, 0); } // StoreU == Store. template -HWY_API void StoreU(Vec128 v, Simd d, T* HWY_RESTRICT p) { +HWY_API void StoreU(Vec128 v, Simd d, T* HWY_RESTRICT p) { Store(v, d, p); } @@ -1432,7 +1554,7 @@ HWY_API void StoreU(Vec128 v, Simd d, T* HWY_RESTRICT p) { // Same as aligned stores on non-x86. template -HWY_API void Stream(Vec128 v, Simd /* tag */, +HWY_API void Stream(Vec128 v, Simd /* tag */, T* HWY_RESTRICT aligned) { wasm_v128_store(aligned, v.raw); } @@ -1440,7 +1562,8 @@ HWY_API void Stream(Vec128 v, Simd /* tag */, // ------------------------------ Scatter (Store) template -HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); @@ -1448,7 +1571,7 @@ HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, Store(v, d, lanes); alignas(16) Offset offset_lanes[N]; - Store(offset, Simd(), offset_lanes); + Store(offset, Rebind(), offset_lanes); uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { @@ -1457,7 +1580,7 @@ HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, } template -HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); @@ -1465,7 +1588,7 @@ HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, Store(v, d, lanes); alignas(16) Index index_lanes[N]; - Store(index, Simd(), index_lanes); + Store(index, Rebind(), index_lanes); for (size_t i = 0; i < N; ++i) { base[index_lanes[i]] = lanes[i]; @@ -1475,13 +1598,13 @@ HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, // ------------------------------ Gather (Load/Store) template -HWY_API Vec128 GatherOffset(const Simd d, +HWY_API Vec128 GatherOffset(const Simd d, const T* HWY_RESTRICT base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); alignas(16) Offset offset_lanes[N]; - Store(offset, Simd(), offset_lanes); + Store(offset, Rebind(), offset_lanes); alignas(16) T lanes[N]; const uint8_t* base_bytes = reinterpret_cast(base); @@ -1492,12 +1615,13 @@ HWY_API Vec128 GatherOffset(const Simd d, } template -HWY_API Vec128 GatherIndex(const Simd d, const T* HWY_RESTRICT base, +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); alignas(16) Index index_lanes[N]; - Store(index, Simd(), index_lanes); + Store(index, Rebind(), index_lanes); alignas(16) T lanes[N]; for (size_t i = 0; i < N; ++i) { @@ -1552,20 +1676,21 @@ HWY_API float GetLane(const Vec128 v) { // ------------------------------ LowerHalf template -HWY_API Vec128 LowerHalf(Simd /* tag */, Vec128 v) { +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { return Vec128{v.raw}; } template HWY_API Vec128 LowerHalf(Vec128 v) { - return LowerHalf(Simd(), v); + return LowerHalf(Simd(), v); } // ------------------------------ ShiftLeftBytes // 0x01..0F, kBytes = 1 => 0x02..0F00 template -HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); const __i8x16 zero = wasm_i8x16_splat(0); switch (kBytes) { @@ -1644,20 +1769,20 @@ HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { template HWY_API Vec128 ShiftLeftBytes(Vec128 v) { - return ShiftLeftBytes(Simd(), v); + return ShiftLeftBytes(Simd(), v); } // ------------------------------ ShiftLeftLanes template -HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { - return ShiftLeftLanes(Simd(), v); + return ShiftLeftLanes(DFromV(), v); } // ------------------------------ ShiftRightBytes @@ -1741,7 +1866,7 @@ HWY_API __i8x16 ShrBytes(const Vec128 v) { // 0x01..0F, kBytes = 1 => 0x0001..0E template -HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { // For partial vectors, clear upper lanes so we shift in zeros. if (N != 16 / sizeof(T)) { const Vec128 vfull{v.raw}; @@ -1752,31 +1877,30 @@ HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { // ------------------------------ ShiftRightLanes template -HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { const Repartition d8; - return BitCast(d, ShiftRightBytes(BitCast(d8, v))); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ UpperHalf (ShiftRightBytes) // Full input: copy hi into lo (smaller instruction encoding than shifts). template -HWY_API Vec128 UpperHalf(Half> /* tag */, - const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { + return Vec64{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; } -HWY_API Vec128 UpperHalf(Half> /* tag */, - const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { + return Vec64{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; } // Partial template -HWY_API Vec128 UpperHalf(Half> /* tag */, +HWY_API Vec128 UpperHalf(Half> /* tag */, Vec128 v) { - const Simd d; - const auto vu = BitCast(RebindToUnsigned(), v); - const auto upper = BitCast(d, ShiftRightBytes(vu)); + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); return Vec128{upper.raw}; } @@ -1854,7 +1978,7 @@ HWY_API V CombineShiftRightBytes(Full128 /* tag */, V hi, V lo) { template > -HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { +HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { constexpr size_t kSize = N * sizeof(T); static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); const Repartition d8; @@ -1869,40 +1993,24 @@ HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { // ------------------------------ Broadcast/splat any lane -// Unsigned -template -HWY_API Vec128 Broadcast(const Vec128 v) { +template +HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{wasm_i16x8_shuffle( - v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; + return Vec128{wasm_i16x8_shuffle(v.raw, v.raw, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane)}; } -template -HWY_API Vec128 Broadcast(const Vec128 v) { + +template +HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{ + return Vec128{ wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; } -// Signed -template -HWY_API Vec128 Broadcast(const Vec128 v) { +template +HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{wasm_i16x8_shuffle( - v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; -} -template -HWY_API Vec128 Broadcast(const Vec128 v) { - static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{ - wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; -} - -// Float -template -HWY_API Vec128 Broadcast(const Vec128 v) { - static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{ - wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, kLane, kLane)}; } // ------------------------------ TableLookupBytes @@ -1934,10 +2042,10 @@ HWY_API Vec128 TableLookupBytes(const Vec128 bytes, template HWY_API Vec128 TableLookupBytesOr0(const Vec128 bytes, const Vec128 from) { - const Simd d; + const Simd d; // Mask size must match vector type, so cast everything to this type. Repartition di8; - Repartition> d_bytes8; + Repartition> d_bytes8; const auto msb = BitCast(di8, from) < Zero(di8); const auto lookup = TableLookupBytes(BitCast(d_bytes8, bytes), BitCast(di8, from)); @@ -1952,57 +2060,44 @@ HWY_API Vec128 TableLookupBytesOr0(const Vec128 bytes, // CombineShiftRightBytes but the shuffle_abcd notation is more convenient. // Swap 32-bit halves in 64-bit halves. -HWY_API Vec128 Shuffle2301(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; -} -HWY_API Vec128 Shuffle2301(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; -} -HWY_API Vec128 Shuffle2301(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; } // Swap 64-bit halves -HWY_API Vec128 Shuffle1032(const Vec128 v) { - return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +template +HWY_API Vec128 Shuffle01(const Vec128 v) { + static_assert(sizeof(T) == 8, "Only for 64-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; } -HWY_API Vec128 Shuffle1032(const Vec128 v) { - return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; -} -HWY_API Vec128 Shuffle1032(const Vec128 v) { - return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; } // Rotate right 32 bits -HWY_API Vec128 Shuffle0321(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; -} -HWY_API Vec128 Shuffle0321(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; -} -HWY_API Vec128 Shuffle0321(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; } + // Rotate left 32 bits -HWY_API Vec128 Shuffle2103(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; -} -HWY_API Vec128 Shuffle2103(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; -} -HWY_API Vec128 Shuffle2103(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; } // Reverse -HWY_API Vec128 Shuffle0123(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; -} -HWY_API Vec128 Shuffle0123(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; -} -HWY_API Vec128 Shuffle0123(const Vec128 v) { - return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; } // ------------------------------ TableLookupLanes @@ -2014,10 +2109,10 @@ struct Indices128 { }; template -HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); #if HWY_IS_DEBUG_BUILD - const Simd di; + const Rebind di; HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && AllTrue(di, Lt(vec, Set(di, static_cast(N))))); #endif @@ -2052,7 +2147,7 @@ HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { } template -HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } @@ -2060,8 +2155,8 @@ HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { template HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { using TI = MakeSigned; - const Simd d; - const Simd di; + const DFromV d; + const Rebind di; return BitCast(d, TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); } @@ -2069,13 +2164,13 @@ HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { // Single lane: no change template -HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { return v; } // Two lanes: shuffle template -HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { return Vec128{Shuffle2301(Vec128{v.raw}).raw}; } @@ -2092,11 +2187,59 @@ HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { // 16-bit template -HWY_API Vec128 Reverse(Simd d, const Vec128 v) { +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { const RepartitionToWide> du32; return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); } +// ------------------------------ Reverse2 + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + return BitCast(d, Vec128{wasm_i16x8_shuffle(v.raw, v.raw, 3, 2, + 1, 0, 7, 6, 5, 4)}); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { + return Reverse(d, v); +} + +template +HWY_API Vec128 Reverse8(Simd, const Vec128) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + // ------------------------------ InterleaveLower template @@ -2151,9 +2294,9 @@ HWY_API Vec128 InterleaveLower(Vec128 a, return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; } -// Additional overload for the optional Simd<> tag. -template > -HWY_API V InterleaveLower(Simd /* tag */, V a, V b) { +// Additional overload for the optional tag. +template +HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { return InterleaveLower(a, b); } @@ -2226,7 +2369,7 @@ HWY_API V InterleaveUpper(Full128 /* tag */, V a, V b) { // Partial template > -HWY_API V InterleaveUpper(Simd d, V a, V b) { +HWY_API V InterleaveUpper(Simd d, V a, V b) { const Half d2; return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); } @@ -2235,19 +2378,17 @@ HWY_API V InterleaveUpper(Simd d, V a, V b) { // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. -template >> -HWY_API VFromD ZipLower(Vec128 a, Vec128 b) { +template >> +HWY_API VFromD ZipLower(V a, V b) { return BitCast(DW(), InterleaveLower(a, b)); } -template , - class DW = RepartitionToWide> -HWY_API VFromD ZipLower(DW dw, Vec128 a, Vec128 b) { +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { return BitCast(dw, InterleaveLower(D(), a, b)); } -template , - class DW = RepartitionToWide> -HWY_API VFromD ZipUpper(DW dw, Vec128 a, Vec128 b) { +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { return BitCast(dw, InterleaveUpper(D(), a, b)); } @@ -2257,7 +2398,7 @@ HWY_API VFromD ZipUpper(DW dw, Vec128 a, Vec128 b) { // N = N/2 + N/2 (upper half undefined) template -HWY_API Vec128 Combine(Simd d, Vec128 hi_half, +HWY_API Vec128 Combine(Simd d, Vec128 hi_half, Vec128 lo_half) { const Half d2; const RebindToUnsigned du2; @@ -2271,7 +2412,7 @@ HWY_API Vec128 Combine(Simd d, Vec128 hi_half, // ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) template -HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { return IfThenElseZero(FirstN(d, N / 2), Vec128{lo.raw}); } @@ -2284,10 +2425,10 @@ HWY_API Vec128 ConcatLowerLower(Full128 /* tag */, const Vec128 hi, return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; } template -HWY_API Vec128 ConcatLowerLower(Simd d, const Vec128 hi, +HWY_API Vec128 ConcatLowerLower(Simd d, const Vec128 hi, const Vec128 lo) { const Half d2; - return Combine(LowerHalf(d2, hi), LowerHalf(d2, lo)); + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); } // ------------------------------ ConcatUpperUpper @@ -2298,10 +2439,10 @@ HWY_API Vec128 ConcatUpperUpper(Full128 /* tag */, const Vec128 hi, return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; } template -HWY_API Vec128 ConcatUpperUpper(Simd d, const Vec128 hi, +HWY_API Vec128 ConcatUpperUpper(Simd d, const Vec128 hi, const Vec128 lo) { const Half d2; - return Combine(UpperHalf(d2, hi), UpperHalf(d2, lo)); + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); } // ------------------------------ ConcatLowerUpper @@ -2312,15 +2453,15 @@ HWY_API Vec128 ConcatLowerUpper(Full128 d, const Vec128 hi, return CombineShiftRightBytes<8>(d, hi, lo); } template -HWY_API Vec128 ConcatLowerUpper(Simd d, const Vec128 hi, +HWY_API Vec128 ConcatLowerUpper(Simd d, const Vec128 hi, const Vec128 lo) { const Half d2; - return Combine(LowerHalf(d2, hi), UpperHalf(d2, lo)); + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); } // ------------------------------ ConcatUpperLower template -HWY_API Vec128 ConcatUpperLower(Simd d, const Vec128 hi, +HWY_API Vec128 ConcatUpperLower(Simd d, const Vec128 hi, const Vec128 lo) { return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); } @@ -2335,9 +2476,9 @@ HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, Vec128 lo) { // 32-bit partial template -HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, Vec128 lo) { - return InterleaveUpper(Simd(), lo, hi); + return InterleaveUpper(Simd(), lo, hi); } // 64-bit full - no partial because we need at least two inputs to have @@ -2357,9 +2498,9 @@ HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, Vec128 lo) { // 32-bit partial template -HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, Vec128 lo) { - return InterleaveLower(Simd(), lo, hi); + return InterleaveLower(Simd(), lo, hi); } // 64-bit full - no partial because we need at least two inputs to have @@ -2369,6 +2510,30 @@ HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, Vec128 lo) { return InterleaveLower(Full128(), lo, hi); } +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 0, 0, 2, 2)}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 1, 3, 3)}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + // ------------------------------ OddEven namespace detail { @@ -2376,7 +2541,7 @@ namespace detail { template HWY_INLINE Vec128 OddEven(hwy::SizeTag<1> /* tag */, const Vec128 a, const Vec128 b) { - const Simd d; + const DFromV d; const Repartition d8; alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; @@ -2424,74 +2589,92 @@ HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { return v; } +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) // Unsigned: zero-extend. template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_u16x8_extend_low_u8x16(v.raw)}; } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{ wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_u16x8_extend_low_u8x16(v.raw)}; } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{ wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_u32x4_extend_low_u16x8(v.raw)}; } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u64x2_extend_low_u32x4(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_u32x4_extend_low_u16x8(v.raw)}; } // Signed: replicate sign bit. template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_i16x8_extend_low_i8x16(v.raw)}; } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{ wasm_i32x4_extend_low_i16x8(wasm_i16x8_extend_low_i8x16(v.raw))}; } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_i32x4_extend_low_i16x8(v.raw)}; } +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i64x2_extend_low_i32x4(v.raw)}; +} template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_f64x2_convert_low_i32x4(v.raw)}; } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd df32, const Vec128 v) { - const Simd di32; - const Simd du32; - const Simd df32; + const RebindToSigned di32; + const RebindToUnsigned du32; // Expand to u32 so we can shift. const auto bits16 = PromoteTo(du32, Vec128{v.raw}); const auto sign = ShiftRight<15>(bits16); @@ -2509,7 +2692,7 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, } template -HWY_API Vec128 PromoteTo(Simd df32, +HWY_API Vec128 PromoteTo(Simd df32, const Vec128 v) { const Rebind du16; const RebindToSigned di32; @@ -2519,19 +2702,19 @@ HWY_API Vec128 PromoteTo(Simd df32, // ------------------------------ Demotions (full -> part w/ narrow lanes) template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); return Vec128{ @@ -2539,36 +2722,36 @@ HWY_API Vec128 DemoteTo(Simd /* tag */, } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); return Vec128{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; } template -HWY_API Vec128 DemoteTo(Simd /* di */, +HWY_API Vec128 DemoteTo(Simd /* di */, const Vec128 v) { return Vec128{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd df16, const Vec128 v) { - const Simd di; - const Simd du; - const Simd du16; + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; const auto bits32 = BitCast(du, v); const auto sign = ShiftRight<31>(bits32); const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); @@ -2594,7 +2777,7 @@ HWY_API Vec128 DemoteTo(Simd /* tag */, } template -HWY_API Vec128 DemoteTo(Simd dbf16, +HWY_API Vec128 DemoteTo(Simd dbf16, const Vec128 v) { const Rebind di32; const Rebind du32; // for logical shift right @@ -2605,7 +2788,7 @@ HWY_API Vec128 DemoteTo(Simd dbf16, template HWY_API Vec128 ReorderDemote2To( - Simd dbf16, Vec128 a, Vec128 b) { + Simd dbf16, Vec128 a, Vec128 b) { const RebindToUnsigned du16; const Repartition du32; const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); @@ -2623,30 +2806,54 @@ HWY_API Vec128 U8FromU32(const Vec128 v) { // ------------------------------ Convert i32 <=> f32 (Round) template -HWY_API Vec128 ConvertTo(Simd /* tag */, +HWY_API Vec128 ConvertTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_f32x4_convert_i32x4(v.raw)}; } // Truncates (rounds toward zero). template -HWY_API Vec128 ConvertTo(Simd /* tag */, +HWY_API Vec128 ConvertTo(Simd /* tag */, const Vec128 v) { return Vec128{wasm_i32x4_trunc_sat_f32x4(v.raw)}; } template HWY_API Vec128 NearestInt(const Vec128 v) { - return ConvertTo(Simd(), Round(v)); + return ConvertTo(Simd(), Round(v)); } // ================================================== MISC +// ------------------------------ SumsOf8 (ShiftRight, Add) +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = And(BitCast(du16, v), Set(du16, 0xFF)); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return And(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), Set(du64, 0xFFFF)); +} + // ------------------------------ LoadMaskBits (TestBit) namespace detail { template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { const RebindToUnsigned du; // Easier than Set(), which would require an >8-bit type, which would not // compile for T=uint8_t, N=1. @@ -2663,7 +2870,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { const RebindToUnsigned du; alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; return RebindMask( @@ -2671,7 +2878,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { const RebindToUnsigned du; alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; return RebindMask( @@ -2679,7 +2886,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { const RebindToUnsigned du; alignas(16) constexpr uint64_t kBit[8] = {1, 2}; return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); @@ -2689,7 +2896,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { // `p` points to at least 8 readable bytes, not all of which need be valid. template -HWY_API Mask128 LoadMaskBits(Simd d, +HWY_API Mask128 LoadMaskBits(Simd d, const uint8_t* HWY_RESTRICT bits) { uint64_t mask_bits = 0; CopyBytes<(N + 7) / 8>(bits, &mask_bits); @@ -2754,6 +2961,17 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, return lanes[0] | lanes[1] | lanes[2] | lanes[3]; } +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 mask) { + const __i64x2 mask_i = static_cast<__i64x2>(mask.raw); + const __i64x2 slice = wasm_i64x2_make(1, 2); + const __i64x2 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, sliced_mask); + return lanes[0] | lanes[1]; +} + // Returns the lowest N bits for the BitsFromMask result. template constexpr uint64_t OnlyActive(uint64_t bits) { @@ -2814,11 +3032,18 @@ HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { return PopCount(lanes[0] | lanes[1]); } +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + alignas(16) int64_t lanes[2]; + wasm_v128_store(lanes, m.raw); + return static_cast(-(lanes[0] + lanes[1])); +} + } // namespace detail // `p` points to at least 8 writable bytes. template -HWY_API size_t StoreMaskBits(const Simd /* tag */, +HWY_API size_t StoreMaskBits(const Simd /* tag */, const Mask128 mask, uint8_t* bits) { const uint64_t mask_bits = detail::BitsFromMask(mask); const size_t kNumBytes = (N + 7) / 8; @@ -2827,13 +3052,13 @@ HWY_API size_t StoreMaskBits(const Simd /* tag */, } template -HWY_API size_t CountTrue(const Simd /* tag */, const Mask128 m) { +HWY_API size_t CountTrue(const Simd /* tag */, const Mask128 m) { return detail::CountTrue(hwy::SizeTag(), m); } // Partial vector template -HWY_API size_t CountTrue(const Simd d, const Mask128 m) { +HWY_API size_t CountTrue(const Simd d, const Mask128 m) { // Ensure all undefined bytes are 0. const Mask128 mask{detail::BytesAbove()}; return CountTrue(d, Mask128{AndNot(mask, m).raw}); @@ -2868,32 +3093,36 @@ template HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { return wasm_i32x4_all_true(m.raw); } +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + return wasm_i64x2_all_true(m.raw); +} } // namespace detail template -HWY_API bool AllTrue(const Simd /* tag */, const Mask128 m) { +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 m) { return detail::AllTrue(hwy::SizeTag(), m); } // Partial vectors template -HWY_API bool AllFalse(Simd /* tag */, const Mask128 m) { +HWY_API bool AllFalse(Simd /* tag */, const Mask128 m) { // Ensure all undefined bytes are 0. const Mask128 mask{detail::BytesAbove()}; - return AllFalse(Mask128{AndNot(mask, m).raw}); + return AllFalse(Full128(), Mask128{AndNot(mask, m).raw}); } template -HWY_API bool AllTrue(const Simd d, const Mask128 m) { +HWY_API bool AllTrue(const Simd /* d */, const Mask128 m) { // Ensure all undefined bytes are FF. const Mask128 mask{detail::BytesAbove()}; - return AllTrue(d, Mask128{Or(mask, m).raw}); + return AllTrue(Full128(), Mask128{Or(mask, m).raw}); } template -HWY_API intptr_t FindFirstTrue(const Simd /* tag */, +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, const Mask128 mask) { const uint64_t bits = detail::BitsFromMask(mask); return bits ? static_cast(Num0BitsBelowLS1Bit_Nonzero64(bits)) : -1; @@ -2906,9 +3135,9 @@ namespace detail { template HWY_INLINE Vec128 Idx16x8FromBits(const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 256); - const Simd d; + const Simd d; const Rebind d8; - const Simd du; + const Simd du; // We need byte indices for TableLookupBytes (one vector's worth for each of // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We @@ -3059,13 +3288,11 @@ HWY_INLINE Vec128 Idx32x4FromBits(const uint64_t mask_bits) { 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const Simd d; + const Simd d; const Repartition d8; return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); } -#if HWY_CAP_INTEGER64 || HWY_CAP_FLOAT64 - template HWY_INLINE Vec128 Idx64x2FromBits(const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 4); @@ -3077,13 +3304,11 @@ HWY_INLINE Vec128 Idx64x2FromBits(const uint64_t mask_bits) { 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const Simd d; + const Simd d; const Repartition d8; return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); } -#endif - // Helper functions called by both Compress and CompressStore - avoids a // redundant BitsFromMask in the latter. @@ -3091,34 +3316,29 @@ template HWY_INLINE Vec128 Compress(hwy::SizeTag<2> /*tag*/, Vec128 v, const uint64_t mask_bits) { const auto idx = detail::Idx16x8FromBits(mask_bits); - using D = Simd; - const RebindToSigned di; - return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } template HWY_INLINE Vec128 Compress(hwy::SizeTag<4> /*tag*/, Vec128 v, const uint64_t mask_bits) { const auto idx = detail::Idx32x4FromBits(mask_bits); - using D = Simd; - const RebindToSigned di; - return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } -#if HWY_CAP_INTEGER64 || HWY_CAP_FLOAT64 - template -HWY_INLINE Vec128 Compress(hwy::SizeTag<8> /*tag*/, - Vec128 v, - const uint64_t mask_bits) { - const auto idx = detail::Idx64x2FromBits(mask_bits); - using D = Simd; - const RebindToSigned di; - return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +HWY_INLINE Vec128 Compress(hwy::SizeTag<8> /*tag*/, Vec128 v, + const uint64_t mask_bits) { + const auto idx = detail::Idx64x2FromBits(mask_bits); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } -#endif - } // namespace detail template @@ -3145,7 +3365,7 @@ HWY_API Vec128 CompressBits(Vec128 v, // ------------------------------ CompressStore template HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, - Simd d, T* HWY_RESTRICT unaligned) { + Simd d, T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(mask); const auto c = detail::Compress(hwy::SizeTag(), v, mask_bits); StoreU(c, d, unaligned); @@ -3155,7 +3375,8 @@ HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, // ------------------------------ CompressBlendedStore template HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, - Simd d, T* HWY_RESTRICT unaligned) { + Simd d, + T* HWY_RESTRICT unaligned) { const RebindToUnsigned du; // so we can support fp16/bf16 using TU = TFromD; const uint64_t mask_bits = detail::BitsFromMask(m); @@ -3172,8 +3393,8 @@ HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, template HWY_API size_t CompressBitsStore(Vec128 v, - const uint8_t* HWY_RESTRICT bits, Simd d, - T* HWY_RESTRICT unaligned) { + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { uint64_t mask_bits = 0; constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(bits, &mask_bits); @@ -3237,7 +3458,7 @@ HWY_API void StoreInterleaved3(const Vec128 a, const Vec128 b, // 64 bits HWY_API void StoreInterleaved3(const Vec128 a, const Vec128 b, - const Vec128 c, Simd d, + const Vec128 c, Full64 d, uint8_t* HWY_RESTRICT unaligned) { // Use full vectors for the shuffles and first result. const Full128 d_full; @@ -3281,7 +3502,7 @@ template HWY_API void StoreInterleaved3(const Vec128 a, const Vec128 b, const Vec128 c, - Simd /*tag*/, + Simd /*tag*/, uint8_t* HWY_RESTRICT unaligned) { // Use full vectors for the shuffles and result. const Full128 d_full; @@ -3337,7 +3558,7 @@ HWY_API void StoreInterleaved4(const Vec128 in0, const Vec128 in1, const Vec128 in2, const Vec128 in3, - Simd /* tag */, + Full64 /* tag */, uint8_t* HWY_RESTRICT unaligned) { // Use full vectors to reduce the number of stores. const Full128 d_full8; @@ -3362,7 +3583,7 @@ HWY_API void StoreInterleaved4(const Vec128 in0, const Vec128 in1, const Vec128 in2, const Vec128 in3, - Simd /*tag*/, + Simd /*tag*/, uint8_t* HWY_RESTRICT unaligned) { // Use full vectors to reduce the number of stores. const Full128 d_full8; @@ -3404,7 +3625,7 @@ HWY_INLINE Vec128 MulOdd(const Vec128 a, // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) template -HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, Vec128 a, Vec128 b, const Vec128 sum0, @@ -3511,133 +3732,91 @@ HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, // u16/i16 template HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { - const Repartition> d32; + const DFromV d; + const Repartition d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(d32, Min(even, odd)); // Also broadcast into odd lanes. - return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); + return BitCast(d, Or(min, ShiftLeft<16>(min))); } template HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { - const Repartition> d32; + const DFromV d; + const Repartition d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(d32, Max(even, odd)); // Also broadcast into odd lanes. - return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); + return BitCast(d, Or(min, ShiftLeft<16>(min))); } } // namespace detail // Supported for u/i/f 32/64. Returns the same value in each lane. template -HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { return detail::SumOfLanes(hwy::SizeTag(), v); } template -HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { return detail::MinOfLanes(hwy::SizeTag(), v); } template -HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { return detail::MaxOfLanes(hwy::SizeTag(), v); } -// ================================================== DEPRECATED +// ------------------------------ Lt128 -template -HWY_API size_t StoreMaskBits(const Mask128 mask, uint8_t* bits) { - return StoreMaskBits(Simd(), mask, bits); +namespace detail { + +template +Mask128 ShiftMaskLeft(Mask128 m) { + return MaskFromVec(ShiftLeftLanes(VecFromMask(Simd(), m))); } -template -HWY_API bool AllTrue(const Mask128 mask) { - return AllTrue(Simd(), mask); +} // namespace detail + +template +HWY_INLINE Mask128 Lt128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "Use u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const Mask128 eqHL = Eq(a, b); + const Mask128 ltHL = Lt(a, b); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. + const Mask128 ltLx = detail::ShiftMaskLeft<1>(ltHL); + const Mask128 outHx = Or(ltHL, And(eqHL, ltLx)); + const Vec128 vecHx = VecFromMask(d, outHx); + return MaskFromVec(InterleaveUpper(d, vecHx, vecHx)); } -template -HWY_API bool AllFalse(const Mask128 mask) { - return AllFalse(Simd(), mask); +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); } -template -HWY_API size_t CountTrue(const Mask128 mask) { - return CountTrue(Simd(), mask); -} - -template -HWY_API Vec128 SumOfLanes(const Vec128 v) { - return SumOfLanes(Simd(), v); -} -template -HWY_API Vec128 MinOfLanes(const Vec128 v) { - return MinOfLanes(Simd(), v); -} -template -HWY_API Vec128 MaxOfLanes(const Vec128 v) { - return MaxOfLanes(Simd(), v); -} - -template -HWY_API Vec128 UpperHalf(Vec128 v) { - return UpperHalf(Half>(), v); -} - -template -HWY_API Vec128 ShiftRightBytes(const Vec128 v) { - return ShiftRightBytes(Simd(), v); -} - -template -HWY_API Vec128 ShiftRightLanes(const Vec128 v) { - return ShiftRightLanes(Simd(), v); -} - -template -HWY_API Vec128 CombineShiftRightBytes(Vec128 hi, Vec128 lo) { - return CombineShiftRightBytes(Simd(), hi, lo); -} - -template -HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { - return InterleaveUpper(Simd(), a, b); -} - -template > -HWY_API VFromD> ZipUpper(Vec128 a, Vec128 b) { - return InterleaveUpper(RepartitionToWide(), a, b); -} - -template -HWY_API Vec128 Combine(Vec128 hi2, Vec128 lo2) { - return Combine(Simd(), hi2, lo2); -} - -template -HWY_API Vec128 ZeroExtendVector(Vec128 lo) { - return ZeroExtendVector(Simd(), lo); -} - -template -HWY_API Vec128 ConcatLowerLower(Vec128 hi, Vec128 lo) { - return ConcatLowerLower(Simd(), hi, lo); -} - -template -HWY_API Vec128 ConcatUpperUpper(Vec128 hi, Vec128 lo) { - return ConcatUpperUpper(Simd(), hi, lo); -} - -template -HWY_API Vec128 ConcatLowerUpper(const Vec128 hi, - const Vec128 lo) { - return ConcatLowerUpper(Simd(), hi, lo); -} - -template -HWY_API Vec128 ConcatUpperLower(Vec128 hi, Vec128 lo) { - return ConcatUpperLower(Simd(), hi, lo); +template +HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), b, a); } // ================================================== Operator wrapper diff --git a/third_party/highway/hwy/ops/wasm_256-inl.h b/third_party/highway/hwy/ops/wasm_256-inl.h index f66e257fa6d0..b7a74a03a086 100644 --- a/third_party/highway/hwy/ops/wasm_256-inl.h +++ b/third_party/highway/hwy/ops/wasm_256-inl.h @@ -28,7 +28,10 @@ namespace hwy { namespace HWY_NAMESPACE { template -using Full256 = Simd; +using Full256 = Simd; + +template +using Full128 = Simd; // TODO(richardwinterton): add this to DeduceD in wasm_128 similar to x86_128. template @@ -70,8 +73,8 @@ struct Mask256 { // ------------------------------ BitCast -template -HWY_API Vec256 BitCast(Simd d, Vec256 v) { +template +HWY_API Vec256 BitCast(Full256 d, Vec256 v) { const Half dh; Vec256 ret; ret.v0 = BitCast(dh, v.v0); @@ -84,13 +87,12 @@ HWY_API Vec256 BitCast(Simd d, Vec256 v) { // ------------------------------ Zero // Returns an all-zero vector/part. -template -HWY_API Vec256 Zero(Simd /* tag */) { +template +HWY_API Vec256 Zero(Full256 /* tag */) { return Vec256{wasm_i32x4_splat(0)}; } -template -HWY_API Vec128 Zero(Simd /* tag */) { - return Vec128{wasm_f32x4_splat(0.0f)}; +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{wasm_f32x4_splat(0.0f)}; } template @@ -99,51 +101,42 @@ using VFromD = decltype(Zero(D())); // ------------------------------ Set // Returns a vector/part with all lanes set to "t". -template -HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { - return Vec128{wasm_i8x16_splat(static_cast(t))}; +HWY_API Vec256 Set(Full256 /* tag */, const uint8_t t) { + return Vec256{wasm_i8x16_splat(static_cast(t))}; } -template -HWY_API Vec128 Set(Simd /* tag */, const uint16_t t) { - return Vec128{wasm_i16x8_splat(static_cast(t))}; +HWY_API Vec256 Set(Full256 /* tag */, const uint16_t t) { + return Vec256{wasm_i16x8_splat(static_cast(t))}; } -template -HWY_API Vec128 Set(Simd /* tag */, const uint32_t t) { - return Vec128{wasm_i32x4_splat(static_cast(t))}; +HWY_API Vec256 Set(Full256 /* tag */, const uint32_t t) { + return Vec256{wasm_i32x4_splat(static_cast(t))}; } -template -HWY_API Vec128 Set(Simd /* tag */, const uint64_t t) { - return Vec128{wasm_i64x2_splat(static_cast(t))}; +HWY_API Vec256 Set(Full256 /* tag */, const uint64_t t) { + return Vec256{wasm_i64x2_splat(static_cast(t))}; } -template -HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { - return Vec128{wasm_i8x16_splat(t)}; +HWY_API Vec256 Set(Full256 /* tag */, const int8_t t) { + return Vec256{wasm_i8x16_splat(t)}; } -template -HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { - return Vec128{wasm_i16x8_splat(t)}; +HWY_API Vec256 Set(Full256 /* tag */, const int16_t t) { + return Vec256{wasm_i16x8_splat(t)}; } -template -HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { - return Vec128{wasm_i32x4_splat(t)}; +HWY_API Vec256 Set(Full256 /* tag */, const int32_t t) { + return Vec256{wasm_i32x4_splat(t)}; } -template -HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { - return Vec128{wasm_i64x2_splat(t)}; +HWY_API Vec256 Set(Full256 /* tag */, const int64_t t) { + return Vec256{wasm_i64x2_splat(t)}; } -template -HWY_API Vec128 Set(Simd /* tag */, const float t) { - return Vec128{wasm_f32x4_splat(t)}; +HWY_API Vec256 Set(Full256 /* tag */, const float t) { + return Vec256{wasm_f32x4_splat(t)}; } HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") // Returns a vector with uninitialized elements. -template -HWY_API Vec256 Undefined(Simd d) { +template +HWY_API Vec256 Undefined(Full256 d) { return Zero(d); } @@ -151,7 +144,7 @@ HWY_DIAGNOSTICS(pop) // Returns a vector with lane i=[0, N) set to "first" + i. template -Vec256 Iota(const Simd d, const T2 first) { +Vec256 Iota(const Full256 d, const T2 first) { HWY_ALIGN T lanes[16 / sizeof(T)]; for (size_t i = 0; i < 16 / sizeof(T); ++i) { lanes[i] = static_cast(first + static_cast(i)); @@ -164,143 +157,123 @@ Vec256 Iota(const Simd d, const T2 first) { // ------------------------------ Addition // Unsigned -template -HWY_API Vec128 operator+(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_add(a.raw, b.raw)}; } -template -HWY_API Vec128 operator+(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_add(a.raw, b.raw)}; } -template -HWY_API Vec128 operator+(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_add(a.raw, b.raw)}; } // Signed -template -HWY_API Vec128 operator+(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_add(a.raw, b.raw)}; } -template -HWY_API Vec128 operator+(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_add(a.raw, b.raw)}; } -template -HWY_API Vec128 operator+(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_add(a.raw, b.raw)}; } // Float -template -HWY_API Vec128 operator+(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_f32x4_add(a.raw, b.raw)}; +HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_add(a.raw, b.raw)}; } // ------------------------------ Subtraction // Unsigned -template -HWY_API Vec128 operator-(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_sub(a.raw, b.raw)}; } -template -HWY_API Vec128 operator-(Vec128 a, - Vec128 b) { - return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{wasm_i16x8_sub(a.raw, b.raw)}; } -template -HWY_API Vec128 operator-(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_sub(a.raw, b.raw)}; } // Signed -template -HWY_API Vec128 operator-(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_sub(a.raw, b.raw)}; } -template -HWY_API Vec128 operator-(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_sub(a.raw, b.raw)}; } -template -HWY_API Vec128 operator-(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_sub(a.raw, b.raw)}; } // Float -template -HWY_API Vec128 operator-(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_f32x4_sub(a.raw, b.raw)}; +HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_sub(a.raw, b.raw)}; } -// ------------------------------ Saturating addition +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(const Vec256 v) { + HWY_ABORT("not implemented"); +} + +// ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. // Unsigned -template -HWY_API Vec128 SaturatedAdd(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u8x16_add_sat(a.raw, b.raw)}; +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u8x16_add_sat(a.raw, b.raw)}; } -template -HWY_API Vec128 SaturatedAdd(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u16x8_add_sat(a.raw, b.raw)}; +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_add_sat(a.raw, b.raw)}; } // Signed -template -HWY_API Vec128 SaturatedAdd(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i8x16_add_sat(a.raw, b.raw)}; +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_add_sat(a.raw, b.raw)}; } -template -HWY_API Vec128 SaturatedAdd(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i16x8_add_sat(a.raw, b.raw)}; +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_add_sat(a.raw, b.raw)}; } -// ------------------------------ Saturating subtraction +// ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. // Unsigned -template -HWY_API Vec128 SaturatedSub(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u8x16_sub_sat(a.raw, b.raw)}; +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u8x16_sub_sat(a.raw, b.raw)}; } -template -HWY_API Vec128 SaturatedSub(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u16x8_sub_sat(a.raw, b.raw)}; +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_sub_sat(a.raw, b.raw)}; } // Signed -template -HWY_API Vec128 SaturatedSub(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i8x16_sub_sat(a.raw, b.raw)}; +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_sub_sat(a.raw, b.raw)}; } -template -HWY_API Vec128 SaturatedSub(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i16x8_sub_sat(a.raw, b.raw)}; +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_sub_sat(a.raw, b.raw)}; } // ------------------------------ Average @@ -308,84 +281,77 @@ HWY_API Vec128 SaturatedSub(const Vec128 a, // Returns (a + b + 1) / 2 // Unsigned -template -HWY_API Vec128 AverageRound(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u8x16_avgr(a.raw, b.raw)}; +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u8x16_avgr(a.raw, b.raw)}; } -template -HWY_API Vec128 AverageRound(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u16x8_avgr(a.raw, b.raw)}; +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_avgr(a.raw, b.raw)}; } // ------------------------------ Absolute value // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. -template -HWY_API Vec128 Abs(const Vec128 v) { - return Vec128{wasm_i8x16_abs(v.raw)}; +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_i8x16_abs(v.raw)}; } -template -HWY_API Vec128 Abs(const Vec128 v) { - return Vec128{wasm_i16x8_abs(v.raw)}; +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_i16x8_abs(v.raw)}; } -template -HWY_API Vec128 Abs(const Vec128 v) { - return Vec128{wasm_i32x4_abs(v.raw)}; +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_i32x4_abs(v.raw)}; } -template -HWY_API Vec128 Abs(const Vec128 v) { - return Vec128{wasm_i62x2_abs(v.raw)}; +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_i62x2_abs(v.raw)}; } -template -HWY_API Vec128 Abs(const Vec128 v) { - return Vec128{wasm_f32x4_abs(v.raw)}; +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_f32x4_abs(v.raw)}; } // ------------------------------ Shift lanes by constant #bits // Unsigned -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{wasm_i16x8_shl(v.raw, kBits)}; } -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - return Vec128{wasm_u16x8_shr(v.raw, kBits)}; +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{wasm_u16x8_shr(v.raw, kBits)}; } -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{wasm_i32x4_shl(v.raw, kBits)}; } -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - return Vec128{wasm_u32x4_shr(v.raw, kBits)}; +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{wasm_u32x4_shr(v.raw, kBits)}; } // Signed -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{wasm_i16x8_shl(v.raw, kBits)}; } -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - return Vec128{wasm_i16x8_shr(v.raw, kBits)}; +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{wasm_i16x8_shr(v.raw, kBits)}; } -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{wasm_i32x4_shl(v.raw, kBits)}; } -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - return Vec128{wasm_i32x4_shr(v.raw, kBits)}; +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{wasm_i32x4_shr(v.raw, kBits)}; } // 8-bit template HWY_API Vec256 ShiftLeft(const Vec256 v) { - const Simd d8; + const Full256 d8; // Use raw instead of BitCast to support N=1. const Vec256 shifted{ShiftLeft(Vec128>{v.raw}).raw}; return kBits == 1 @@ -393,19 +359,18 @@ HWY_API Vec256 ShiftLeft(const Vec256 v) { : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); } -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - const Simd d8; +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 d8; // Use raw instead of BitCast to support N=1. - const Vec128 shifted{ - ShiftRight(Vec128{v.raw}).raw}; + const Vec256 shifted{ShiftRight(Vec128{v.raw}).raw}; return shifted & Set(d8, 0xFF >> kBits); } -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - const Simd di; - const Simd du; +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 di; + const Full256 du; const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); return (shifted ^ shifted_sign) - shifted_sign; @@ -423,72 +388,59 @@ HWY_API Vec256 RotateRight(const Vec256 v) { // ------------------------------ Shift lanes by same variable #bits // Unsigned -template -HWY_API Vec128 ShiftLeftSame(const Vec128 v, - const int bits) { - return Vec128{wasm_i16x8_shl(v.raw, bits)}; +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{wasm_i16x8_shl(v.raw, bits)}; } -template -HWY_API Vec128 ShiftRightSame(const Vec128 v, - const int bits) { - return Vec128{wasm_u16x8_shr(v.raw, bits)}; +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{wasm_u16x8_shr(v.raw, bits)}; } -template -HWY_API Vec128 ShiftLeftSame(const Vec128 v, - const int bits) { - return Vec128{wasm_i32x4_shl(v.raw, bits)}; +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{wasm_i32x4_shl(v.raw, bits)}; } -template -HWY_API Vec128 ShiftRightSame(const Vec128 v, - const int bits) { - return Vec128{wasm_u32x4_shr(v.raw, bits)}; +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{wasm_u32x4_shr(v.raw, bits)}; } // Signed -template -HWY_API Vec128 ShiftLeftSame(const Vec128 v, - const int bits) { - return Vec128{wasm_i16x8_shl(v.raw, bits)}; +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{wasm_i16x8_shl(v.raw, bits)}; } -template -HWY_API Vec128 ShiftRightSame(const Vec128 v, - const int bits) { - return Vec128{wasm_i16x8_shr(v.raw, bits)}; +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{wasm_i16x8_shr(v.raw, bits)}; } -template -HWY_API Vec128 ShiftLeftSame(const Vec128 v, - const int bits) { - return Vec128{wasm_i32x4_shl(v.raw, bits)}; +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{wasm_i32x4_shl(v.raw, bits)}; } -template -HWY_API Vec128 ShiftRightSame(const Vec128 v, - const int bits) { - return Vec128{wasm_i32x4_shr(v.raw, bits)}; +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{wasm_i32x4_shr(v.raw, bits)}; } // 8-bit template HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { - const Simd d8; + const Full256 d8; // Use raw instead of BitCast to support N=1. const Vec256 shifted{ShiftLeftSame(Vec128>{v.raw}, bits).raw}; return shifted & Set(d8, (0xFF << bits) & 0xFF); } -template -HWY_API Vec128 ShiftRightSame(Vec128 v, - const int bits) { - const Simd d8; +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 d8; // Use raw instead of BitCast to support N=1. - const Vec128 shifted{ + const Vec256 shifted{ ShiftRightSame(Vec128{v.raw}, bits).raw}; return shifted & Set(d8, 0xFF >> bits); } -template -HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { - const Simd di; - const Simd du; +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 di; + const Full256 du; const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); return (shifted ^ shifted_sign) - shifted_sign; @@ -497,159 +449,124 @@ HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { // ------------------------------ Minimum // Unsigned -template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u8x16_min(a.raw, b.raw)}; +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_u8x16_min(a.raw, b.raw)}; } -template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u16x8_min(a.raw, b.raw)}; +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_min(a.raw, b.raw)}; } -template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u32x4_min(a.raw, b.raw)}; +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u32x4_min(a.raw, b.raw)}; } -template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - alignas(16) float min[4]; +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + alignas(32) float min[4]; min[0] = HWY_MIN(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); min[1] = HWY_MIN(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); - return Vec128{wasm_v128_load(min)}; + return Vec256{wasm_v128_load(min)}; } // Signed -template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i8x16_min(a.raw, b.raw)}; +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i8x16_min(a.raw, b.raw)}; } -template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i16x8_min(a.raw, b.raw)}; +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i16x8_min(a.raw, b.raw)}; } -template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i32x4_min(a.raw, b.raw)}; +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i32x4_min(a.raw, b.raw)}; } -template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - alignas(16) float min[4]; +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + alignas(32) float min[4]; min[0] = HWY_MIN(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); min[1] = HWY_MIN(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); - return Vec128{wasm_v128_load(min)}; + return Vec256{wasm_v128_load(min)}; } // Float -template -HWY_API Vec128 Min(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_f32x4_min(a.raw, b.raw)}; +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_min(a.raw, b.raw)}; } // ------------------------------ Maximum // Unsigned -template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u8x16_max(a.raw, b.raw)}; +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_u8x16_max(a.raw, b.raw)}; } -template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u16x8_max(a.raw, b.raw)}; +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_max(a.raw, b.raw)}; } -template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_u32x4_max(a.raw, b.raw)}; +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u32x4_max(a.raw, b.raw)}; } -template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - alignas(16) float max[4]; +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + alignas(32) float max[4]; max[0] = HWY_MAX(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); max[1] = HWY_MAX(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); - return Vec128{wasm_v128_load(max)}; + return Vec256{wasm_v128_load(max)}; } // Signed -template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i8x16_max(a.raw, b.raw)}; +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i8x16_max(a.raw, b.raw)}; } -template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i16x8_max(a.raw, b.raw)}; +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i16x8_max(a.raw, b.raw)}; } -template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i32x4_max(a.raw, b.raw)}; +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i32x4_max(a.raw, b.raw)}; } -template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - alignas(16) float max[4]; +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + alignas(32) float max[4]; max[0] = HWY_MAX(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); max[1] = HWY_MAX(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); - return Vec128{wasm_v128_load(max)}; + return Vec256{wasm_v128_load(max)}; } // Float -template -HWY_API Vec128 Max(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_f32x4_max(a.raw, b.raw)}; +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_max(a.raw, b.raw)}; } // ------------------------------ Integer multiplication // Unsigned -template -HWY_API Vec128 operator*(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_mul(a.raw, b.raw)}; } -template -HWY_API Vec128 operator*(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_mul(a.raw, b.raw)}; } // Signed -template -HWY_API Vec128 operator*(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_mul(a.raw, b.raw)}; } -template -HWY_API Vec128 operator*(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_mul(a.raw, b.raw)}; } // Returns the upper 16 bits of a * b in each lane. -template -HWY_API Vec128 MulHigh(const Vec128 a, - const Vec128 b) { +HWY_API Vec256 MulHigh(const Vec256 a, + const Vec256 b) { // TODO(eustas): replace, when implemented in WASM. const auto al = wasm_u32x4_extend_low_u16x8(a.raw); const auto ah = wasm_u32x4_extend_high_u16x8(a.raw); @@ -658,12 +575,10 @@ HWY_API Vec128 MulHigh(const Vec128 a, const auto l = wasm_i32x4_mul(al, bl); const auto h = wasm_i32x4_mul(ah, bh); // TODO(eustas): shift-right + narrow? - return Vec128{ - wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; + return Vec256{wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; } -template -HWY_API Vec128 MulHigh(const Vec128 a, - const Vec128 b) { +HWY_API Vec256 MulHigh(const Vec256 a, + const Vec256 b) { // TODO(eustas): replace, when implemented in WASM. const auto al = wasm_i32x4_extend_low_i16x8(a.raw); const auto ah = wasm_i32x4_extend_high_i16x8(a.raw); @@ -672,117 +587,96 @@ HWY_API Vec128 MulHigh(const Vec128 a, const auto l = wasm_i32x4_mul(al, bl); const auto h = wasm_i32x4_mul(ah, bh); // TODO(eustas): shift-right + narrow? - return Vec128{ - wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; + return Vec256{wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; } // Multiplies even lanes (0, 2 ..) and returns the double-width result. -template -HWY_API Vec128 MulEven(const Vec128 a, - const Vec128 b) { +HWY_API Vec256 MulEven(const Vec256 a, + const Vec256 b) { // TODO(eustas): replace, when implemented in WASM. const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); const auto ae = wasm_v128_and(a.raw, kEvenMask); const auto be = wasm_v128_and(b.raw, kEvenMask); - return Vec128{wasm_i64x2_mul(ae, be)}; + return Vec256{wasm_i64x2_mul(ae, be)}; } -template -HWY_API Vec128 MulEven(const Vec128 a, - const Vec128 b) { +HWY_API Vec256 MulEven(const Vec256 a, + const Vec256 b) { // TODO(eustas): replace, when implemented in WASM. const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); const auto ae = wasm_v128_and(a.raw, kEvenMask); const auto be = wasm_v128_and(b.raw, kEvenMask); - return Vec128{wasm_i64x2_mul(ae, be)}; + return Vec256{wasm_i64x2_mul(ae, be)}; } // ------------------------------ Negate template HWY_API Vec256 Neg(const Vec256 v) { - return Xor(v, SignBit(Simd())); + return Xor(v, SignBit(Full256())); } -template -HWY_API Vec128 Neg(const Vec128 v) { - return Vec128{wasm_i8x16_neg(v.raw)}; +HWY_API Vec256 Neg(const Vec256 v) { + return Vec256{wasm_i8x16_neg(v.raw)}; } -template -HWY_API Vec128 Neg(const Vec128 v) { - return Vec128{wasm_i16x8_neg(v.raw)}; +HWY_API Vec256 Neg(const Vec256 v) { + return Vec256{wasm_i16x8_neg(v.raw)}; } -template -HWY_API Vec128 Neg(const Vec128 v) { - return Vec128{wasm_i32x4_neg(v.raw)}; +HWY_API Vec256 Neg(const Vec256 v) { + return Vec256{wasm_i32x4_neg(v.raw)}; } -template -HWY_API Vec128 Neg(const Vec128 v) { - return Vec128{wasm_i64x2_neg(v.raw)}; +HWY_API Vec256 Neg(const Vec256 v) { + return Vec256{wasm_i64x2_neg(v.raw)}; } // ------------------------------ Floating-point mul / div -template -HWY_API Vec128 operator*(Vec128 a, Vec128 b) { - return Vec128{wasm_f32x4_mul(a.raw, b.raw)}; +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{wasm_f32x4_mul(a.raw, b.raw)}; } -template -HWY_API Vec128 operator/(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_f32x4_div(a.raw, b.raw)}; +HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_div(a.raw, b.raw)}; } // Approximate reciprocal -template -HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { - const Vec128 one = Vec128{wasm_f32x4_splat(1.0f)}; +HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { + const Vec256 one = Vec256{wasm_f32x4_splat(1.0f)}; return one / v; } // Absolute value of difference. -template -HWY_API Vec128 AbsDiff(const Vec128 a, - const Vec128 b) { +HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants // Returns mul * x + add -template -HWY_API Vec128 MulAdd(const Vec128 mul, - const Vec128 x, - const Vec128 add) { +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { // TODO(eustas): replace, when implemented in WASM. // TODO(eustas): is it wasm_f32x4_qfma? return mul * x + add; } // Returns add - mul * x -template -HWY_API Vec128 NegMulAdd(const Vec128 mul, - const Vec128 x, - const Vec128 add) { +HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { // TODO(eustas): replace, when implemented in WASM. return add - mul * x; } // Returns mul * x - sub -template -HWY_API Vec128 MulSub(const Vec128 mul, - const Vec128 x, - const Vec128 sub) { +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { // TODO(eustas): replace, when implemented in WASM. // TODO(eustas): is it wasm_f32x4_qfms? return mul * x - sub; } // Returns -mul * x - sub -template -HWY_API Vec128 NegMulSub(const Vec128 mul, - const Vec128 x, - const Vec128 sub) { +HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { // TODO(eustas): replace, when implemented in WASM. return Neg(mul) * x - sub; } @@ -790,57 +684,51 @@ HWY_API Vec128 NegMulSub(const Vec128 mul, // ------------------------------ Floating-point square root // Full precision square root -template -HWY_API Vec128 Sqrt(const Vec128 v) { - return Vec128{wasm_f32x4_sqrt(v.raw)}; +HWY_API Vec256 Sqrt(const Vec256 v) { + return Vec256{wasm_f32x4_sqrt(v.raw)}; } // Approximate reciprocal square root -template -HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { +HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { // TODO(eustas): find cheaper a way to calculate this. - const Vec128 one = Vec128{wasm_f32x4_splat(1.0f)}; + const Vec256 one = Vec256{wasm_f32x4_splat(1.0f)}; return one / Sqrt(v); } // ------------------------------ Floating-point rounding // Toward nearest integer, ties to even -template -HWY_API Vec128 Round(const Vec128 v) { - return Vec128{wasm_f32x4_nearest(v.raw)}; +HWY_API Vec256 Round(const Vec256 v) { + return Vec256{wasm_f32x4_nearest(v.raw)}; } // Toward zero, aka truncate -template -HWY_API Vec128 Trunc(const Vec128 v) { - return Vec128{wasm_f32x4_trunc(v.raw)}; +HWY_API Vec256 Trunc(const Vec256 v) { + return Vec256{wasm_f32x4_trunc(v.raw)}; } // Toward +infinity, aka ceiling -template -HWY_API Vec128 Ceil(const Vec128 v) { - return Vec128{wasm_f32x4_ceil(v.raw)}; +HWY_API Vec256 Ceil(const Vec256 v) { + return Vec256{wasm_f32x4_ceil(v.raw)}; } // Toward -infinity, aka floor -template -HWY_API Vec128 Floor(const Vec128 v) { - return Vec128{wasm_f32x4_floor(v.raw)}; +HWY_API Vec256 Floor(const Vec256 v) { + return Vec256{wasm_f32x4_floor(v.raw)}; } // ================================================== COMPARE // Comparisons fill a lane with 1-bits if the condition is true, else 0. -template -HWY_API Mask128 RebindMask(Simd /*tag*/, Mask128 m) { +template +HWY_API Mask256 RebindMask(Full256 /*tag*/, Mask256 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); - return Mask128{m.raw}; + return Mask256{m.raw}; } template -HWY_API Mask128 TestBit(Vec256 v, Vec256 bit) { +HWY_API Mask256 TestBit(Vec256 v, Vec256 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return (v & bit) == bit; } @@ -848,110 +736,90 @@ HWY_API Mask128 TestBit(Vec256 v, Vec256 bit) { // ------------------------------ Equality // Unsigned -template -HWY_API Mask128 operator==(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_eq(a.raw, b.raw)}; } -template -HWY_API Mask128 operator==(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i16x8_eq(a.raw, b.raw)}; } -template -HWY_API Mask128 operator==(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_eq(a.raw, b.raw)}; } // Signed -template -HWY_API Mask128 operator==(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_eq(a.raw, b.raw)}; } -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{wasm_i16x8_eq(a.raw, b.raw)}; } -template -HWY_API Mask128 operator==(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_eq(a.raw, b.raw)}; } // Float -template -HWY_API Mask128 operator==(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_f32x4_eq(a.raw, b.raw)}; +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_f32x4_eq(a.raw, b.raw)}; } // ------------------------------ Inequality // Unsigned -template -HWY_API Mask128 operator!=(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_ne(a.raw, b.raw)}; } -template -HWY_API Mask128 operator!=(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i16x8_ne(a.raw, b.raw)}; } -template -HWY_API Mask128 operator!=(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_ne(a.raw, b.raw)}; } // Signed -template -HWY_API Mask128 operator!=(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_ne(a.raw, b.raw)}; } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{wasm_i16x8_ne(a.raw, b.raw)}; } -template -HWY_API Mask128 operator!=(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_ne(a.raw, b.raw)}; } // Float -template -HWY_API Mask128 operator!=(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_f32x4_ne(a.raw, b.raw)}; +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_f32x4_ne(a.raw, b.raw)}; } // ------------------------------ Strict inequality -template -HWY_API Mask128 operator>(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i8x16_gt(a.raw, b.raw)}; +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_gt(a.raw, b.raw)}; } -template -HWY_API Mask128 operator>(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i16x8_gt(a.raw, b.raw)}; +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i16x8_gt(a.raw, b.raw)}; } -template -HWY_API Mask128 operator>(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_i32x4_gt(a.raw, b.raw)}; +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_gt(a.raw, b.raw)}; } -template -HWY_API Mask128 operator>(const Vec128 a, - const Vec128 b) { - const Simd d32; +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + const Rebind < int32_t, DFromV d32; const auto a32 = BitCast(d32, a); const auto b32 = BitCast(d32, b); // If the upper half is less than or greater, this is the answer. @@ -964,46 +832,42 @@ HWY_API Mask128 operator>(const Vec128 a, const auto gt = Or(lo_gt, m_gt); // Copy result in upper 32 bits to lower 32 bits. - return Mask128{wasm_i32x4_shuffle(gt, gt, 3, 3, 1, 1)}; + return Mask256{wasm_i32x4_shuffle(gt, gt, 3, 3, 1, 1)}; } template -HWY_API Mask128 operator>(Vec256 a, Vec256 b) { - const Simd du; +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + const Full256 du; const RebindToSigned di; const Vec256 msb = Set(du, (LimitsMax() >> 1) + 1); return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); } -template -HWY_API Mask128 operator>(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_f32x4_gt(a.raw, b.raw)}; +HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { + return Mask256{wasm_f32x4_gt(a.raw, b.raw)}; } template -HWY_API Mask128 operator<(const Vec256 a, const Vec256 b) { +HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { return operator>(b, a); } // ------------------------------ Weak inequality // Float <= >= -template -HWY_API Mask128 operator<=(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_f32x4_le(a.raw, b.raw)}; +HWY_API Mask256 operator<=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_f32x4_le(a.raw, b.raw)}; } -template -HWY_API Mask128 operator>=(const Vec128 a, - const Vec128 b) { - return Mask128{wasm_f32x4_ge(a.raw, b.raw)}; +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_f32x4_ge(a.raw, b.raw)}; } // ------------------------------ FirstN (Iota, Lt) template -HWY_API Mask128 FirstN(const Simd d, size_t num) { +HWY_API Mask256 FirstN(const Full256 d, size_t num) { const RebindToSigned di; // Signed comparisons may be cheaper. return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); } @@ -1046,6 +910,20 @@ HWY_API Vec256 Xor(Vec256 a, Vec256 b) { return Vec256{wasm_v128_xor(a.raw, b.raw)}; } +// ------------------------------ OrAnd + +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + // ------------------------------ Operator overloads (internal-only if float) template @@ -1068,14 +946,14 @@ HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { template HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); - const auto msb = SignBit(Simd()); + const auto msb = SignBit(Full256()); return Or(AndNot(msb, magn), And(msb, sign)); } template HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); - return Or(abs, And(SignBit(Simd()), sign)); + return Or(abs, And(SignBit(Full256()), sign)); } // ------------------------------ BroadcastSignBit (compare) @@ -1084,83 +962,82 @@ template HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return ShiftRight(v); } -template -HWY_API Vec128 BroadcastSignBit(const Vec128 v) { - return VecFromMask(Simd(), v < Zero(Simd())); +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return VecFromMask(Full256(), v < Zero(Full256())); } // ------------------------------ Mask // Mask and Vec are the same (true = FF..FF). template -HWY_API Mask128 MaskFromVec(const Vec256 v) { - return Mask128{v.raw}; +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{v.raw}; } template -HWY_API Vec256 VecFromMask(Simd /* tag */, Mask128 v) { - return Vec256{v.raw}; -} - -// DEPRECATED -template -HWY_API Vec256 VecFromMask(const Mask128 v) { +HWY_API Vec256 VecFromMask(Full256 /* tag */, Mask256 v) { return Vec256{v.raw}; } // mask ? yes : no template -HWY_API Vec256 IfThenElse(Mask128 mask, Vec256 yes, Vec256 no) { +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{wasm_v128_bitselect(yes.raw, no.raw, mask.raw)}; } // mask ? yes : 0 template -HWY_API Vec256 IfThenElseZero(Mask128 mask, Vec256 yes) { - return yes & VecFromMask(Simd(), mask); +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(Full256(), mask); } // mask ? 0 : no template -HWY_API Vec256 IfThenZeroElse(Mask128 mask, Vec256 no) { - return AndNot(VecFromMask(Simd(), mask), no); +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(Full256(), mask), no); +} + +template + HWY_API Vec256 < + T IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + HWY_ASSERT(0); } template HWY_API Vec256 ZeroIfNegative(Vec256 v) { - const Simd d; + const Full256 d; const auto zero = Zero(d); - return IfThenElse(Mask128{(v > zero).raw}, v, zero); + return IfThenElse(Mask256{(v > zero).raw}, v, zero); } // ------------------------------ Mask logical template -HWY_API Mask128 Not(const Mask128 m) { - return MaskFromVec(Not(VecFromMask(Simd(), m))); +HWY_API Mask256 Not(const Mask256 m) { + return MaskFromVec(Not(VecFromMask(Full256(), m))); } template -HWY_API Mask128 And(const Mask128 a, Mask128 b) { - const Simd d; +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template -HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { - const Simd d; +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template -HWY_API Mask128 Or(const Mask128 a, Mask128 b) { - const Simd d; +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template -HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { - const Simd d; +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } @@ -1176,8 +1053,8 @@ HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { template HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { - const Simd d; - Mask128 mask; + const Full256 d; + Mask256 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); // Move the highest valid bit of the shift count into the sign bit. @@ -1201,8 +1078,8 @@ HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { template HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { - const Simd d; - Mask128 mask; + const Full256 d; + Mask256 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); // Move the highest valid bit of the shift count into the sign bit. @@ -1232,8 +1109,8 @@ HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { template HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { - const Simd d; - Mask128 mask; + const Full256 d; + Mask256 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); // Move the highest valid bit of the shift count into the sign bit. @@ -1257,8 +1134,8 @@ HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { template HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { - const Simd d; - Mask128 mask; + const Full256 d; + Mask256 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); // Move the highest valid bit of the shift count into the sign bit. @@ -1289,57 +1166,38 @@ HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { // ------------------------------ Load template -HWY_API Vec128 Load(Full256 /* tag */, const T* HWY_RESTRICT aligned) { - return Vec128{wasm_v128_load(aligned)}; +HWY_API Vec256 Load(Full256 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec256{wasm_v128_load(aligned)}; } template -HWY_API Vec256 MaskedLoad(Mask128 m, Simd d, +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } -// Partial load. -template -HWY_API Vec256 Load(Simd /* tag */, const T* HWY_RESTRICT p) { - Vec256 v; - CopyBytes(p, &v); - return v; -} - // LoadU == Load. template -HWY_API Vec256 LoadU(Simd d, const T* HWY_RESTRICT p) { +HWY_API Vec256 LoadU(Full256 d, const T* HWY_RESTRICT p) { return Load(d, p); } // 128-bit SIMD => nothing to duplicate, same as an unaligned load. -template -HWY_API Vec256 LoadDup128(Simd d, const T* HWY_RESTRICT p) { +template +HWY_API Vec256 LoadDup128(Full256 d, const T* HWY_RESTRICT p) { return Load(d, p); } // ------------------------------ Store template -HWY_API void Store(Vec128 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { +HWY_API void Store(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { wasm_v128_store(aligned, v.raw); } -// Partial store. -template -HWY_API void Store(Vec256 v, Simd /* tag */, T* HWY_RESTRICT p) { - CopyBytes(&v, p); -} - -HWY_API void Store(const Vec128 v, Simd /* tag */, - float* HWY_RESTRICT p) { - *p = wasm_f32x4_extract_lane(v.raw, 0); -} - // StoreU == Store. template -HWY_API void StoreU(Vec256 v, Simd d, T* HWY_RESTRICT p) { +HWY_API void StoreU(Vec256 v, Full256 d, T* HWY_RESTRICT p) { Store(v, d, p); } @@ -1348,23 +1206,23 @@ HWY_API void StoreU(Vec256 v, Simd d, T* HWY_RESTRICT p) { // Same as aligned stores on non-x86. template -HWY_API void Stream(Vec256 v, Simd /* tag */, +HWY_API void Stream(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { wasm_v128_store(aligned, v.raw); } // ------------------------------ Scatter (Store) -template -HWY_API void ScatterOffset(Vec256 v, Simd d, T* HWY_RESTRICT base, - const Vec128 offset) { +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); - alignas(16) T lanes[N]; + alignas(32) T lanes[32 / sizeof(T)]; Store(v, d, lanes); - alignas(16) Offset offset_lanes[N]; - Store(offset, Simd(), offset_lanes); + alignas(32) Offset offset_lanes[32 / sizeof(T)]; + Store(offset, Full256(), offset_lanes); uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { @@ -1372,16 +1230,16 @@ HWY_API void ScatterOffset(Vec256 v, Simd d, T* HWY_RESTRICT base, } } -template -HWY_API void ScatterIndex(Vec256 v, Simd d, T* HWY_RESTRICT base, - const Vec128 index) { +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); - alignas(16) T lanes[N]; + alignas(32) T lanes[32 / sizeof(T)]; Store(v, d, lanes); - alignas(16) Index index_lanes[N]; - Store(index, Simd(), index_lanes); + alignas(32) Index index_lanes[32 / sizeof(T)]; + Store(index, Full256(), index_lanes); for (size_t i = 0; i < N; ++i) { base[index_lanes[i]] = lanes[i]; @@ -1391,14 +1249,14 @@ HWY_API void ScatterIndex(Vec256 v, Simd d, T* HWY_RESTRICT base, // ------------------------------ Gather (Load/Store) template -HWY_API Vec256 GatherOffset(const Simd d, const T* HWY_RESTRICT base, - const Vec128 offset) { +HWY_API Vec256 GatherOffset(const Full256 d, const T* HWY_RESTRICT base, + const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); - alignas(16) Offset offset_lanes[N]; - Store(offset, Simd(), offset_lanes); + alignas(32) Offset offset_lanes[32 / sizeof(T)]; + Store(offset, Full256(), offset_lanes); - alignas(16) T lanes[N]; + alignas(32) T lanes[32 / sizeof(T)]; const uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); @@ -1407,14 +1265,14 @@ HWY_API Vec256 GatherOffset(const Simd d, const T* HWY_RESTRICT base, } template -HWY_API Vec256 GatherIndex(const Simd d, const T* HWY_RESTRICT base, - const Vec128 index) { +HWY_API Vec256 GatherIndex(const Full256 d, const T* HWY_RESTRICT base, + const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); - alignas(16) Index index_lanes[N]; - Store(index, Simd(), index_lanes); + alignas(32) Index index_lanes[32 / sizeof(T)]; + Store(index, Full256(), index_lanes); - alignas(16) T lanes[N]; + alignas(32) T lanes[32 / sizeof(T)]; for (size_t i = 0; i < N; ++i) { lanes[i] = base[index_lanes[i]]; } @@ -1426,61 +1284,52 @@ HWY_API Vec256 GatherIndex(const Simd d, const T* HWY_RESTRICT base, // ------------------------------ Extract lane // Gets the single value stored in a vector/part. -template -HWY_API uint8_t GetLane(const Vec128 v) { +HWY_API uint8_t GetLane(const Vec256 v) { return wasm_i8x16_extract_lane(v.raw, 0); } -template -HWY_API int8_t GetLane(const Vec128 v) { +HWY_API int8_t GetLane(const Vec256 v) { return wasm_i8x16_extract_lane(v.raw, 0); } -template -HWY_API uint16_t GetLane(const Vec128 v) { +HWY_API uint16_t GetLane(const Vec256 v) { return wasm_i16x8_extract_lane(v.raw, 0); } -template -HWY_API int16_t GetLane(const Vec128 v) { +HWY_API int16_t GetLane(const Vec256 v) { return wasm_i16x8_extract_lane(v.raw, 0); } -template -HWY_API uint32_t GetLane(const Vec128 v) { +HWY_API uint32_t GetLane(const Vec256 v) { return wasm_i32x4_extract_lane(v.raw, 0); } -template -HWY_API int32_t GetLane(const Vec128 v) { +HWY_API int32_t GetLane(const Vec256 v) { return wasm_i32x4_extract_lane(v.raw, 0); } -template -HWY_API uint64_t GetLane(const Vec128 v) { +HWY_API uint64_t GetLane(const Vec256 v) { return wasm_i64x2_extract_lane(v.raw, 0); } -template -HWY_API int64_t GetLane(const Vec128 v) { +HWY_API int64_t GetLane(const Vec256 v) { return wasm_i64x2_extract_lane(v.raw, 0); } -template -HWY_API float GetLane(const Vec128 v) { +HWY_API float GetLane(const Vec256 v) { return wasm_f32x4_extract_lane(v.raw, 0); } // ------------------------------ LowerHalf template -HWY_API Vec128 LowerHalf(Simd /* tag */, Vec256 v) { - return Vec128{v.raw}; +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return Vec128{v.raw}; } template -HWY_API Vec128 LowerHalf(Vec256 v) { - return LowerHalf(Simd(), v); +HWY_API Vec128 LowerHalf(Vec256 v) { + return LowerHalf(Full128(), v); } // ------------------------------ ShiftLeftBytes // 0x01..0F, kBytes = 1 => 0x02..0F00 template -HWY_API Vec256 ShiftLeftBytes(Simd /* tag */, Vec256 v) { +HWY_API Vec256 ShiftLeftBytes(Full256 /* tag */, Vec256 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); const __i8x16 zero = wasm_i8x16_splat(0); switch (kBytes) { @@ -1554,20 +1403,20 @@ HWY_API Vec256 ShiftLeftBytes(Simd /* tag */, Vec256 v) { template HWY_API Vec256 ShiftLeftBytes(Vec256 v) { - return ShiftLeftBytes(Simd(), v); + return ShiftLeftBytes(Full256(), v); } // ------------------------------ ShiftLeftLanes template -HWY_API Vec256 ShiftLeftLanes(Simd d, const Vec256 v) { +HWY_API Vec256 ShiftLeftLanes(Full256 d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { - return ShiftLeftLanes(Simd(), v); + return ShiftLeftLanes(Full256(), v); } // ------------------------------ ShiftRightBytes @@ -1651,18 +1500,13 @@ HWY_API __i8x16 ShrBytes(const Vec256 v) { // 0x01..0F, kBytes = 1 => 0x0001..0E template -HWY_API Vec256 ShiftRightBytes(Simd /* tag */, Vec256 v) { - // For partial vectors, clear upper lanes so we shift in zeros. - if (N != 16 / sizeof(T)) { - const Vec128 vfull{v.raw}; - v = Vec256{IfThenElseZero(FirstN(Full256(), N), vfull).raw}; - } +HWY_API Vec256 ShiftRightBytes(Full256 /* tag */, Vec256 v) { return Vec256{detail::ShrBytes(v)}; } // ------------------------------ ShiftRightLanes template -HWY_API Vec256 ShiftRightLanes(Simd d, const Vec256 v) { +HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(BitCast(d8, v))); } @@ -1671,28 +1515,18 @@ HWY_API Vec256 ShiftRightLanes(Simd d, const Vec256 v) { // Full input: copy hi into lo (smaller instruction encoding than shifts). template -HWY_API Vec128 UpperHalf(Half> /* tag */, - const Vec128 v) { +HWY_API Vec128 UpperHalf(Full128 /* tag */, + const Vec256 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; } -HWY_API Vec128 UpperHalf(Half> /* tag */, +HWY_API Vec128 UpperHalf(Full128 /* tag */, const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; } -// Partial -template -HWY_API Vec128 UpperHalf(Half> /* tag */, - Vec256 v) { - const Simd d; - const auto vu = BitCast(RebindToUnsigned(), v); - const auto upper = BitCast(d, ShiftRightBytes(vu)); - return Vec128{upper.raw}; -} - // ------------------------------ CombineShiftRightBytes -template > +template > HWY_API V CombineShiftRightBytes(Full256 /* tag */, V hi, V lo) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); switch (kBytes) { @@ -1762,55 +1596,41 @@ HWY_API V CombineShiftRightBytes(Full256 /* tag */, V hi, V lo) { return hi; } -template > -HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { - constexpr size_t kSize = N * sizeof(T); - static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); - const Repartition d8; - const Full256 d_full8; - using V8 = VFromD; - const V8 hi8{BitCast(d8, hi).raw}; - // Move into most-significant bytes - const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); - const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(d_full8, hi8, lo8); - return V{BitCast(Full256(), r).raw}; -} - // ------------------------------ Broadcast/splat any lane // Unsigned -template -HWY_API Vec128 Broadcast(const Vec128 v) { +template +HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{wasm_i16x8_shuffle( + return Vec256{wasm_i16x8_shuffle( v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; } -template -HWY_API Vec128 Broadcast(const Vec128 v) { +template +HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{ + return Vec256{ wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; } // Signed -template -HWY_API Vec128 Broadcast(const Vec128 v) { +template +HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{wasm_i16x8_shuffle( - v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; + return Vec256{wasm_i16x8_shuffle(v.raw, v.raw, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane)}; } -template -HWY_API Vec128 Broadcast(const Vec128 v) { +template +HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{ + return Vec256{ wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; } // Float -template -HWY_API Vec128 Broadcast(const Vec128 v) { +template +HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); - return Vec128{ + return Vec256{ wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; } @@ -1818,35 +1638,35 @@ HWY_API Vec128 Broadcast(const Vec128 v) { // Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. // lane indices in [0, 16). -template -HWY_API Vec128 TableLookupBytes(const Vec256 bytes, - const Vec128 from) { +template +HWY_API Vec256 TableLookupBytes(const Vec256 bytes, + const Vec256 from) { // Not yet available in all engines, see // https://github.com/WebAssembly/simd/blob/bdcc304b2d379f4601c2c44ea9b44ed9484fde7e/proposals/simd/ImplementationStatus.md // V8 implementation of this had a bug, fixed on 2021-04-03: // https://chromium-review.googlesource.com/c/v8/v8/+/2822951 #if 0 - return Vec128{wasm_i8x16_swizzle(bytes.raw, from.raw)}; + return Vec256{wasm_i8x16_swizzle(bytes.raw, from.raw)}; #else - alignas(16) uint8_t control[16]; - alignas(16) uint8_t input[16]; - alignas(16) uint8_t output[16]; + alignas(32) uint8_t control[16]; + alignas(32) uint8_t input[16]; + alignas(32) uint8_t output[16]; wasm_v128_store(control, from.raw); wasm_v128_store(input, bytes.raw); for (size_t i = 0; i < 16; ++i) { output[i] = control[i] < 16 ? input[control[i]] : 0; } - return Vec128{wasm_v128_load(output)}; + return Vec256{wasm_v128_load(output)}; #endif } -template -HWY_API Vec128 TableLookupBytesOr0(const Vec256 bytes, - const Vec128 from) { - const Simd d; +template +HWY_API Vec256 TableLookupBytesOr0(const Vec256 bytes, + const Vec256 from) { + const Full256 d; // Mask size must match vector type, so cast everything to this type. Repartition di8; - Repartition> d_bytes8; + Repartition> d_bytes8; const auto msb = BitCast(di8, from) < Zero(di8); const auto lookup = TableLookupBytes(BitCast(d_bytes8, bytes), BitCast(di8, from)); @@ -1922,115 +1742,109 @@ struct Indices256 { __v128_u raw; }; -template -HWY_API Indices256 IndicesFromVec(Simd d, Vec256 vec) { +template +HWY_API Indices256 IndicesFromVec(Full256 d, Vec256 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); - return Indices256{}; + return Indices256{}; } -template -HWY_API Indices256 SetTableIndices(Simd d, const TI* idx) { +template +HWY_API Indices256 SetTableIndices(Full256 d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template -HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { using TI = MakeSigned; - const Simd d; - const Simd di; - return BitCast(d, TableLookupBytes(BitCast(di, v), Vec256{idx.raw})); + const Full256 d; + const Full256 di; + return BitCast(d, TableLookupBytes(BitCast(di, v), Vec256{idx.raw})); } // ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) -// Single lane: no change -template -HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { - return v; -} - -// Two lanes: shuffle -template -HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { - return Vec128{Shuffle2301(Vec128{v.raw}).raw}; -} - template -HWY_API Vec128 Reverse(Full256 /* tag */, const Vec128 v) { +HWY_API Vec256 Reverse(Full256 /* tag */, const Vec256 v) { return Shuffle01(v); } // Four lanes: shuffle template -HWY_API Vec128 Reverse(Full256 /* tag */, const Vec128 v) { +HWY_API Vec256 Reverse(Full256 /* tag */, const Vec256 v) { return Shuffle0123(v); } // 16-bit template -HWY_API Vec256 Reverse(Simd d, const Vec256 v) { +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { const RepartitionToWide> du32; return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); } +// ------------------------------ Reverse2 + +template +HWY_API Vec256 Reverse2(Full256 d, const Vec256 v) { + HWY_ASSERT(0); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { + HWY_ASSERT(0); +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { + HWY_ASSERT(0); +} + // ------------------------------ InterleaveLower -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{wasm_i8x16_shuffle( - a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 0, 16, 1, 17, 2, 18, + 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{ +HWY_API Vec256 InterleaveLower(Vec256 a, + Vec256 b) { + return Vec256{ wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +HWY_API Vec256 InterleaveLower(Vec256 a, + Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +HWY_API Vec256 InterleaveLower(Vec256 a, + Vec256 b) { + return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{wasm_i8x16_shuffle( - a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, + 19, 4, 20, 5, 21, 6, 22, 7, 23)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{ +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{ wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; } -// Additional overload for the optional Simd<> tag. +// Additional overload for the optional tag. template > -HWY_API V InterleaveLower(Simd /* tag */, V a, V b) { +HWY_API V InterleaveLower(Full256 /* tag */, V a, V b) { return InterleaveLower(a, b); } @@ -2039,89 +1853,66 @@ HWY_API V InterleaveLower(Simd /* tag */, V a, V b) { // All functions inside detail lack the required D parameter. namespace detail { -template -HWY_API Vec128 InterleaveUpper(Vec128 a, - Vec128 b) { - return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, - 26, 11, 27, 12, 28, 13, 29, 14, - 30, 15, 31)}; +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, 26, + 11, 27, 12, 28, 13, 29, 14, 30, 15, + 31)}; } -template -HWY_API Vec128 InterleaveUpper(Vec128 a, - Vec128 b) { - return Vec128{ +HWY_API Vec256 InterleaveUpper(Vec256 a, + Vec256 b) { + return Vec256{ wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; } -template -HWY_API Vec128 InterleaveUpper(Vec128 a, - Vec128 b) { - return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +HWY_API Vec256 InterleaveUpper(Vec256 a, + Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; } -template -HWY_API Vec128 InterleaveUpper(Vec128 a, - Vec128 b) { - return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +HWY_API Vec256 InterleaveUpper(Vec256 a, + Vec256 b) { + return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; } -template -HWY_API Vec128 InterleaveUpper(Vec128 a, - Vec128 b) { - return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, - 26, 11, 27, 12, 28, 13, 29, 14, - 30, 15, 31)}; +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, 26, + 11, 27, 12, 28, 13, 29, 14, 30, 15, + 31)}; } -template -HWY_API Vec128 InterleaveUpper(Vec128 a, - Vec128 b) { - return Vec128{ +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{ wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; } -template -HWY_API Vec128 InterleaveUpper(Vec128 a, - Vec128 b) { - return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; } -template -HWY_API Vec128 InterleaveUpper(Vec128 a, - Vec128 b) { - return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; } -template -HWY_API Vec128 InterleaveUpper(Vec128 a, - Vec128 b) { - return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; } } // namespace detail -// Full -template > +template > HWY_API V InterleaveUpper(Full256 /* tag */, V a, V b) { return detail::InterleaveUpper(a, b); } -// Partial -template > -HWY_API V InterleaveUpper(Simd d, V a, V b) { - const Half d2; - return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); -} - // ------------------------------ ZipLower/ZipUpper (InterleaveLower) // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. -template >> +template >> HWY_API VFromD ZipLower(Vec256 a, Vec256 b) { return BitCast(DW(), InterleaveLower(a, b)); } -template , class DW = RepartitionToWide> +template , class DW = RepartitionToWide> HWY_API VFromD ZipLower(DW dw, Vec256 a, Vec256 b) { return BitCast(dw, InterleaveLower(D(), a, b)); } -template , class DW = RepartitionToWide> +template , class DW = RepartitionToWide> HWY_API VFromD ZipUpper(DW dw, Vec256 a, Vec256 b) { return BitCast(dw, InterleaveUpper(D(), a, b)); } @@ -2132,8 +1923,7 @@ HWY_API VFromD ZipUpper(DW dw, Vec256 a, Vec256 b) { // N = N/2 + N/2 (upper half undefined) template -HWY_API Vec256 Combine(Simd d, Vec128 hi_half, - Vec128 lo_half) { +HWY_API Vec256 Combine(Full256 d, Vec128 hi_half, Vec128 lo_half) { const Half d2; const RebindToUnsigned du2; // Treat half-width input as one lane, and expand to two lanes. @@ -2146,79 +1936,54 @@ HWY_API Vec256 Combine(Simd d, Vec128 hi_half, // ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) template -HWY_API Vec256 ZeroExtendVector(Simd d, Vec128 lo) { - return IfThenElseZero(FirstN(d, N / 2), Vec256{lo.raw}); +HWY_API Vec256 ZeroExtendVector(Full256 d, Vec128 lo) { + return IfThenElseZero(FirstN(d, 16 / sizeof(T)), Vec256{lo.raw}); } // ------------------------------ ConcatLowerLower // hiH,hiL loH,loL |-> hiL,loL (= lower halves) template -HWY_API Vec128 ConcatLowerLower(Full256 /* tag */, const Vec128 hi, - const Vec128 lo) { - return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; -} -template -HWY_API Vec256 ConcatLowerLower(Simd d, const Vec256 hi, +HWY_API Vec256 ConcatLowerLower(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { - const Half d2; - return Combine(LowerHalf(d2, hi), LowerHalf(d2, lo)); + return Vec256{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; } // ------------------------------ ConcatUpperUpper template -HWY_API Vec128 ConcatUpperUpper(Full256 /* tag */, const Vec128 hi, - const Vec128 lo) { - return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; -} -template -HWY_API Vec256 ConcatUpperUpper(Simd d, const Vec256 hi, +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { - const Half d2; - return Combine(UpperHalf(d2, hi), UpperHalf(d2, lo)); + return Vec256{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; } // ------------------------------ ConcatLowerUpper template -HWY_API Vec128 ConcatLowerUpper(Full256 d, const Vec128 hi, - const Vec128 lo) { - return CombineShiftRightBytes<8>(d, hi, lo); -} -template -HWY_API Vec256 ConcatLowerUpper(Simd d, const Vec256 hi, +HWY_API Vec256 ConcatLowerUpper(Full256 d, const Vec256 hi, const Vec256 lo) { - const Half d2; - return Combine(LowerHalf(d2, hi), UpperHalf(d2, lo)); + return CombineShiftRightBytes<8>(d, hi, lo); } // ------------------------------ ConcatUpperLower template -HWY_API Vec256 ConcatUpperLower(Simd d, const Vec256 hi, +HWY_API Vec256 ConcatUpperLower(Full256 d, const Vec256 hi, const Vec256 lo) { return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); } // ------------------------------ ConcatOdd -// 32-bit full +// 32-bit template -HWY_API Vec128 ConcatOdd(Full256 /* tag */, Vec128 hi, Vec128 lo) { - return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 1, 3, 5, 7)}; -} - -// 32-bit partial -template -HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, - Vec128 lo) { - return InterleaveUpper(Simd(), lo, hi); +HWY_API Vec256 ConcatOdd(Full256 /* tag */, Vec256 hi, Vec256 lo) { + return Vec256{wasm_i32x4_shuffle(lo.raw, hi.raw, 1, 3, 5, 7)}; } // 64-bit full - no partial because we need at least two inputs to have // even/odd. template -HWY_API Vec128 ConcatOdd(Full256 /* tag */, Vec128 hi, Vec128 lo) { +HWY_API Vec256 ConcatOdd(Full256 /* tag */, Vec256 hi, Vec256 lo) { return InterleaveUpper(Full256(), lo, hi); } @@ -2226,24 +1991,29 @@ HWY_API Vec128 ConcatOdd(Full256 /* tag */, Vec128 hi, Vec128 lo) { // 32-bit full template -HWY_API Vec128 ConcatEven(Full256 /* tag */, Vec128 hi, Vec128 lo) { - return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 0, 2, 4, 6)}; -} - -// 32-bit partial -template -HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, - Vec128 lo) { - return InterleaveLower(Simd(), lo, hi); +HWY_API Vec256 ConcatEven(Full256 /* tag */, Vec256 hi, Vec256 lo) { + return Vec256{wasm_i32x4_shuffle(lo.raw, hi.raw, 0, 2, 4, 6)}; } // 64-bit full - no partial because we need at least two inputs to have // even/odd. template -HWY_API Vec128 ConcatEven(Full256 /* tag */, Vec128 hi, Vec128 lo) { +HWY_API Vec256 ConcatEven(Full256 /* tag */, Vec256 hi, Vec256 lo) { return InterleaveLower(Full256(), lo, hi); } +// ------------------------------ DupEven +template +HWY_API Vec256 DupEven(Vec256 v) { + HWY_ASSERT(0); +} + +// ------------------------------ DupOdd +template +HWY_API Vec256 DupOdd(Vec256 v) { + HWY_ASSERT(0); +} + // ------------------------------ OddEven namespace detail { @@ -2251,9 +2021,9 @@ namespace detail { template HWY_INLINE Vec256 OddEven(hwy::SizeTag<1> /* tag */, const Vec256 a, const Vec256 b) { - const Simd d; + const Full256 d; const Repartition d8; - alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + alignas(32) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); } @@ -2279,10 +2049,8 @@ template HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { return detail::OddEven(hwy::SizeTag(), a, b); } -template -HWY_API Vec128 OddEven(const Vec128 a, - const Vec128 b) { - return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; } // ------------------------------ OddEvenBlocks @@ -2298,76 +2066,72 @@ HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { return v; } +// ------------------------------ ReverseBlocks + +template +HWY_API Vec256 ReverseBlocks(Full256 /* tag */, const Vec256 v) { + return v; +} + // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) // Unsigned: zero-extend. -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_u16x8_extend_low_u8x16(v.raw)}; +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_u16x8_extend_low_u8x16(v.raw)}; } -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{ +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{ wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; } -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_u16x8_extend_low_u8x16(v.raw)}; +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_u16x8_extend_low_u8x16(v.raw)}; } -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{ +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{ wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; } -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_u32x4_extend_low_u16x8(v.raw)}; +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_u32x4_extend_low_u16x8(v.raw)}; } -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_u32x4_extend_low_u16x8(v.raw)}; +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_u32x4_extend_low_u16x8(v.raw)}; } // Signed: replicate sign bit. -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_i16x8_extend_low_i8x16(v.raw)}; +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_i16x8_extend_low_i8x16(v.raw)}; } -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{ +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{ wasm_i32x4_extend_low_i16x8(wasm_i16x8_extend_low_i8x16(v.raw))}; } -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_i32x4_extend_low_i16x8(v.raw)}; +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_i32x4_extend_low_i16x8(v.raw)}; } -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_f64x2_convert_low_i32x4(v.raw)}; +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_f64x2_convert_low_i32x4(v.raw)}; } -template -HWY_API Vec128 PromoteTo(Simd /* tag */, - const Vec128 v) { - const Simd di32; - const Simd du32; - const Simd df32; +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + const Full256 di32; + const Full256 du32; + const Full256 df32; // Expand to u32 so we can shift. - const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto bits16 = PromoteTo(du32, Vec256{v.raw}); const auto sign = ShiftRight<15>(bits16); const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); const auto mantissa = bits16 & Set(du32, 0x3FF); @@ -2382,9 +2146,8 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, return BitCast(df32, ShiftLeft<31>(sign) | bits32); } -template -HWY_API Vec128 PromoteTo(Simd df32, - const Vec128 v) { +HWY_API Vec256 PromoteTo(Full256 df32, + const Vec128 v) { const Rebind du16; const RebindToSigned di32; return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); @@ -2392,57 +2155,48 @@ HWY_API Vec128 PromoteTo(Simd df32, // ------------------------------ Demotions (full -> part w/ narrow lanes) -template -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; } -template -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; } -template -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); - return Vec128{ - wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; + return Vec128{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; } -template -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; } -template -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); - return Vec128{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; + return Vec128{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; } -template -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; } -template -HWY_API Vec128 DemoteTo(Simd /* di */, - const Vec128 v) { - return Vec128{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; +HWY_API Vec128 DemoteTo(Full128 /* di */, + const Vec256 v) { + return Vec128{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; } -template -HWY_API Vec128 DemoteTo(Simd /* tag */, - const Vec128 v) { - const Simd di; - const Simd du; - const Simd du16; +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const Full256 di; + const Full256 du; + const Full256 du16; const auto bits32 = BitCast(du, v); const auto sign = ShiftRight<31>(bits32); const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); @@ -2464,12 +2218,11 @@ HWY_API Vec128 DemoteTo(Simd /* tag */, const auto sign16 = ShiftLeft<15>(sign); const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); - return Vec128{DemoteTo(du16, bits16).raw}; + return Vec128{DemoteTo(du16, bits16).raw}; } -template -HWY_API Vec128 DemoteTo(Simd dbf16, - const Vec128 v) { +HWY_API Vec128 DemoteTo(Full128 dbf16, + const Vec256 v) { const Rebind di32; const Rebind du32; // for logical shift right const Rebind du16; @@ -2477,40 +2230,34 @@ HWY_API Vec128 DemoteTo(Simd dbf16, return BitCast(dbf16, DemoteTo(du16, bits_in_32)); } -template -HWY_API Vec128 ReorderDemote2To( - Simd dbf16, Vec128 a, Vec128 b) { +HWY_API Vec128 ReorderDemote2To(Full128 dbf16, + Vec256 a, Vec256 b) { const RebindToUnsigned du16; const Repartition du32; - const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); + const Vec256 b_in_even = ShiftRight<16>(BitCast(du32, b)); return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); } // For already range-limited input [0, 255]. -template -HWY_API Vec128 U8FromU32(const Vec128 v) { +HWY_API Vec256 U8FromU32(const Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); - return Vec128{ - wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; + return Vec256{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; } // ------------------------------ Convert i32 <=> f32 (Round) -template -HWY_API Vec128 ConvertTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_f32x4_convert_i32x4(v.raw)}; +HWY_API Vec256 ConvertTo(Full256 /* tag */, + const Vec256 v) { + return Vec256{wasm_f32x4_convert_i32x4(v.raw)}; } // Truncates (rounds toward zero). -template -HWY_API Vec128 ConvertTo(Simd /* tag */, - const Vec128 v) { - return Vec128{wasm_i32x4_trunc_sat_f32x4(v.raw)}; +HWY_API Vec256 ConvertTo(Full256 /* tag */, + const Vec256 v) { + return Vec256{wasm_i32x4_trunc_sat_f32x4(v.raw)}; } -template -HWY_API Vec128 NearestInt(const Vec128 v) { - return ConvertTo(Simd(), Round(v)); +HWY_API Vec256 NearestInt(const Vec256 v) { + return ConvertTo(Full256(), Round(v)); } // ================================================== MISC @@ -2520,49 +2267,49 @@ HWY_API Vec128 NearestInt(const Vec128 v) { namespace detail { template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { +HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { const RebindToUnsigned du; // Easier than Set(), which would require an >8-bit type, which would not // compile for T=uint8_t, N=1. const Vec256 vbits{wasm_i32x4_splat(static_cast(bits))}; // Replicate bytes 8x such that each byte contains the bit that governs it. - alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + alignas(32) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}; const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); - alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + alignas(32) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128}; return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { +HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { const RebindToUnsigned du; - alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + alignas(32) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { +HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { const RebindToUnsigned du; - alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + alignas(32) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { +HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { const RebindToUnsigned du; - alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + alignas(32) constexpr uint64_t kBit[8] = {1, 2}; return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); } } // namespace detail // `p` points to at least 8 readable bytes, not all of which need be valid. -template -HWY_API Mask128 LoadMaskBits(Simd d, - const uint8_t* HWY_RESTRICT bits) { +template +HWY_API Mask256 LoadMaskBits(Full256 d, + const uint8_t* HWY_RESTRICT bits) { uint64_t mask_bits = 0; CopyBytes<(N + 7) / 8>(bits, &mask_bits); return detail::LoadMaskBits(d, mask_bits); @@ -2576,7 +2323,7 @@ namespace detail { template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, const Mask128 mask) { - alignas(16) uint64_t lanes[2]; + alignas(32) uint64_t lanes[2]; wasm_v128_store(lanes, mask.raw); constexpr uint64_t kMagic = 0x103070F1F3F80ULL; @@ -2585,53 +2332,27 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, return (hi + lo); } -// 64-bit -template -HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, - const Mask128 mask) { - constexpr uint64_t kMagic = 0x103070F1F3F80ULL; - return (wasm_i64x2_extract_lane(mask.raw, 0) * kMagic) >> 56; -} - -// 32-bit or less: need masking -template -HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, - const Mask128 mask) { - uint64_t bytes = wasm_i64x2_extract_lane(mask.raw, 0); - // Clear potentially undefined bytes. - bytes &= (1ULL << (N * 8)) - 1; - constexpr uint64_t kMagic = 0x103070F1F3F80ULL; - return (bytes * kMagic) >> 56; -} - template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, - const Mask128 mask) { + const Mask256 mask) { // Remove useless lower half of each u16 while preserving the sign bit. const __i16x8 zero = wasm_i16x8_splat(0); - const Mask128 mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; + const Mask256 mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; return BitsFromMask(hwy::SizeTag<1>(), mask8); } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, - const Mask128 mask) { + const Mask256 mask) { const __i32x4 mask_i = static_cast<__i32x4>(mask.raw); const __i32x4 slice = wasm_i32x4_make(1, 2, 4, 8); const __i32x4 sliced_mask = wasm_v128_and(mask_i, slice); - alignas(16) uint32_t lanes[4]; + alignas(32) uint32_t lanes[4]; wasm_v128_store(lanes, sliced_mask); return lanes[0] | lanes[1] | lanes[2] | lanes[3]; } -// Returns the lowest N bits for the BitsFromMask result. -template -constexpr uint64_t OnlyActive(uint64_t bits) { - return ((N * sizeof(T)) == 16) ? bits : bits & ((1ull << N) - 1); -} - // Returns 0xFF for bytes with index >= N, otherwise 0. -template constexpr __i8x16 BytesAbove() { return /**/ (N == 0) ? wasm_i32x4_make(-1, -1, -1, -1) @@ -2661,8 +2382,8 @@ constexpr __i8x16 BytesAbove() { } template -HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { - return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { + return BitsFromMask(hwy::SizeTag(), mask); } template @@ -2679,7 +2400,7 @@ template HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { const __i32x4 var_shift = wasm_i32x4_make(1, 2, 4, 8); const __i32x4 shifted_bits = wasm_v128_and(m.raw, var_shift); - alignas(16) uint64_t lanes[2]; + alignas(32) uint64_t lanes[2]; wasm_v128_store(lanes, shifted_bits); return PopCount(lanes[0] | lanes[1]); } @@ -2688,8 +2409,8 @@ HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { // `p` points to at least 8 writable bytes. template -HWY_API size_t StoreMaskBits(const Simd /* tag */, - const Mask128 mask, uint8_t* bits) { +HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, + uint8_t* bits) { const uint64_t mask_bits = detail::BitsFromMask(mask); const size_t kNumBytes = (N + 7) / 8; CopyBytes(&mask_bits, bits); @@ -2697,19 +2418,10 @@ HWY_API size_t StoreMaskBits(const Simd /* tag */, } template -HWY_API size_t CountTrue(const Simd /* tag */, const Mask128 m) { +HWY_API size_t CountTrue(const Full256 /* tag */, const Mask128 m) { return detail::CountTrue(hwy::SizeTag(), m); } -// Partial vector -template -HWY_API size_t CountTrue(const Simd d, const Mask128 m) { - // Ensure all undefined bytes are 0. - const Mask128 mask{detail::BytesAbove()}; - return CountTrue(d, Mask128{AndNot(mask, m).raw}); -} - -// Full vector template HWY_API bool AllFalse(const Full256 d, const Mask128 m) { #if 0 @@ -2742,29 +2454,13 @@ HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { } // namespace detail template -HWY_API bool AllTrue(const Simd /* tag */, const Mask128 m) { +HWY_API bool AllTrue(const Full256 /* tag */, const Mask128 m) { return detail::AllTrue(hwy::SizeTag(), m); } -// Partial vectors - -template -HWY_API bool AllFalse(Simd /* tag */, const Mask128 m) { - // Ensure all undefined bytes are 0. - const Mask128 mask{detail::BytesAbove()}; - return AllFalse(Mask128{AndNot(mask, m).raw}); -} - -template -HWY_API bool AllTrue(const Simd d, const Mask128 m) { - // Ensure all undefined bytes are FF. - const Mask128 mask{detail::BytesAbove()}; - return AllTrue(d, Mask128{Or(mask, m).raw}); -} - template -HWY_API intptr_t FindFirstTrue(const Simd /* tag */, - const Mask128 mask) { +HWY_API intptr_t FindFirstTrue(const Full256 /* tag */, + const Mask256 mask) { const uint64_t bits = detail::BitsFromMask(mask); return bits ? Num0BitsBelowLS1Bit_Nonzero64(bits) : -1; } @@ -2776,16 +2472,16 @@ namespace detail { template HWY_INLINE Vec256 Idx16x8FromBits(const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 256); - const Simd d; + const Full256 d; const Rebind d8; - const Simd du; + const Full256 du; // We need byte indices for TableLookupBytes (one vector's worth for each of // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We // can instead store lane indices and convert to byte indices (2*lane + 0..1), // with the doubling baked into the table. Unpacking nibbles is likely more // costly than the higher cache footprint from storing bytes. - alignas(16) constexpr uint8_t table[256 * 8] = { + alignas(32) constexpr uint8_t table[256 * 8] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, @@ -2901,8 +2597,8 @@ HWY_INLINE Vec256 Idx16x8FromBits(const uint64_t mask_bits) { 4, 6, 8, 10, 12, 14, 0, 0, 0, 4, 6, 8, 10, 12, 14, 0, 2, 4, 6, 8, 10, 12, 14, 0, 0, 2, 4, 6, 8, 10, 12, 14}; - const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; - const Vec128 pairs = ZipLower(byte_idx, byte_idx); + const Vec256 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec256 pairs = ZipLower(byte_idx, byte_idx); return BitCast(d, pairs + Set(du, 0x0100)); } @@ -2911,7 +2607,7 @@ HWY_INLINE Vec256 Idx32x4FromBits(const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 16); // There are only 4 lanes, so we can afford to load the index vector directly. - alignas(16) constexpr uint8_t packed_array[16 * 16] = { + alignas(32) constexpr uint8_t packed_array[16 * 16] = { 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // @@ -2929,25 +2625,25 @@ HWY_INLINE Vec256 Idx32x4FromBits(const uint64_t mask_bits) { 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const Simd d; + const Full256 d; const Repartition d8; return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); } -#if HWY_CAP_INTEGER64 || HWY_CAP_FLOAT64 +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 template HWY_INLINE Vec256 Idx64x2FromBits(const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 4); // There are only 2 lanes, so we can afford to load the index vector directly. - alignas(16) constexpr uint8_t packed_array[4 * 16] = { + alignas(32) constexpr uint8_t packed_array[4 * 16] = { 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const Simd d; + const Full256 d; const Repartition d8; return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); } @@ -2960,8 +2656,8 @@ HWY_INLINE Vec256 Idx64x2FromBits(const uint64_t mask_bits) { template HWY_INLINE Vec256 Compress(hwy::SizeTag<2> /*tag*/, Vec256 v, const uint64_t mask_bits) { - const auto idx = detail::Idx16x8FromBits(mask_bits); - using D = Simd; + const auto idx = detail::Idx16x8FromBits(mask_bits); + using D = Full256; const RebindToSigned di; return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } @@ -2969,20 +2665,20 @@ HWY_INLINE Vec256 Compress(hwy::SizeTag<2> /*tag*/, Vec256 v, template HWY_INLINE Vec256 Compress(hwy::SizeTag<4> /*tag*/, Vec256 v, const uint64_t mask_bits) { - const auto idx = detail::Idx32x4FromBits(mask_bits); - using D = Simd; + const auto idx = detail::Idx32x4FromBits(mask_bits); + using D = Full256; const RebindToSigned di; return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } -#if HWY_CAP_INTEGER64 || HWY_CAP_FLOAT64 +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 template -HWY_INLINE Vec128 Compress(hwy::SizeTag<8> /*tag*/, - Vec128 v, - const uint64_t mask_bits) { - const auto idx = detail::Idx64x2FromBits(mask_bits); - using D = Simd; +HWY_INLINE Vec256 Compress(hwy::SizeTag<8> /*tag*/, + Vec256 v, + const uint64_t mask_bits) { + const auto idx = detail::Idx64x2FromBits(mask_bits); + using D = Full256; const RebindToSigned di; return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } @@ -2992,7 +2688,7 @@ HWY_INLINE Vec128 Compress(hwy::SizeTag<8> /*tag*/, } // namespace detail template -HWY_API Vec256 Compress(Vec256 v, const Mask128 mask) { +HWY_API Vec256 Compress(Vec256 v, const Mask256 mask) { const uint64_t mask_bits = detail::BitsFromMask(mask); return detail::Compress(hwy::SizeTag(), v, mask_bits); } @@ -3013,8 +2709,8 @@ HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { // ------------------------------ CompressStore template -HWY_API size_t CompressStore(Vec256 v, const Mask128 mask, - Simd d, T* HWY_RESTRICT unaligned) { +HWY_API size_t CompressStore(Vec256 v, const Mask256 mask, Full256 d, + T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(mask); const auto c = detail::Compress(hwy::SizeTag(), v, mask_bits); StoreU(c, d, unaligned); @@ -3023,16 +2719,16 @@ HWY_API size_t CompressStore(Vec256 v, const Mask128 mask, // ------------------------------ CompressBlendedStore template -HWY_API size_t CompressBlendedStore(Vec256 v, Mask128 m, Simd d, +HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { const RebindToUnsigned du; // so we can support fp16/bf16 using TU = TFromD; const uint64_t mask_bits = detail::BitsFromMask(m); const size_t count = PopCount(mask_bits); - const Mask128 store_mask = FirstN(du, count); - const Vec128 compressed = + const Mask256 store_mask = FirstN(du, count); + const Vec256 compressed = detail::Compress(hwy::SizeTag(), BitCast(du, v), mask_bits); - const Vec128 prev = BitCast(du, LoadU(d, unaligned)); + const Vec256 prev = BitCast(du, LoadU(d, unaligned)); StoreU(BitCast(d, IfThenElse(store_mask, compressed, prev)), d, unaligned); return count; } @@ -3041,7 +2737,7 @@ HWY_API size_t CompressBlendedStore(Vec256 v, Mask128 m, Simd d, template HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, - Simd d, T* HWY_RESTRICT unaligned) { + Full256 d, T* HWY_RESTRICT unaligned) { uint64_t mask_bits = 0; constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(bits, &mask_bits); @@ -3057,19 +2753,18 @@ HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, // ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, // TableLookupBytes) -// 128 bits -HWY_API void StoreInterleaved3(const Vec128 a, const Vec128 b, - const Vec128 c, Full256 d, +HWY_API void StoreInterleaved3(const Vec256 a, const Vec256 b, + const Vec256 c, Full256 d, uint8_t* HWY_RESTRICT unaligned) { const auto k5 = Set(d, 5); const auto k6 = Set(d, 6); // Shuffle (a,b,c) vector bytes to (MSB on left): r5, bgr[4:0]. // 0x80 so lanes to be filled from other vectors are 0 for blending. - alignas(16) static constexpr uint8_t tbl_r0[16] = { + alignas(32) static constexpr uint8_t tbl_r0[16] = { 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; - alignas(16) static constexpr uint8_t tbl_g0[16] = { + alignas(32) static constexpr uint8_t tbl_g0[16] = { 0x80, 0, 0x80, 0x80, 1, 0x80, // 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; const auto shuf_r0 = Load(d, tbl_r0); @@ -3102,86 +2797,12 @@ HWY_API void StoreInterleaved3(const Vec128 a, const Vec128 b, StoreU(int2, d, unaligned + 2 * 16); } -// 64 bits -HWY_API void StoreInterleaved3(const Vec128 a, - const Vec128 b, - const Vec128 c, Simd d, - uint8_t* HWY_RESTRICT unaligned) { - // Use full vectors for the shuffles and first result. - const Full256 d_full; - const auto k5 = Set(d_full, 5); - const auto k6 = Set(d_full, 6); - - const Vec128 full_a{a.raw}; - const Vec128 full_b{b.raw}; - const Vec128 full_c{c.raw}; - - // Shuffle (a,b,c) vector bytes to (MSB on left): r5, bgr[4:0]. - // 0x80 so lanes to be filled from other vectors are 0 for blending. - alignas(16) static constexpr uint8_t tbl_r0[16] = { - 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // - 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; - alignas(16) static constexpr uint8_t tbl_g0[16] = { - 0x80, 0, 0x80, 0x80, 1, 0x80, // - 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; - const auto shuf_r0 = Load(d_full, tbl_r0); - const auto shuf_g0 = Load(d_full, tbl_g0); // cannot reuse r0 due to 5 in MSB - const auto shuf_b0 = CombineShiftRightBytes<15>(d_full, shuf_g0, shuf_g0); - const auto r0 = TableLookupBytes(full_a, shuf_r0); // 5..4..3..2..1..0 - const auto g0 = TableLookupBytes(full_b, shuf_g0); // ..4..3..2..1..0. - const auto b0 = TableLookupBytes(full_c, shuf_b0); // .4..3..2..1..0.. - const auto int0 = r0 | g0 | b0; - StoreU(int0, d_full, unaligned + 0 * 16); - - // Second (HALF) vector: bgr[7:6], b5,g5 - const auto shuf_r1 = shuf_b0 + k6; // ..7..6.. - const auto shuf_g1 = shuf_r0 + k5; // .7..6..5 - const auto shuf_b1 = shuf_g0 + k5; // 7..6..5. - const auto r1 = TableLookupBytes(full_a, shuf_r1); - const auto g1 = TableLookupBytes(full_b, shuf_g1); - const auto b1 = TableLookupBytes(full_c, shuf_b1); - const decltype(Zero(d)) int1{(r1 | g1 | b1).raw}; - StoreU(int1, d, unaligned + 1 * 16); -} - -// <= 32 bits -template -HWY_API void StoreInterleaved3(const Vec128 a, - const Vec128 b, - const Vec128 c, - Simd /*tag*/, - uint8_t* HWY_RESTRICT unaligned) { - // Use full vectors for the shuffles and result. - const Full256 d_full; - - const Vec128 full_a{a.raw}; - const Vec128 full_b{b.raw}; - const Vec128 full_c{c.raw}; - - // Shuffle (a,b,c) vector bytes to bgr[3:0]. - // 0x80 so lanes to be filled from other vectors are 0 for blending. - alignas(16) static constexpr uint8_t tbl_r0[16] = { - 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, // - 0x80, 0x80, 0x80, 0x80}; - const auto shuf_r0 = Load(d_full, tbl_r0); - const auto shuf_g0 = CombineShiftRightBytes<15>(d_full, shuf_r0, shuf_r0); - const auto shuf_b0 = CombineShiftRightBytes<14>(d_full, shuf_r0, shuf_r0); - const auto r0 = TableLookupBytes(full_a, shuf_r0); // ......3..2..1..0 - const auto g0 = TableLookupBytes(full_b, shuf_g0); // .....3..2..1..0. - const auto b0 = TableLookupBytes(full_c, shuf_b0); // ....3..2..1..0.. - const auto int0 = r0 | g0 | b0; - alignas(16) uint8_t buf[16]; - StoreU(int0, d_full, buf); - CopyBytes(buf, unaligned); -} - // ------------------------------ StoreInterleaved4 -// 128 bits -HWY_API void StoreInterleaved4(const Vec128 v0, - const Vec128 v1, - const Vec128 v2, - const Vec128 v3, Full256 d8, +HWY_API void StoreInterleaved4(const Vec256 v0, + const Vec256 v1, + const Vec256 v2, + const Vec256 v3, Full256 d8, uint8_t* HWY_RESTRICT unaligned) { const RepartitionToWide d16; const RepartitionToWide d32; @@ -3200,69 +2821,20 @@ HWY_API void StoreInterleaved4(const Vec128 v0, StoreU(BitCast(d8, dcba_C), d8, unaligned + 3 * 16); } -// 64 bits -HWY_API void StoreInterleaved4(const Vec128 in0, - const Vec128 in1, - const Vec128 in2, - const Vec128 in3, - Simd /* tag */, - uint8_t* HWY_RESTRICT unaligned) { - // Use full vectors to reduce the number of stores. - const Full256 d_full8; - const RepartitionToWide d16; - const RepartitionToWide d32; - const Vec128 v0{in0.raw}; - const Vec128 v1{in1.raw}; - const Vec128 v2{in2.raw}; - const Vec128 v3{in3.raw}; - // let a,b,c,d denote v0..3. - const auto ba0 = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0 - const auto dc0 = ZipLower(d16, v2, v3); // d7 c7 .. d0 c0 - const auto dcba_0 = ZipLower(d32, ba0, dc0); // d..a3 d..a0 - const auto dcba_4 = ZipUpper(d32, ba0, dc0); // d..a7 d..a4 - StoreU(BitCast(d_full8, dcba_0), d_full8, unaligned + 0 * 16); - StoreU(BitCast(d_full8, dcba_4), d_full8, unaligned + 1 * 16); -} - -// <= 32 bits -template -HWY_API void StoreInterleaved4(const Vec128 in0, - const Vec128 in1, - const Vec128 in2, - const Vec128 in3, - Simd /*tag*/, - uint8_t* HWY_RESTRICT unaligned) { - // Use full vectors to reduce the number of stores. - const Full256 d_full8; - const RepartitionToWide d16; - const RepartitionToWide d32; - const Vec128 v0{in0.raw}; - const Vec128 v1{in1.raw}; - const Vec128 v2{in2.raw}; - const Vec128 v3{in3.raw}; - // let a,b,c,d denote v0..3. - const auto ba0 = ZipLower(d16, v0, v1); // b3 a3 .. b0 a0 - const auto dc0 = ZipLower(d16, v2, v3); // d3 c3 .. d0 c0 - const auto dcba_0 = ZipLower(d32, ba0, dc0); // d..a3 d..a0 - alignas(16) uint8_t buf[16]; - StoreU(BitCast(d_full8, dcba_0), d_full8, buf); - CopyBytes<4 * N>(buf, unaligned); -} - // ------------------------------ MulEven/Odd (Load) -HWY_INLINE Vec128 MulEven(const Vec128 a, - const Vec128 b) { - alignas(16) uint64_t mul[2]; +HWY_INLINE Vec256 MulEven(const Vec256 a, + const Vec256 b) { + alignas(32) uint64_t mul[2]; mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 0)), static_cast(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); return Load(Full256(), mul); } -HWY_INLINE Vec128 MulOdd(const Vec128 a, - const Vec128 b) { - alignas(16) uint64_t mul[2]; +HWY_INLINE Vec256 MulOdd(const Vec256 a, + const Vec256 b) { + alignas(32) uint64_t mul[2]; mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 1)), static_cast(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); @@ -3271,19 +2843,18 @@ HWY_INLINE Vec128 MulOdd(const Vec128 a, // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) -template -HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, - Vec128 a, - Vec128 b, - const Vec128 sum0, - Vec128& sum1) { +HWY_API Vec256 ReorderWidenMulAccumulate(Full256 df32, + Vec256 a, + Vec256 b, + const Vec256 sum0, + Vec256& sum1) { const Repartition du16; const RebindToUnsigned du32; - const Vec128 zero = Zero(du16); - const Vec128 a0 = ZipLower(du32, zero, BitCast(du16, a)); - const Vec128 a1 = ZipUpper(du32, zero, BitCast(du16, a)); - const Vec128 b0 = ZipLower(du32, zero, BitCast(du16, b)); - const Vec128 b1 = ZipUpper(du32, zero, BitCast(du16, b)); + const Vec256 zero = Zero(du16); + const Vec256 a0 = ZipLower(du32, zero, BitCast(du16, a)); + const Vec256 a1 = ZipUpper(du32, zero, BitCast(du16, a)); + const Vec256 b0 = ZipLower(du32, zero, BitCast(du16, b)); + const Vec256 b1 = ZipUpper(du32, zero, BitCast(du16, b)); sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); } @@ -3292,220 +2863,100 @@ HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, namespace detail { -// N=1 for any T: no-op -template -HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, - const Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, - const Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, - const Vec128 v) { - return v; -} - // u32/i32/f32: -// N=2 template -HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, - const Vec128 v10) { - return v10 + Vec128{Shuffle2301(Vec128{v10.raw}).raw}; -} -template -HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, - const Vec128 v10) { - return Min(v10, Vec128{Shuffle2301(Vec128{v10.raw}).raw}); -} -template -HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, - const Vec128 v10) { - return Max(v10, Vec128{Shuffle2301(Vec128{v10.raw}).raw}); -} - -// N=4 (full) -template -HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, - const Vec128 v3210) { - const Vec128 v1032 = Shuffle1032(v3210); - const Vec128 v31_20_31_20 = v3210 + v1032; - const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); +HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const Vec256 v1032 = Shuffle1032(v3210); + const Vec256 v31_20_31_20 = v3210 + v1032; + const Vec256 v20_31_20_31 = Shuffle0321(v31_20_31_20); return v20_31_20_31 + v31_20_31_20; } template -HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, - const Vec128 v3210) { - const Vec128 v1032 = Shuffle1032(v3210); - const Vec128 v31_20_31_20 = Min(v3210, v1032); - const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); +HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const Vec256 v1032 = Shuffle1032(v3210); + const Vec256 v31_20_31_20 = Min(v3210, v1032); + const Vec256 v20_31_20_31 = Shuffle0321(v31_20_31_20); return Min(v20_31_20_31, v31_20_31_20); } template -HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, - const Vec128 v3210) { - const Vec128 v1032 = Shuffle1032(v3210); - const Vec128 v31_20_31_20 = Max(v3210, v1032); - const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); +HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const Vec256 v1032 = Shuffle1032(v3210); + const Vec256 v31_20_31_20 = Max(v3210, v1032); + const Vec256 v20_31_20_31 = Shuffle0321(v31_20_31_20); return Max(v20_31_20_31, v31_20_31_20); } // u64/i64/f64: -// N=2 (full) template -HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, - const Vec128 v10) { - const Vec128 v01 = Shuffle01(v10); +HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const Vec256 v01 = Shuffle01(v10); return v10 + v01; } template -HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, - const Vec128 v10) { - const Vec128 v01 = Shuffle01(v10); +HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const Vec256 v01 = Shuffle01(v10); return Min(v10, v01); } template -HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, - const Vec128 v10) { - const Vec128 v01 = Shuffle01(v10); +HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const Vec256 v01 = Shuffle01(v10); return Max(v10, v01); } // u16/i16 -template +template HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { - const Repartition> d32; + const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(d32, Min(even, odd)); // Also broadcast into odd lanes. - return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); + return BitCast(Full256(), Or(min, ShiftLeft<16>(min))); } -template +template HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { - const Repartition> d32; + const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(d32, Max(even, odd)); // Also broadcast into odd lanes. - return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); + return BitCast(Full256(), Or(min, ShiftLeft<16>(min))); } } // namespace detail // Supported for u/i/f 32/64. Returns the same value in each lane. template -HWY_API Vec256 SumOfLanes(Simd /* tag */, const Vec256 v) { +HWY_API Vec256 SumOfLanes(Full256 /* tag */, const Vec256 v) { return detail::SumOfLanes(hwy::SizeTag(), v); } template -HWY_API Vec256 MinOfLanes(Simd /* tag */, const Vec256 v) { +HWY_API Vec256 MinOfLanes(Full256 /* tag */, const Vec256 v) { return detail::MinOfLanes(hwy::SizeTag(), v); } template -HWY_API Vec256 MaxOfLanes(Simd /* tag */, const Vec256 v) { +HWY_API Vec256 MaxOfLanes(Full256 /* tag */, const Vec256 v) { return detail::MaxOfLanes(hwy::SizeTag(), v); } -// ================================================== DEPRECATED +// ------------------------------ Lt128 template -HWY_API size_t StoreMaskBits(const Mask128 mask, uint8_t* bits) { - return StoreMaskBits(Simd(), mask, bits); -} +HWY_INLINE Mask256 Lt128(Full256 d, Vec256 a, Vec256 b) {} template -HWY_API bool AllTrue(const Mask128 mask) { - return AllTrue(Simd(), mask); -} +HWY_INLINE Vec256 Min128(Full256 d, Vec256 a, Vec256 b) {} template -HWY_API bool AllFalse(const Mask128 mask) { - return AllFalse(Simd(), mask); -} - -template -HWY_API size_t CountTrue(const Mask128 mask) { - return CountTrue(Simd(), mask); -} - -template -HWY_API Vec256 SumOfLanes(const Vec256 v) { - return SumOfLanes(Simd(), v); -} -template -HWY_API Vec256 MinOfLanes(const Vec256 v) { - return MinOfLanes(Simd(), v); -} -template -HWY_API Vec256 MaxOfLanes(const Vec256 v) { - return MaxOfLanes(Simd(), v); -} - -template -HWY_API Vec128 UpperHalf(Vec256 v) { - return UpperHalf(Half>(), v); -} - -template -HWY_API Vec256 ShiftRightBytes(const Vec256 v) { - return ShiftRightBytes(Simd(), v); -} - -template -HWY_API Vec256 ShiftRightLanes(const Vec256 v) { - return ShiftRightLanes(Simd(), v); -} - -template -HWY_API Vec256 CombineShiftRightBytes(Vec256 hi, Vec256 lo) { - return CombineShiftRightBytes(Simd(), hi, lo); -} - -template -HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { - return InterleaveUpper(Simd(), a, b); -} - -template > -HWY_API VFromD> ZipUpper(Vec256 a, Vec256 b) { - return InterleaveUpper(RepartitionToWide(), a, b); -} - -template -HWY_API Vec128 Combine(Vec128 hi2, Vec128 lo2) { - return Combine(Simd(), hi2, lo2); -} - -template -HWY_API Vec128 ZeroExtendVector(Vec128 lo) { - return ZeroExtendVector(Simd(), lo); -} - -template -HWY_API Vec256 ConcatLowerLower(Vec256 hi, Vec256 lo) { - return ConcatLowerLower(Simd(), hi, lo); -} - -template -HWY_API Vec256 ConcatUpperUpper(Vec256 hi, Vec256 lo) { - return ConcatUpperUpper(Simd(), hi, lo); -} - -template -HWY_API Vec256 ConcatLowerUpper(const Vec256 hi, const Vec256 lo) { - return ConcatLowerUpper(Simd(), hi, lo); -} - -template -HWY_API Vec256 ConcatUpperLower(Vec256 hi, Vec256 lo) { - return ConcatUpperLower(Simd(), hi, lo); -} +HWY_INLINE Vec256 Max128(Full256 d, Vec256 a, Vec256 b) {} // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE diff --git a/third_party/highway/hwy/ops/x86_128-inl.h b/third_party/highway/hwy/ops/x86_128-inl.h index 0bb7e269c8fd..5ef18d8562b8 100644 --- a/third_party/highway/hwy/ops/x86_128-inl.h +++ b/third_party/highway/hwy/ops/x86_128-inl.h @@ -43,7 +43,23 @@ namespace hwy { namespace HWY_NAMESPACE { template -using Full128 = Simd; +using Full32 = Simd; + +template +using Full64 = Simd; + +template +using Full128 = Simd; + +#if HWY_TARGET <= HWY_AVX2 +template +using Full256 = Simd; +#endif + +#if HWY_TARGET <= HWY_AVX3 +template +using Full512 = Simd; +#endif namespace detail { @@ -94,14 +110,15 @@ class Vec128 { Raw raw; }; -// Forward-declare for use by DeduceD, see below. template -class Vec256; -template -class Vec512; +using Vec64 = Vec128; #if HWY_TARGET <= HWY_AVX3 +// Forward-declare for use by DeduceD, see below. +template +class Vec512; + namespace detail { // Template arg: sizeof(lane type) @@ -147,24 +164,34 @@ struct Mask128 { #endif // HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX2 +// Forward-declare for use by DeduceD, see below. +template +class Vec256; +#endif + namespace detail { -// Deduce Simd from Vec* (pointers because Vec256/512 may be +// Deduce Simd from Vec* (pointers because Vec256/512 may be // incomplete types at this point; this is simpler than avoiding multiple // definitions of DFromV via #if) struct DeduceD { template - Simd operator()(const Vec128*) const { - return Simd(); + Simd operator()(const Vec128*) const { + return Simd(); } +#if HWY_TARGET <= HWY_AVX2 template - Simd operator()(const Vec256*) const { - return Simd(); + Full256 operator()(const hwy::HWY_NAMESPACE::Vec256*) const { + return Full256(); } +#endif +#if HWY_TARGET <= HWY_AVX3 template - Simd operator()(const Vec512*) const { - return Simd(); + Full512 operator()(const hwy::HWY_NAMESPACE::Vec512*) const { + return Full512(); } +#endif }; // Workaround for MSVC v19.14: alias with a dependent type fails to specialize. @@ -209,7 +236,7 @@ struct BitCastFromInteger128 { }; template -HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128{BitCastFromInteger128()(v.raw)}; } @@ -217,7 +244,7 @@ HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, } // namespace detail template -HWY_API Vec128 BitCast(Simd d, +HWY_API Vec128 BitCast(Simd d, Vec128 v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } @@ -226,15 +253,15 @@ HWY_API Vec128 BitCast(Simd d, // Returns an all-zero vector/part. template -HWY_API Vec128 Zero(Simd /* tag */) { +HWY_API Vec128 Zero(Simd /* tag */) { return Vec128{_mm_setzero_si128()}; } template -HWY_API Vec128 Zero(Simd /* tag */) { +HWY_API Vec128 Zero(Simd /* tag */) { return Vec128{_mm_setzero_ps()}; } template -HWY_API Vec128 Zero(Simd /* tag */) { +HWY_API Vec128 Zero(Simd /* tag */) { return Vec128{_mm_setzero_pd()}; } @@ -245,45 +272,48 @@ using VFromD = decltype(Zero(D())); // Returns a vector/part with all lanes set to "t". template -HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { +HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { return Vec128{_mm_set1_epi8(static_cast(t))}; // NOLINT } template -HWY_API Vec128 Set(Simd /* tag */, const uint16_t t) { +HWY_API Vec128 Set(Simd /* tag */, + const uint16_t t) { return Vec128{_mm_set1_epi16(static_cast(t))}; // NOLINT } template -HWY_API Vec128 Set(Simd /* tag */, const uint32_t t) { +HWY_API Vec128 Set(Simd /* tag */, + const uint32_t t) { return Vec128{_mm_set1_epi32(static_cast(t))}; } template -HWY_API Vec128 Set(Simd /* tag */, const uint64_t t) { +HWY_API Vec128 Set(Simd /* tag */, + const uint64_t t) { return Vec128{ _mm_set1_epi64x(static_cast(t))}; // NOLINT } template -HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { +HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { return Vec128{_mm_set1_epi8(static_cast(t))}; // NOLINT } template -HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { +HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { return Vec128{_mm_set1_epi16(static_cast(t))}; // NOLINT } template -HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { +HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { return Vec128{_mm_set1_epi32(t)}; } template -HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { +HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { return Vec128{ _mm_set1_epi64x(static_cast(t))}; // NOLINT } template -HWY_API Vec128 Set(Simd /* tag */, const float t) { +HWY_API Vec128 Set(Simd /* tag */, const float t) { return Vec128{_mm_set1_ps(t)}; } template -HWY_API Vec128 Set(Simd /* tag */, const double t) { +HWY_API Vec128 Set(Simd /* tag */, const double t) { return Vec128{_mm_set1_pd(t)}; } @@ -292,17 +322,17 @@ HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") // Returns a vector with uninitialized elements. template -HWY_API Vec128 Undefined(Simd /* tag */) { +HWY_API Vec128 Undefined(Simd /* tag */) { // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC // generate an XOR instruction. return Vec128{_mm_undefined_si128()}; } template -HWY_API Vec128 Undefined(Simd /* tag */) { +HWY_API Vec128 Undefined(Simd /* tag */) { return Vec128{_mm_undefined_ps()}; } template -HWY_API Vec128 Undefined(Simd /* tag */) { +HWY_API Vec128 Undefined(Simd /* tag */) { return Vec128{_mm_undefined_pd()}; } @@ -343,7 +373,7 @@ template HWY_API uint64_t GetLane(const Vec128 v) { #if HWY_ARCH_X86_32 alignas(16) uint64_t lanes[2]; - Store(v, Simd(), lanes); + Store(v, Simd(), lanes); return lanes[0]; #else return static_cast(_mm_cvtsi128_si64(v.raw)); @@ -353,7 +383,7 @@ template HWY_API int64_t GetLane(const Vec128 v) { #if HWY_ARCH_X86_32 alignas(16) int64_t lanes[2]; - Store(v, Simd(), lanes); + Store(v, Simd(), lanes); return lanes[0]; #else return _mm_cvtsi128_si64(v.raw); @@ -441,13 +471,47 @@ HWY_API Vec128 Xor(const Vec128 a, template HWY_API Vec128 Not(const Vec128 v) { - using TU = MakeUnsigned; + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; #if HWY_TARGET <= HWY_AVX3 - const __m128i vu = BitCast(Simd(), v).raw; - return BitCast(Simd(), - Vec128{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)}); + const __m128i vu = BitCast(du, v).raw; + return BitCast(d, VU{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)}); #else - return Xor(v, BitCast(Simd(), Vec128{_mm_set1_epi32(-1)})); + return Xor(v, BitCast(d, VU{_mm_set1_epi32(-1)})); +#endif +} + +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast( + d, VU{_mm_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); #endif } @@ -517,12 +581,12 @@ HWY_API Vec128 PopulationCount(Vec128 v) { template HWY_API Vec128 Neg(const Vec128 v) { - return Xor(v, SignBit(Simd())); + return Xor(v, SignBit(DFromV())); } template HWY_API Vec128 Neg(const Vec128 v) { - return Zero(Simd()) - v; + return Zero(DFromV()) - v; } // ------------------------------ Abs @@ -532,7 +596,7 @@ template HWY_API Vec128 Abs(const Vec128 v) { #if HWY_COMPILER_MSVC // Workaround for incorrect codegen? (reaches breakpoint) - const auto zero = Zero(Simd()); + const auto zero = Zero(DFromV()); return Vec128{_mm_max_epi8(v.raw, (zero - v).raw)}; #else return Vec128{_mm_abs_epi8(v.raw)}; @@ -550,12 +614,12 @@ HWY_API Vec128 Abs(const Vec128 v) { template HWY_API Vec128 Abs(const Vec128 v) { const Vec128 mask{_mm_set1_epi32(0x7FFFFFFF)}; - return v & BitCast(Simd(), mask); + return v & BitCast(DFromV(), mask); } template HWY_API Vec128 Abs(const Vec128 v) { const Vec128 mask{_mm_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; - return v & BitCast(Simd(), mask); + return v & BitCast(DFromV(), mask); } // ------------------------------ CopySign @@ -565,11 +629,11 @@ HWY_API Vec128 CopySign(const Vec128 magn, const Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); - const Simd d; + const DFromV d; const auto msb = SignBit(d); #if HWY_TARGET <= HWY_AVX3 - const Rebind, decltype(d)> du; + const RebindToUnsigned du; // Truth table for msb, magn, sign | bitwise msb ? sign : mag // 0 0 0 | 0 // 0 0 1 | 0 @@ -582,7 +646,7 @@ HWY_API Vec128 CopySign(const Vec128 magn, // The lane size does not matter because we are not using predication. const __m128i out = _mm_ternarylogic_epi32( BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); - return BitCast(d, decltype(Zero(du)){out}); + return BitCast(d, VFromD{out}); #else return Or(AndNot(msb, magn), And(msb, sign)); #endif @@ -595,7 +659,7 @@ HWY_API Vec128 CopySignToAbs(const Vec128 abs, // AVX3 can also handle abs < 0, so no extra action needed. return CopySign(abs, sign); #else - return Or(abs, And(SignBit(Simd()), sign)); + return Or(abs, And(SignBit(DFromV()), sign)); #endif } @@ -946,7 +1010,7 @@ HWY_API Vec128 VecFromMask(const Mask128 v) { } template -HWY_API Vec128 VecFromMask(const Simd /* tag */, +HWY_API Vec128 VecFromMask(const Simd /* tag */, const Mask128 v) { return Vec128{v.raw}; } @@ -957,7 +1021,7 @@ HWY_API Vec128 VecFromMask(const Simd /* tag */, template HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, Vec128 no) { - const auto vmask = VecFromMask(Simd(), mask); + const auto vmask = VecFromMask(DFromV(), mask); return Or(And(vmask, yes), AndNot(vmask, no)); } @@ -987,43 +1051,43 @@ HWY_API Vec128 IfThenElse(const Mask128 mask, // mask ? yes : 0 template HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { - return yes & VecFromMask(Simd(), mask); + return yes & VecFromMask(DFromV(), mask); } // mask ? 0 : no template HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { - return AndNot(VecFromMask(Simd(), mask), no); + return AndNot(VecFromMask(DFromV(), mask), no); } // ------------------------------ Mask logical template HWY_API Mask128 Not(const Mask128 m) { - return MaskFromVec(Not(VecFromMask(Simd(), m))); + return MaskFromVec(Not(VecFromMask(Simd(), m))); } template HWY_API Mask128 And(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Or(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { - const Simd d; + const Simd d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } @@ -1114,7 +1178,7 @@ HWY_API Vec128 Shuffle0123(const Vec128 v) { // Comparisons set a mask bit to 1 if the condition is true, else 0. template -HWY_API Mask128 RebindMask(Simd /*tag*/, +HWY_API Mask128 RebindMask(Simd /*tag*/, Mask128 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return Mask128{m.raw}; @@ -1316,11 +1380,13 @@ HWY_API Mask128 MaskFromVec(const Vec128 v) { // There do not seem to be native floating-point versions of these instructions. template HWY_API Mask128 MaskFromVec(const Vec128 v) { - return Mask128{MaskFromVec(BitCast(Simd(), v)).raw}; + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; } template HWY_API Mask128 MaskFromVec(const Vec128 v) { - return Mask128{MaskFromVec(BitCast(Simd(), v)).raw}; + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; } template @@ -1354,7 +1420,8 @@ HWY_API Vec128 VecFromMask(const Mask128 v) { } template -HWY_API Vec128 VecFromMask(Simd /* tag */, const Mask128 v) { +HWY_API Vec128 VecFromMask(Simd /* tag */, + const Mask128 v) { return VecFromMask(v); } @@ -1363,10 +1430,11 @@ HWY_API Vec128 VecFromMask(Simd /* tag */, const Mask128 v) { // Comparisons fill a lane with 1-bits if the condition is true, else 0. template -HWY_API Mask128 RebindMask(Simd /*tag*/, Mask128 m) { +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); - const Simd d; - return MaskFromVec(BitCast(Simd(), VecFromMask(d, m))); + const Simd d; + return MaskFromVec(BitCast(Simd(), VecFromMask(d, m))); } template @@ -1397,8 +1465,8 @@ template HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { #if HWY_TARGET == HWY_SSSE3 - const Simd d32; - const Simd d64; + const Simd d32; + const Simd d64; const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); const auto cmp64 = cmp32 & Shuffle2301(cmp32); return MaskFromVec(BitCast(d64, cmp64)); @@ -1427,8 +1495,9 @@ template HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { // Same as signed ==; avoid duplicating the SSSE3 version. - const Simd du; - return RebindMask(Simd(), BitCast(du, a) == BitCast(du, b)); + const DFromV d; + RebindToUnsigned du; + return RebindMask(d, BitCast(du, a) == BitCast(du, b)); } // Float @@ -1481,7 +1550,7 @@ HWY_API Mask128 operator>(Vec128 a, template HWY_API Mask128 operator>(Vec128 a, Vec128 b) { - const Simd du; + const DFromV du; const RebindToSigned di; const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); @@ -1546,7 +1615,7 @@ HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { // ------------------------------ FirstN (Iota, Lt) template -HWY_API Mask128 FirstN(const Simd d, size_t num) { +HWY_API Mask128 FirstN(const Simd d, size_t num) { #if HWY_TARGET <= HWY_AVX3 (void)d; const uint64_t all = (1ull << N) - 1; @@ -1608,19 +1677,17 @@ HWY_API Vec128 LoadU(Full128 /* tag */, } template -HWY_API Vec128 Load(Simd /* tag */, - const T* HWY_RESTRICT p) { +HWY_API Vec64 Load(Full64 /* tag */, const T* HWY_RESTRICT p) { #if HWY_SAFE_PARTIAL_LOAD_STORE __m128i v = _mm_setzero_si128(); CopyBytes<8>(p, &v); - return Vec128{v}; + return Vec64{v}; #else - return Vec128{ - _mm_loadl_epi64(reinterpret_cast(p))}; + return Vec64{_mm_loadl_epi64(reinterpret_cast(p))}; #endif } -HWY_API Vec128 Load(Simd /* tag */, +HWY_API Vec128 Load(Full64 /* tag */, const float* HWY_RESTRICT p) { #if HWY_SAFE_PARTIAL_LOAD_STORE __m128 v = _mm_setzero_ps(); @@ -1632,18 +1699,18 @@ HWY_API Vec128 Load(Simd /* tag */, #endif } -HWY_API Vec128 Load(Simd /* tag */, - const double* HWY_RESTRICT p) { +HWY_API Vec64 Load(Full64 /* tag */, + const double* HWY_RESTRICT p) { #if HWY_SAFE_PARTIAL_LOAD_STORE __m128d v = _mm_setzero_pd(); CopyBytes<8>(p, &v); - return Vec128{v}; + return Vec64{v}; #else - return Vec128{_mm_load_sd(p)}; + return Vec64{_mm_load_sd(p)}; #endif } -HWY_API Vec128 Load(Simd /* tag */, +HWY_API Vec128 Load(Full32 /* tag */, const float* HWY_RESTRICT p) { #if HWY_SAFE_PARTIAL_LOAD_STORE __m128 v = _mm_setzero_ps(); @@ -1656,7 +1723,7 @@ HWY_API Vec128 Load(Simd /* tag */, // Any <= 32 bit except template -HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { +HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { constexpr size_t kSize = sizeof(T) * N; #if HWY_SAFE_PARTIAL_LOAD_STORE __m128 v = _mm_setzero_ps(); @@ -1671,19 +1738,19 @@ HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { // For < 128 bit, LoadU == Load. template -HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { return Load(d, p); } // 128-bit SIMD => nothing to duplicate, same as an unaligned load. template -HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { +HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { return LoadU(d, p); } // Returns a vector with lane i=[0, N) set to "first" + i. template -HWY_API Vec128 Iota(const Simd d, const T2 first) { +HWY_API Vec128 Iota(const Simd d, const T2 first) { HWY_ALIGN T lanes[16 / sizeof(T)]; for (size_t i = 0; i < 16 / sizeof(T); ++i) { lanes[i] = static_cast(first + static_cast(i)); @@ -1696,40 +1763,40 @@ HWY_API Vec128 Iota(const Simd d, const T2 first) { #if HWY_TARGET <= HWY_AVX3 template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, const T* HWY_RESTRICT aligned) { return Vec128{_mm_maskz_load_epi32(m.raw, aligned)}; } template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, const T* HWY_RESTRICT aligned) { return Vec128{_mm_maskz_load_epi64(m.raw, aligned)}; } template HWY_API Vec128 MaskedLoad(Mask128 m, - Simd /* tag */, + Simd /* tag */, const float* HWY_RESTRICT aligned) { return Vec128{_mm_maskz_load_ps(m.raw, aligned)}; } template HWY_API Vec128 MaskedLoad(Mask128 m, - Simd /* tag */, + Simd /* tag */, const double* HWY_RESTRICT aligned) { return Vec128{_mm_maskz_load_pd(m.raw, aligned)}; } // There is no load_epi8/16, so use loadu instead. template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, const T* HWY_RESTRICT aligned) { return Vec128{_mm_maskz_loadu_epi8(m.raw, aligned)}; } template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, const T* HWY_RESTRICT aligned) { return Vec128{_mm_maskz_loadu_epi16(m.raw, aligned)}; } @@ -1737,21 +1804,21 @@ HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, #elif HWY_TARGET == HWY_AVX2 template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, const T* HWY_RESTRICT aligned) { auto aligned_p = reinterpret_cast(aligned); // NOLINT return Vec128{_mm_maskload_epi32(aligned_p, m.raw)}; } template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, const T* HWY_RESTRICT aligned) { auto aligned_p = reinterpret_cast(aligned); // NOLINT return Vec128{_mm_maskload_epi64(aligned_p, m.raw)}; } template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, const float* HWY_RESTRICT aligned) { const Vec128 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); @@ -1759,7 +1826,7 @@ HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, } template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, const double* HWY_RESTRICT aligned) { const Vec128 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); @@ -1768,7 +1835,7 @@ HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, // There is no maskload_epi8/16, so blend instead. template * = nullptr> -HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } @@ -1777,7 +1844,7 @@ HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, // Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). template -HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } @@ -1813,15 +1880,14 @@ HWY_API void StoreU(const Vec128 v, Full128 /* tag */, } template -HWY_API void Store(Vec128 v, Simd /* tag */, - T* HWY_RESTRICT p) { +HWY_API void Store(Vec64 v, Full64 /* tag */, T* HWY_RESTRICT p) { #if HWY_SAFE_PARTIAL_LOAD_STORE CopyBytes<8>(&v, p); #else _mm_storel_epi64(reinterpret_cast<__m128i*>(p), v.raw); #endif } -HWY_API void Store(const Vec128 v, Simd /* tag */, +HWY_API void Store(const Vec128 v, Full64 /* tag */, float* HWY_RESTRICT p) { #if HWY_SAFE_PARTIAL_LOAD_STORE CopyBytes<8>(&v, p); @@ -1829,7 +1895,7 @@ HWY_API void Store(const Vec128 v, Simd /* tag */, _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); #endif } -HWY_API void Store(const Vec128 v, Simd /* tag */, +HWY_API void Store(const Vec64 v, Full64 /* tag */, double* HWY_RESTRICT p) { #if HWY_SAFE_PARTIAL_LOAD_STORE CopyBytes<8>(&v, p); @@ -1840,10 +1906,10 @@ HWY_API void Store(const Vec128 v, Simd /* tag */, // Any <= 32 bit except template -HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { +HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { CopyBytes(&v, p); } -HWY_API void Store(const Vec128 v, Simd /* tag */, +HWY_API void Store(const Vec128 v, Full32 /* tag */, float* HWY_RESTRICT p) { #if HWY_SAFE_PARTIAL_LOAD_STORE CopyBytes<4>(&v, p); @@ -1854,7 +1920,7 @@ HWY_API void Store(const Vec128 v, Simd /* tag */, // For < 128 bit, StoreU == Store. template -HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { +HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { Store(v, d, p); } @@ -1976,7 +2042,13 @@ HWY_API Vec128 operator-(const Vec128 a, return Vec128{_mm_sub_pd(a.raw, b.raw)}; } -// ------------------------------ Saturating addition +// ------------------------------ SumsOf8 +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128{_mm_sad_epu8(v.raw, _mm_setzero_si128())}; +} + +// ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. @@ -2004,7 +2076,7 @@ HWY_API Vec128 SaturatedAdd(const Vec128 a, return Vec128{_mm_adds_epi16(a.raw, b.raw)}; } -// ------------------------------ Saturating subtraction +// ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. @@ -2086,7 +2158,7 @@ HWY_API Vec128 MulEven(const Vec128 a, template // N=1 or 2 HWY_API Vec128 MulEven(const Vec128 a, const Vec128 b) { - return Set(Simd(), int64_t(GetLane(a)) * GetLane(b)); + return Set(Simd(), int64_t(GetLane(a)) * GetLane(b)); } HWY_API Vec128 MulEven(const Vec128 a, const Vec128 b) { @@ -2139,8 +2211,9 @@ template HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { // Same as unsigned; avoid duplicating the SSSE3 code. - const Simd du; - return BitCast(Simd(), BitCast(du, a) * BitCast(du, b)); + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); } // ------------------------------ ShiftLeft @@ -2175,7 +2248,7 @@ HWY_API Vec128 ShiftLeft(const Vec128 v) { template HWY_API Vec128 ShiftLeft(const Vec128 v) { - const Simd d8; + const DFromV d8; // Use raw instead of BitCast to support N=1. const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; return kBits == 1 @@ -2200,7 +2273,7 @@ HWY_API Vec128 ShiftRight(const Vec128 v) { template HWY_API Vec128 ShiftRight(const Vec128 v) { - const Simd d8; + const DFromV d8; // Use raw instead of BitCast to support N=1. const Vec128 shifted{ ShiftRight(Vec128{v.raw}).raw}; @@ -2218,8 +2291,8 @@ HWY_API Vec128 ShiftRight(const Vec128 v) { template HWY_API Vec128 ShiftRight(const Vec128 v) { - const Simd di; - const Simd du; + const DFromV di; + const RebindToUnsigned du; const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); return (shifted ^ shifted_sign) - shifted_sign; @@ -2255,7 +2328,8 @@ HWY_API Vec128 RotateRight(const Vec128 v) { template HWY_API Vec128 BroadcastSignBit(const Vec128 v) { - return VecFromMask(v < Zero(Simd())); + const DFromV d; + return VecFromMask(v < Zero(d)); } template @@ -2270,14 +2344,16 @@ HWY_API Vec128 BroadcastSignBit(const Vec128 v) { template HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; #if HWY_TARGET <= HWY_AVX3 + (void)d; return Vec128{_mm_srai_epi64(v.raw, 63)}; #elif HWY_TARGET == HWY_AVX2 || HWY_TARGET == HWY_SSE4 - return VecFromMask(v < Zero(Simd())); + return VecFromMask(v < Zero(d)); #else // Efficient Lt() requires SSE4.2 and BLENDVPD requires SSE4.1. 32-bit shift // avoids generating a zero. - const Simd d32; + const RepartitionToNarrow d32; const auto sign = ShiftRight<31>(BitCast(d32, v)); return Vec128{ _mm_shuffle_epi32(sign.raw, _MM_SHUFFLE(3, 3, 1, 1))}; @@ -2289,7 +2365,7 @@ HWY_API Vec128 Abs(const Vec128 v) { #if HWY_TARGET <= HWY_AVX3 return Vec128{_mm_abs_epi64(v.raw)}; #else - const auto zero = Zero(Simd()); + const auto zero = Zero(DFromV()); return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); #endif } @@ -2299,8 +2375,8 @@ HWY_API Vec128 ShiftRight(const Vec128 v) { #if HWY_TARGET <= HWY_AVX3 return Vec128{_mm_srai_epi64(v.raw, kBits)}; #else - const Simd di; - const Simd du; + const DFromV di; + const RebindToUnsigned du; const auto right = BitCast(di, ShiftRight(BitCast(du, v))); const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); return right | sign; @@ -2310,7 +2386,7 @@ HWY_API Vec128 ShiftRight(const Vec128 v) { // ------------------------------ ZeroIfNegative (BroadcastSignBit) template HWY_API Vec128 ZeroIfNegative(Vec128 v) { - const Simd d; + const DFromV d; #if HWY_TARGET == HWY_SSSE3 const RebindToSigned di; const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); @@ -2320,6 +2396,39 @@ HWY_API Vec128 ZeroIfNegative(Vec128 v) { return IfThenElse(mask, Zero(d), v); } +// ------------------------------ IfNegativeThenElse +template +HWY_API Vec128 IfNegativeThenElse(const Vec128 v, + const Vec128 yes, + const Vec128 no) { + // int8: IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + // 16-bit: no native blendv, so copy sign to lower byte's MSB. + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToFloat df; + + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + return BitCast(d, IfThenElse(MaskFromVec(BitCast(df, v)), BitCast(df, yes), + BitCast(df, no))); +} + // ------------------------------ ShiftLeftSame template @@ -2358,7 +2467,7 @@ HWY_API Vec128 ShiftLeftSame(const Vec128 v, template HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { - const Simd d8; + const DFromV d8; // Use raw instead of BitCast to support N=1. const Vec128 shifted{ ShiftLeftSame(Vec128>{v.raw}, bits).raw}; @@ -2386,7 +2495,7 @@ HWY_API Vec128 ShiftRightSame(const Vec128 v, template HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { - const Simd d8; + const DFromV d8; // Use raw instead of BitCast to support N=1. const Vec128 shifted{ ShiftRightSame(Vec128{v.raw}, bits).raw}; @@ -2410,8 +2519,8 @@ HWY_API Vec128 ShiftRightSame(const Vec128 v, #if HWY_TARGET <= HWY_AVX3 return Vec128{_mm_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; #else - const Simd di; - const Simd du; + const DFromV di; + const RebindToUnsigned du; const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); return right | sign; @@ -2420,8 +2529,8 @@ HWY_API Vec128 ShiftRightSame(const Vec128 v, template HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { - const Simd di; - const Simd du; + const DFromV di; + const RebindToUnsigned du; const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto shifted_sign = BitCast(di, Set(du, static_cast(0x80 >> bits))); @@ -2443,9 +2552,8 @@ HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { return Vec128{_mm_mul_pd(a.raw, b.raw)}; } -HWY_API Vec128 operator*(const Vec128 a, - const Vec128 b) { - return Vec128{_mm_mul_sd(a.raw, b.raw)}; +HWY_API Vec64 operator*(const Vec64 a, const Vec64 b) { + return Vec64{_mm_mul_sd(a.raw, b.raw)}; } template @@ -2462,9 +2570,8 @@ HWY_API Vec128 operator/(const Vec128 a, const Vec128 b) { return Vec128{_mm_div_pd(a.raw, b.raw)}; } -HWY_API Vec128 operator/(const Vec128 a, - const Vec128 b) { - return Vec128{_mm_div_sd(a.raw, b.raw)}; +HWY_API Vec64 operator/(const Vec64 a, const Vec64 b) { + return Vec64{_mm_div_sd(a.raw, b.raw)}; } // Approximate reciprocal @@ -2587,8 +2694,8 @@ template HWY_API Vec128 Sqrt(const Vec128 v) { return Vec128{_mm_sqrt_pd(v.raw)}; } -HWY_API Vec128 Sqrt(const Vec128 v) { - return Vec128{_mm_sqrt_sd(_mm_setzero_pd(), v.raw)}; +HWY_API Vec64 Sqrt(const Vec64 v) { + return Vec64{_mm_sqrt_sd(_mm_setzero_pd(), v.raw)}; } // Approximate reciprocal square root @@ -2607,8 +2714,9 @@ namespace detail { template HWY_INLINE HWY_MAYBE_UNUSED Vec128 MinU(const Vec128 a, const Vec128 b) { - const Simd du; - const RebindToSigned di; + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); return IfThenElse(gt, b, a); @@ -2702,8 +2810,9 @@ namespace detail { template HWY_INLINE HWY_MAYBE_UNUSED Vec128 MaxU(const Vec128 a, const Vec128 b) { - const Simd du; - const RebindToSigned di; + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); return IfThenElse(gt, a, b); @@ -2798,17 +2907,17 @@ HWY_API Vec128 Max(const Vec128 a, // On clang6, we see incorrect code generated for _mm_stream_pi, so // round even partial vectors up to 16 bytes. template -HWY_API void Stream(Vec128 v, Simd /* tag */, +HWY_API void Stream(Vec128 v, Simd /* tag */, T* HWY_RESTRICT aligned) { _mm_stream_si128(reinterpret_cast<__m128i*>(aligned), v.raw); } template -HWY_API void Stream(const Vec128 v, Simd /* tag */, +HWY_API void Stream(const Vec128 v, Simd /* tag */, float* HWY_RESTRICT aligned) { _mm_stream_ps(aligned, v.raw); } template -HWY_API void Stream(const Vec128 v, Simd /* tag */, +HWY_API void Stream(const Vec128 v, Simd /* tag */, double* HWY_RESTRICT aligned) { _mm_stream_pd(aligned, v.raw); } @@ -2828,7 +2937,7 @@ namespace detail { template HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec128 v, - Simd /* tag */, T* HWY_RESTRICT base, + Simd /* tag */, T* HWY_RESTRICT base, const Vec128 offset) { if (N == 4) { _mm_i32scatter_epi32(base, offset.raw, v.raw, 1); @@ -2839,7 +2948,7 @@ HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec128 v, } template HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec128 v, - Simd /* tag */, T* HWY_RESTRICT base, + Simd /* tag */, T* HWY_RESTRICT base, const Vec128 index) { if (N == 4) { _mm_i32scatter_epi32(base, index.raw, v.raw, 4); @@ -2851,7 +2960,7 @@ HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec128 v, template HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec128 v, - Simd /* tag */, T* HWY_RESTRICT base, + Simd /* tag */, T* HWY_RESTRICT base, const Vec128 offset) { if (N == 2) { _mm_i64scatter_epi64(base, offset.raw, v.raw, 1); @@ -2862,7 +2971,7 @@ HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec128 v, } template HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec128 v, - Simd /* tag */, T* HWY_RESTRICT base, + Simd /* tag */, T* HWY_RESTRICT base, const Vec128 index) { if (N == 2) { _mm_i64scatter_epi64(base, index.raw, v.raw, 8); @@ -2875,20 +2984,21 @@ HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec128 v, } // namespace detail template -HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); } template -HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); } template -HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, float* HWY_RESTRICT base, const Vec128 offset) { if (N == 4) { @@ -2899,7 +3009,7 @@ HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, } } template -HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, float* HWY_RESTRICT base, const Vec128 index) { if (N == 4) { @@ -2911,7 +3021,7 @@ HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, } template -HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, double* HWY_RESTRICT base, const Vec128 offset) { if (N == 2) { @@ -2922,7 +3032,7 @@ HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, } } template -HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, double* HWY_RESTRICT base, const Vec128 index) { if (N == 2) { @@ -2935,7 +3045,8 @@ HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, #else // HWY_TARGET <= HWY_AVX3 template -HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); @@ -2943,7 +3054,7 @@ HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, Store(v, d, lanes); alignas(16) Offset offset_lanes[N]; - Store(offset, Simd(), offset_lanes); + Store(offset, Rebind(), offset_lanes); uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { @@ -2952,7 +3063,7 @@ HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, } template -HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); @@ -2960,7 +3071,7 @@ HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, Store(v, d, lanes); alignas(16) Index index_lanes[N]; - Store(index, Simd(), index_lanes); + Store(index, Rebind(), index_lanes); for (size_t i = 0; i < N; ++i) { base[index_lanes[i]] = lanes[i]; @@ -2974,13 +3085,13 @@ HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, #if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 template -HWY_API Vec128 GatherOffset(const Simd d, +HWY_API Vec128 GatherOffset(const Simd d, const T* HWY_RESTRICT base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); alignas(16) Offset offset_lanes[N]; - Store(offset, Simd(), offset_lanes); + Store(offset, Rebind(), offset_lanes); alignas(16) T lanes[N]; const uint8_t* base_bytes = reinterpret_cast(base); @@ -2991,12 +3102,13 @@ HWY_API Vec128 GatherOffset(const Simd d, } template -HWY_API Vec128 GatherIndex(const Simd d, const T* HWY_RESTRICT base, +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); alignas(16) Index index_lanes[N]; - Store(index, Simd(), index_lanes); + Store(index, Rebind(), index_lanes); alignas(16) T lanes[N]; for (size_t i = 0; i < N; ++i) { @@ -3011,7 +3123,7 @@ namespace detail { template HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<4> /* tag */, - Simd /* d */, + Simd /* d */, const T* HWY_RESTRICT base, const Vec128 offset) { return Vec128{_mm_i32gather_epi32( @@ -3019,7 +3131,7 @@ HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<4> /* tag */, } template HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<4> /* tag */, - Simd /* d */, + Simd /* d */, const T* HWY_RESTRICT base, const Vec128 index) { return Vec128{_mm_i32gather_epi32( @@ -3028,7 +3140,7 @@ HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<4> /* tag */, template HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<8> /* tag */, - Simd /* d */, + Simd /* d */, const T* HWY_RESTRICT base, const Vec128 offset) { return Vec128{_mm_i64gather_epi64( @@ -3036,7 +3148,7 @@ HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<8> /* tag */, } template HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<8> /* tag */, - Simd /* d */, + Simd /* d */, const T* HWY_RESTRICT base, const Vec128 index) { return Vec128{_mm_i64gather_epi64( @@ -3046,37 +3158,37 @@ HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<8> /* tag */, } // namespace detail template -HWY_API Vec128 GatherOffset(Simd d, const T* HWY_RESTRICT base, +HWY_API Vec128 GatherOffset(Simd d, const T* HWY_RESTRICT base, const Vec128 offset) { return detail::GatherOffset(hwy::SizeTag(), d, base, offset); } template -HWY_API Vec128 GatherIndex(Simd d, const T* HWY_RESTRICT base, +HWY_API Vec128 GatherIndex(Simd d, const T* HWY_RESTRICT base, const Vec128 index) { return detail::GatherIndex(hwy::SizeTag(), d, base, index); } template -HWY_API Vec128 GatherOffset(Simd /* tag */, +HWY_API Vec128 GatherOffset(Simd /* tag */, const float* HWY_RESTRICT base, const Vec128 offset) { return Vec128{_mm_i32gather_ps(base, offset.raw, 1)}; } template -HWY_API Vec128 GatherIndex(Simd /* tag */, +HWY_API Vec128 GatherIndex(Simd /* tag */, const float* HWY_RESTRICT base, const Vec128 index) { return Vec128{_mm_i32gather_ps(base, index.raw, 4)}; } template -HWY_API Vec128 GatherOffset(Simd /* tag */, +HWY_API Vec128 GatherOffset(Simd /* tag */, const double* HWY_RESTRICT base, const Vec128 offset) { return Vec128{_mm_i64gather_pd(base, offset.raw, 1)}; } template -HWY_API Vec128 GatherIndex(Simd /* tag */, +HWY_API Vec128 GatherIndex(Simd /* tag */, const double* HWY_RESTRICT base, const Vec128 index) { return Vec128{_mm_i64gather_pd(base, index.raw, 8)}; @@ -3092,44 +3204,45 @@ HWY_DIAGNOSTICS(pop) // Returns upper/lower half of a vector. template -HWY_API Vec128 LowerHalf(Simd /* tag */, Vec128 v) { +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { return Vec128{v.raw}; } template HWY_API Vec128 LowerHalf(Vec128 v) { - return LowerHalf(Simd(), v); + return LowerHalf(Simd(), v); } // ------------------------------ ShiftLeftBytes template -HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); return Vec128{_mm_slli_si128(v.raw, kBytes)}; } template HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { - return ShiftLeftBytes(Simd(), v); + return ShiftLeftBytes(DFromV(), v); } // ------------------------------ ShiftLeftLanes template -HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { - return ShiftLeftLanes(Simd(), v); + return ShiftLeftLanes(DFromV(), v); } // ------------------------------ ShiftRightBytes template -HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); // For partial vectors, clear upper lanes so we shift in zeros. if (N != 16 / sizeof(T)) { @@ -3141,34 +3254,33 @@ HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { // ------------------------------ ShiftRightLanes template -HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { const Repartition d8; - return BitCast(d, ShiftRightBytes(BitCast(d8, v))); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ UpperHalf (ShiftRightBytes) // Full input: copy hi into lo (smaller instruction encoding than shifts). template -HWY_API Vec128 UpperHalf(Half> /* tag */, - Vec128 v) { - return Vec128{_mm_unpackhi_epi64(v.raw, v.raw)}; +HWY_API Vec64 UpperHalf(Half> /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_epi64(v.raw, v.raw)}; } -HWY_API Vec128 UpperHalf(Simd /* tag */, Vec128 v) { +HWY_API Vec128 UpperHalf(Full64 /* tag */, Vec128 v) { return Vec128{_mm_movehl_ps(v.raw, v.raw)}; } -HWY_API Vec128 UpperHalf(Simd /* tag */, - Vec128 v) { - return Vec128{_mm_unpackhi_pd(v.raw, v.raw)}; +HWY_API Vec64 UpperHalf(Full64 /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_pd(v.raw, v.raw)}; } // Partial template -HWY_API Vec128 UpperHalf(Half> /* tag */, +HWY_API Vec128 UpperHalf(Half> /* tag */, Vec128 v) { - const Simd d; - const auto vu = BitCast(RebindToUnsigned(), v); - const auto upper = BitCast(d, ShiftRightBytes(vu)); + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); return Vec128{upper.raw}; } @@ -3183,7 +3295,7 @@ HWY_API V CombineShiftRightBytes(Full128 d, V hi, V lo) { template > -HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { +HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { constexpr size_t kSize = N * sizeof(T); static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); const Repartition d8; @@ -3280,10 +3392,10 @@ struct Indices128 { template -HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); #if HWY_IS_DEBUG_BUILD - const Simd di; + const Rebind di; HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && AllTrue(di, Lt(vec, Set(di, N)))); #endif @@ -3312,13 +3424,14 @@ HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { template -HWY_API Indices128 IndicesFromVec(Simd /* tag */, - Vec128 vec) { +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); #if HWY_IS_DEBUG_BUILD - const Simd di; + const Rebind di; HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && AllTrue(di, Lt(vec, Set(di, static_cast(N))))); +#else + (void)d; #endif // No change - even without AVX3, we can shuffle+blend. @@ -3326,7 +3439,7 @@ HWY_API Indices128 IndicesFromVec(Simd /* tag */, } template -HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } @@ -3334,8 +3447,8 @@ HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { template HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { #if HWY_TARGET <= HWY_AVX2 - const Simd d; - const Simd df; + const DFromV d; + const RebindToFloat df; const Vec128 perm{_mm_permutevar_ps(BitCast(df, v).raw, idx.raw)}; return BitCast(d, perm); #else @@ -3349,8 +3462,8 @@ HWY_API Vec128 TableLookupLanes(Vec128 v, #if HWY_TARGET <= HWY_AVX2 return Vec128{_mm_permutevar_ps(v.raw, idx.raw)}; #else - const Simd di; - const Simd df; + const DFromV df; + const RebindToSigned di; return BitCast(df, TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); #endif @@ -3402,17 +3515,25 @@ HWY_API Vec128 TableLookupLanes(Vec128 v, #endif } +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + // ------------------------------ Reverse (Shuffle0123, Shuffle2301) // Single lane: no change template -HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { return v; } // Two lanes: shuffle template -HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 Reverse(Full64 /* tag */, const Vec128 v) { return Vec128{Shuffle2301(Vec128{v.raw}).raw}; } @@ -3429,7 +3550,7 @@ HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { // 16-bit template -HWY_API Vec128 Reverse(Simd d, const Vec128 v) { +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { #if HWY_TARGET <= HWY_AVX3 if (N == 1) return v; if (N == 2) { @@ -3447,6 +3568,79 @@ HWY_API Vec128 Reverse(Simd d, const Vec128 v) { #endif } +// ------------------------------ Reverse2 + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + const RebindToSigned di; + // 4x 16-bit: a single shufflelo suffices. + if (N == 4) { + return BitCast(d, Vec128{_mm_shufflelo_epi16( + BitCast(di, v).raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET <= HWY_AVX3 + alignas(16) constexpr int16_t kReverse4[8] = {3, 2, 1, 0, 7, 6, 5, 4}; + const Vec128 idx = Load(di, kReverse4); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); +#endif +} + +// 4x 32-bit: use Shuffle0123 +template +HWY_API Vec128 Reverse4(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, Vec128 /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec128 idx = Load(di, kReverse8); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec128 Reverse8(Simd /* tag */, Vec128 /* v */) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + // ------------------------------ InterleaveLower // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides @@ -3506,9 +3700,9 @@ HWY_API Vec128 InterleaveLower(const Vec128 a, return Vec128{_mm_unpacklo_pd(a.raw, b.raw)}; } -// Additional overload for the optional Simd<> tag. -template > -HWY_API V InterleaveLower(Simd /* tag */, V a, V b) { +// Additional overload for the optional tag (also for 256/512). +template +HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { return InterleaveLower(a, b); } @@ -3570,7 +3764,7 @@ HWY_API V InterleaveUpper(Full128 /* tag */, V a, V b) { // Partial template > -HWY_API V InterleaveUpper(Simd d, V a, V b) { +HWY_API V InterleaveUpper(Simd d, V a, V b) { const Half d2; return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); } @@ -3579,19 +3773,17 @@ HWY_API V InterleaveUpper(Simd d, V a, V b) { // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. -template >> -HWY_API VFromD ZipLower(Vec128 a, Vec128 b) { +template >> +HWY_API VFromD ZipLower(V a, V b) { return BitCast(DW(), InterleaveLower(a, b)); } -template , - class DW = RepartitionToWide> -HWY_API VFromD ZipLower(DW dw, Vec128 a, Vec128 b) { +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { return BitCast(dw, InterleaveLower(D(), a, b)); } -template , - class DW = RepartitionToWide> -HWY_API VFromD ZipUpper(DW dw, Vec128 a, Vec128 b) { +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { return BitCast(dw, InterleaveUpper(D(), a, b)); } @@ -3601,7 +3793,7 @@ HWY_API VFromD ZipUpper(DW dw, Vec128 a, Vec128 b) { // N = N/2 + N/2 (upper half undefined) template -HWY_API Vec128 Combine(Simd d, Vec128 hi_half, +HWY_API Vec128 Combine(Simd d, Vec128 hi_half, Vec128 lo_half) { const Half d2; const RebindToUnsigned du2; @@ -3615,19 +3807,18 @@ HWY_API Vec128 Combine(Simd d, Vec128 hi_half, // ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) template -HWY_API Vec128 ZeroExtendVector(Full128 /* tag */, - Vec128 lo) { +HWY_API Vec128 ZeroExtendVector(Full128 /* tag */, Vec64 lo) { return Vec128{_mm_move_epi64(lo.raw)}; } template -HWY_API Vec128 ZeroExtendVector(Full128 d, Vec128 lo) { +HWY_API Vec128 ZeroExtendVector(Full128 d, Vec64 lo) { const RebindToUnsigned du; return BitCast(d, ZeroExtendVector(du, BitCast(Half(), lo))); } template -HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { return IfThenElseZero(FirstN(d, N / 2), Vec128{lo.raw}); } @@ -3680,31 +3871,31 @@ HWY_API Vec128 ConcatUpperLower(Full128 /* tag */, // ------------------------------ Concat partial (Combine, LowerHalf) template -HWY_API Vec128 ConcatLowerLower(Simd d, Vec128 hi, +HWY_API Vec128 ConcatLowerLower(Simd d, Vec128 hi, Vec128 lo) { const Half d2; - return Combine(LowerHalf(d2, hi), LowerHalf(d2, lo)); + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); } template -HWY_API Vec128 ConcatUpperUpper(Simd d, Vec128 hi, +HWY_API Vec128 ConcatUpperUpper(Simd d, Vec128 hi, Vec128 lo) { const Half d2; - return Combine(UpperHalf(d2, hi), UpperHalf(d2, lo)); + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); } template -HWY_API Vec128 ConcatLowerUpper(Simd d, const Vec128 hi, +HWY_API Vec128 ConcatLowerUpper(Simd d, const Vec128 hi, const Vec128 lo) { const Half d2; - return Combine(LowerHalf(d2, hi), UpperHalf(d2, lo)); + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); } template -HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, +HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, Vec128 lo) { const Half d2; - return Combine(UpperHalf(d2, hi), LowerHalf(d2, lo)); + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); } // ------------------------------ ConcatOdd @@ -3725,7 +3916,7 @@ HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, // 32-bit partial template -HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, Vec128 lo) { +HWY_API Vec128 ConcatOdd(Full64 d, Vec128 hi, Vec128 lo) { return InterleaveUpper(d, lo, hi); } @@ -3754,8 +3945,7 @@ HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, // 32-bit partial template -HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, - Vec128 lo) { +HWY_API Vec128 ConcatEven(Full64 d, Vec128 hi, Vec128 lo) { return InterleaveLower(d, lo, hi); } @@ -3766,6 +3956,40 @@ HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { return InterleaveLower(d, lo, hi); } +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + // ------------------------------ OddEven (IfThenElse) namespace detail { @@ -3773,7 +3997,7 @@ namespace detail { template HWY_INLINE Vec128 OddEven(hwy::SizeTag<1> /* tag */, const Vec128 a, const Vec128 b) { - const Simd d; + const DFromV d; const Repartition d8; alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; @@ -3783,7 +4007,7 @@ template HWY_INLINE Vec128 OddEven(hwy::SizeTag<2> /* tag */, const Vec128 a, const Vec128 b) { #if HWY_TARGET == HWY_SSSE3 - const Simd d; + const DFromV d; const Repartition d8; alignas(16) constexpr uint8_t mask[16] = {0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; @@ -3866,7 +4090,7 @@ namespace detail { // Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. template HWY_INLINE Vec128, N> Pow2(const Vec128 v) { - const Simd d; + const DFromV d; const RepartitionToWide dw; const Rebind df; const auto zero = Zero(d); @@ -3885,7 +4109,7 @@ HWY_INLINE Vec128, N> Pow2(const Vec128 v) { // Same, for 32-bit shifts. template HWY_INLINE Vec128, N> Pow2(const Vec128 v) { - const Simd d; + const DFromV d; const auto exp = ShiftLeft<23>(v); const auto f = exp + Set(d, 0x3F800000); // 1.0f // Do not use ConvertTo because we rely on the native 0x80..00 overflow @@ -3937,16 +4161,16 @@ HWY_API Vec128 operator<<(const Vec128 v, return Vec128{_mm_sllv_epi64(v.raw, bits.raw)}; #endif } -HWY_API Vec128 operator<<(const Vec128 v, - const Vec128 bits) { - return Vec128{_mm_sll_epi64(v.raw, bits.raw)}; +HWY_API Vec64 operator<<(const Vec64 v, + const Vec64 bits) { + return Vec64{_mm_sll_epi64(v.raw, bits.raw)}; } // Signed left shift is the same as unsigned. template HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { - const Simd di; - const Simd, N> du; + const DFromV di; + const RebindToUnsigned du; return BitCast(di, BitCast(du, v) << BitCast(du, bits)); } @@ -3963,7 +4187,7 @@ HWY_API Vec128 operator>>(const Vec128 in, #if HWY_TARGET <= HWY_AVX3 return Vec128{_mm_srlv_epi16(in.raw, bits.raw)}; #else - const Simd d; + const Simd d; // For bits=0, we cannot mul by 2^16, so fix the result later. const auto out = MulHigh(in, detail::Pow2(Set(d, 16) - bits)); // Replace output with input where bits == 0. @@ -3980,7 +4204,7 @@ HWY_API Vec128 operator>>(const Vec128 in, const Vec128 bits) { #if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 // 32x32 -> 64 bit mul, then shift right by 32. - const Simd d32; + const Simd d32; // Move odd lanes into position for the second mul. Shuffle more gracefully // handles N=1 than repartitioning to u64 and shifting 32 bits right. const Vec128 in31{_mm_shuffle_epi32(in.raw, 0x31)}; @@ -4014,9 +4238,9 @@ HWY_API Vec128 operator>>(const Vec128 v, return Vec128{_mm_srlv_epi64(v.raw, bits.raw)}; #endif } -HWY_API Vec128 operator>>(const Vec128 v, - const Vec128 bits) { - return Vec128{_mm_srl_epi64(v.raw, bits.raw)}; +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + return Vec64{_mm_srl_epi64(v.raw, bits.raw)}; } #if HWY_TARGET > HWY_AVX3 // AVX2 or older @@ -4043,7 +4267,7 @@ HWY_API Vec128 operator>>(const Vec128 v, #if HWY_TARGET <= HWY_AVX3 return Vec128{_mm_srav_epi16(v.raw, bits.raw)}; #else - return detail::SignedShr(Simd(), v, bits); + return detail::SignedShr(Simd(), v, bits); #endif } HWY_API Vec128 operator>>(const Vec128 v, @@ -4057,7 +4281,7 @@ HWY_API Vec128 operator>>(const Vec128 v, #if HWY_TARGET <= HWY_AVX3 return Vec128{_mm_srav_epi32(v.raw, bits.raw)}; #else - return detail::SignedShr(Simd(), v, bits); + return detail::SignedShr(Simd(), v, bits); #endif } HWY_API Vec128 operator>>(const Vec128 v, @@ -4071,7 +4295,7 @@ HWY_API Vec128 operator>>(const Vec128 v, #if HWY_TARGET <= HWY_AVX3 return Vec128{_mm_srav_epi64(v.raw, bits.raw)}; #else - return detail::SignedShr(Simd(), v, bits); + return detail::SignedShr(Simd(), v, bits); #endif } @@ -4096,7 +4320,7 @@ HWY_INLINE Vec128 MulOdd(const Vec128 a, // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) template -HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, Vec128 a, Vec128 b, const Vec128 sum0, @@ -4121,7 +4345,7 @@ HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, // Unsigned: zero-extend. template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { #if HWY_TARGET == HWY_SSSE3 const __m128i zero = _mm_setzero_si128(); @@ -4131,7 +4355,7 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, #endif } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { #if HWY_TARGET == HWY_SSSE3 return Vec128{_mm_unpacklo_epi16(v.raw, _mm_setzero_si128())}; @@ -4140,7 +4364,7 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, #endif } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { #if HWY_TARGET == HWY_SSSE3 return Vec128{_mm_unpacklo_epi32(v.raw, _mm_setzero_si128())}; @@ -4149,7 +4373,7 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, #endif } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { #if HWY_TARGET == HWY_SSSE3 const __m128i zero = _mm_setzero_si128(); @@ -4162,24 +4386,24 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, // Unsigned to signed: same plus cast. template -HWY_API Vec128 PromoteTo(Simd di, +HWY_API Vec128 PromoteTo(Simd di, const Vec128 v) { - return BitCast(di, PromoteTo(Simd(), v)); + return BitCast(di, PromoteTo(Simd(), v)); } template -HWY_API Vec128 PromoteTo(Simd di, +HWY_API Vec128 PromoteTo(Simd di, const Vec128 v) { - return BitCast(di, PromoteTo(Simd(), v)); + return BitCast(di, PromoteTo(Simd(), v)); } template -HWY_API Vec128 PromoteTo(Simd di, +HWY_API Vec128 PromoteTo(Simd di, const Vec128 v) { - return BitCast(di, PromoteTo(Simd(), v)); + return BitCast(di, PromoteTo(Simd(), v)); } // Signed: replicate sign bit. template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { #if HWY_TARGET == HWY_SSSE3 return ShiftRight<8>(Vec128{_mm_unpacklo_epi8(v.raw, v.raw)}); @@ -4188,7 +4412,7 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, #endif } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { #if HWY_TARGET == HWY_SSSE3 return ShiftRight<16>(Vec128{_mm_unpacklo_epi16(v.raw, v.raw)}); @@ -4197,7 +4421,7 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, #endif } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { #if HWY_TARGET == HWY_SSSE3 return ShiftRight<32>(Vec128{_mm_unpacklo_epi32(v.raw, v.raw)}); @@ -4206,7 +4430,7 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, #endif } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { #if HWY_TARGET == HWY_SSSE3 const __m128i x2 = _mm_unpacklo_epi8(v.raw, v.raw); @@ -4226,7 +4450,7 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, #define HWY_INLINE_F16 HWY_INLINE #endif template -HWY_INLINE_F16 Vec128 PromoteTo(Simd df32, +HWY_INLINE_F16 Vec128 PromoteTo(Simd df32, const Vec128 v) { #if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) const RebindToSigned di32; @@ -4252,7 +4476,7 @@ HWY_INLINE_F16 Vec128 PromoteTo(Simd df32, } template -HWY_API Vec128 PromoteTo(Simd df32, +HWY_API Vec128 PromoteTo(Simd df32, const Vec128 v) { const Rebind du16; const RebindToSigned di32; @@ -4260,13 +4484,13 @@ HWY_API Vec128 PromoteTo(Simd df32, } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{_mm_cvtps_pd(v.raw)}; } template -HWY_API Vec128 PromoteTo(Simd /* tag */, +HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128{_mm_cvtepi32_pd(v.raw)}; } @@ -4274,11 +4498,11 @@ HWY_API Vec128 PromoteTo(Simd /* tag */, // ------------------------------ Demotions (full -> part w/ narrow lanes) template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { #if HWY_TARGET == HWY_SSSE3 - const Simd di32; - const Simd du16; + const Simd di32; + const Simd du16; const auto zero_if_neg = AndNot(ShiftRight<31>(v), v); const auto too_big = VecFromMask(di32, Gt(v, Set(di32, 0xFFFF))); const auto clamped = Or(zero_if_neg, too_big); @@ -4293,39 +4517,39 @@ HWY_API Vec128 DemoteTo(Simd /* tag */, } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128{_mm_packs_epi32(v.raw, v.raw)}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); return Vec128{_mm_packus_epi16(i16, i16)}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128{_mm_packus_epi16(v.raw, v.raw)}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); return Vec128{_mm_packs_epi16(i16, i16)}; } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128{_mm_packs_epi16(v.raw, v.raw)}; } template -HWY_API Vec128 DemoteTo(Simd df16, +HWY_API Vec128 DemoteTo(Simd df16, const Vec128 v) { #if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) const RebindToUnsigned du16; @@ -4360,7 +4584,7 @@ HWY_API Vec128 DemoteTo(Simd df16, } template -HWY_API Vec128 DemoteTo(Simd dbf16, +HWY_API Vec128 DemoteTo(Simd dbf16, const Vec128 v) { // TODO(janwas): _mm_cvtneps_pbh once we have avx512bf16. const Rebind di32; @@ -4372,7 +4596,7 @@ HWY_API Vec128 DemoteTo(Simd dbf16, template HWY_API Vec128 ReorderDemote2To( - Simd dbf16, Vec128 a, Vec128 b) { + Simd dbf16, Vec128 a, Vec128 b) { // TODO(janwas): _mm_cvtne2ps_pbh once we have avx512bf16. const RebindToUnsigned du16; const Repartition du32; @@ -4381,7 +4605,7 @@ HWY_API Vec128 ReorderDemote2To( } template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128{_mm_cvtpd_ps(v.raw)}; } @@ -4391,7 +4615,7 @@ namespace detail { // For well-defined float->int demotion in all x86_*-inl.h. template -HWY_INLINE auto ClampF64ToI32Max(Simd d, decltype(Zero(d)) v) +HWY_INLINE auto ClampF64ToI32Max(Simd d, decltype(Zero(d)) v) -> decltype(Zero(d)) { // The max can be exactly represented in binary64, so clamping beforehand // prevents x86 conversion from raising an exception and returning 80..00. @@ -4401,35 +4625,43 @@ HWY_INLINE auto ClampF64ToI32Max(Simd d, decltype(Zero(d)) v) // For ConvertTo float->int of same size, clamping before conversion would // change the result because the max integer value is not exactly representable. // Instead detect the overflow result after conversion and fix it. -template , N>> -HWY_INLINE auto FixConversionOverflow(Simd di, - decltype(Zero(DF())) original, +template > +HWY_INLINE auto FixConversionOverflow(DI di, VFromD original, decltype(Zero(di).raw) converted_raw) - -> decltype(Zero(di)) { + -> VFromD { // Combinations of original and output sign: // --: normal <0 or -huge_val to 80..00: OK // -+: -0 to 0 : OK // +-: +huge_val to 80..00 : xor with FF..FF to get 7F..FF // ++: normal >0 : OK - const auto converted = decltype(Zero(di)){converted_raw}; + const auto converted = VFromD{converted_raw}; const auto sign_wrong = AndNot(BitCast(di, original), converted); - return BitCast(di, Xor(converted, BroadcastSignBit(sign_wrong))); +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG + // Critical GCC 11 compiler bug (possibly also GCC 10): omits the Xor; also + // Add() if using that instead. Work around with one more instruction. + const RebindToUnsigned du; + const VFromD mask = BroadcastSignBit(sign_wrong); + const VFromD max = BitCast(di, ShiftRight<1>(BitCast(du, mask))); + return IfVecThenElse(mask, max, converted); +#else + return Xor(converted, BroadcastSignBit(sign_wrong)); +#endif } } // namespace detail template -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { - const auto clamped = detail::ClampF64ToI32Max(Simd(), v); + const auto clamped = detail::ClampF64ToI32Max(Simd(), v); return Vec128{_mm_cvttpd_epi32(clamped.raw)}; } // For already range-limited input [0, 255]. template HWY_API Vec128 U8FromU32(const Vec128 v) { - const Simd d32; - const Simd d8; + const Simd d32; + const Simd d8; alignas(16) static constexpr uint32_t k8From32[4] = { 0x0C080400u, 0x0C080400u, 0x0C080400u, 0x0C080400u}; // Also replicate bytes into all 32 bit lanes for safety. @@ -4440,13 +4672,13 @@ HWY_API Vec128 U8FromU32(const Vec128 v) { // ------------------------------ Integer <=> fp (ShiftRight, OddEven) template -HWY_API Vec128 ConvertTo(Simd /* tag */, +HWY_API Vec128 ConvertTo(Simd /* tag */, const Vec128 v) { return Vec128{_mm_cvtepi32_ps(v.raw)}; } template -HWY_API Vec128 ConvertTo(Simd dd, +HWY_API Vec128 ConvertTo(Simd dd, const Vec128 v) { #if HWY_TARGET <= HWY_AVX3 (void)dd; @@ -4471,7 +4703,7 @@ HWY_API Vec128 ConvertTo(Simd dd, // Truncates (rounds toward zero). template -HWY_API Vec128 ConvertTo(const Simd di, +HWY_API Vec128 ConvertTo(const Simd di, const Vec128 v) { return detail::FixConversionOverflow(di, v, _mm_cvttps_epi32(v.raw)); } @@ -4486,7 +4718,7 @@ HWY_API Vec128 ConvertTo(Full128 di, const Vec128 v) { const __m128i i1 = _mm_cvtsi64_si128(_mm_cvttsd_si64(UpperHalf(dd2, v).raw)); return detail::FixConversionOverflow(di, v, _mm_unpacklo_epi64(i0, i1)); #else - using VI = decltype(Zero(di)); + using VI = VFromD; const VI k0 = Zero(di); const VI k1 = Set(di, 1); const VI k51 = Set(di, 51); @@ -4521,22 +4753,21 @@ HWY_API Vec128 ConvertTo(Full128 di, const Vec128 v) { return (magnitude ^ sign_mask) - sign_mask; #endif } -HWY_API Vec128 ConvertTo(Simd di, - const Vec128 v) { +HWY_API Vec64 ConvertTo(Full64 di, const Vec64 v) { // Only need to specialize for non-AVX3, 64-bit (single scalar op) #if HWY_TARGET > HWY_AVX3 && HWY_ARCH_X86_64 - const Vec128 i0{_mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw))}; + const Vec64 i0{_mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw))}; return detail::FixConversionOverflow(di, v, i0.raw); #else (void)di; const auto full = ConvertTo(Full128(), Vec128{v.raw}); - return Vec128{full.raw}; + return Vec64{full.raw}; #endif } template HWY_API Vec128 NearestInt(const Vec128 v) { - const Simd di; + const Simd di; return detail::FixConversionOverflow(di, v, _mm_cvtps_epi32(v.raw)); } @@ -4550,7 +4781,7 @@ HWY_API Vec128 Round(const Vec128 v) { // Rely on rounding after addition with a large value such that no mantissa // bits remain (assuming the current mode is nearest-even). We may need a // compiler flag for precise floating-point to prevent "optimizing" this out. - const Simd df; + const Simd df; const auto max = Set(df, MantissaEnd()); const auto large = CopySignToAbs(max, v); const auto added = large + v; @@ -4566,7 +4797,7 @@ namespace detail { // (because mantissa >> exponent is zero). template HWY_INLINE Mask128 UseInt(const Vec128 v) { - return Abs(v) < Set(Simd(), MantissaEnd()); + return Abs(v) < Set(Simd(), MantissaEnd()); } } // namespace detail @@ -4574,7 +4805,7 @@ HWY_INLINE Mask128 UseInt(const Vec128 v) { // Toward zero, aka truncate template HWY_API Vec128 Trunc(const Vec128 v) { - const Simd df; + const Simd df; const RebindToSigned di; const auto integer = ConvertTo(di, v); // round toward 0 @@ -4586,7 +4817,7 @@ HWY_API Vec128 Trunc(const Vec128 v) { // Toward +infinity, aka ceiling template HWY_API Vec128 Ceil(const Vec128 v) { - const Simd df; + const Simd df; const RebindToSigned di; const auto integer = ConvertTo(di, v); // round toward 0 @@ -4601,7 +4832,7 @@ HWY_API Vec128 Ceil(const Vec128 v) { // Toward -infinity, aka floor template HWY_API Vec128 Floor(const Vec128 v) { - const Simd df; + const Simd df; const RebindToSigned di; const auto integer = ConvertTo(di, v); // round toward 0 @@ -4681,6 +4912,11 @@ HWY_API Vec128 AESRound(Vec128 state, return Vec128{_mm_aesenc_si128(state.raw, round_key.raw)}; } +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenclast_si128(state.raw, round_key.raw)}; +} + template HWY_API Vec128 CLMulLower(Vec128 a, Vec128 b) { @@ -4703,7 +4939,7 @@ HWY_API Vec128 CLMulUpper(Vec128 a, // `p` points to at least 8 readable bytes, not all of which need be valid. template -HWY_API Mask128 LoadMaskBits(Simd /* tag */, +HWY_API Mask128 LoadMaskBits(Simd /* tag */, const uint8_t* HWY_RESTRICT bits) { uint64_t mask_bits = 0; constexpr size_t kNumBytes = (N + 7) / 8; @@ -4719,7 +4955,7 @@ HWY_API Mask128 LoadMaskBits(Simd /* tag */, // `p` points to at least 8 writable bytes. template -HWY_API size_t StoreMaskBits(const Simd /* tag */, +HWY_API size_t StoreMaskBits(const Simd /* tag */, const Mask128 mask, uint8_t* bits) { constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(&mask.raw, bits); @@ -4738,26 +4974,27 @@ HWY_API size_t StoreMaskBits(const Simd /* tag */, // Beware: the suffix indicates the number of mask bits, not lane size! template -HWY_API size_t CountTrue(const Simd /* tag */, const Mask128 mask) { +HWY_API size_t CountTrue(const Simd /* tag */, + const Mask128 mask) { const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); return PopCount(mask_bits); } template -HWY_API intptr_t FindFirstTrue(const Simd /* tag */, +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, const Mask128 mask) { const uint32_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); return mask.raw ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; } template -HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { +HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); return mask_bits == 0; } template -HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); // Cannot use _kortestc because we may have less than 8 mask bits. return mask_bits == (1u << N) - 1; @@ -4769,7 +5006,7 @@ HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { namespace detail { // Returns permutevar_epi16 indices for 16-bit Compress. Also used by x86_256. -HWY_INLINE Vec128 IndicesForCompress16(uint64_t mask_bits) { +HWY_INLINE Vec128 IndicesForCompress16(uint64_t mask_bits) { Full128 du16; // Table of u16 indices packed into bytes to reduce L1 usage. Will be unpacked // to u16. Ideally we would broadcast 8*3 (half of the 8 bytes currently used) @@ -4866,7 +5103,7 @@ HWY_INLINE Vec128 IndicesForCompress16(uint64_t mask_bits) { template HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { - const Simd d; + const Simd d; const Rebind du; const auto vu = BitCast(du, v); // (required for float16_t inputs) @@ -4905,14 +5142,14 @@ HWY_API Vec128 Compress(Vec128 v, template HWY_API Vec128 CompressBits(Vec128 v, const uint8_t* HWY_RESTRICT bits) { - return Compress(v, LoadMaskBits(Simd(), bits)); + return Compress(v, LoadMaskBits(Simd(), bits)); } // ------------------------------ CompressStore template -HWY_API size_t CompressStore(Vec128 v, Mask128 mask, Simd d, - T* HWY_RESTRICT unaligned) { +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd d, T* HWY_RESTRICT unaligned) { const Rebind du; const auto vu = BitCast(du, v); // (required for float16_t inputs) @@ -4930,21 +5167,23 @@ HWY_API size_t CompressStore(Vec128 v, Mask128 mask, Simd d, template HWY_API size_t CompressStore(Vec128 v, Mask128 mask, - Simd /* tag */, T* HWY_RESTRICT unaligned) { + Simd /* tag */, + T* HWY_RESTRICT unaligned) { _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); } template HWY_API size_t CompressStore(Vec128 v, Mask128 mask, - Simd /* tag */, T* HWY_RESTRICT unaligned) { + Simd /* tag */, + T* HWY_RESTRICT unaligned) { _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); } template HWY_API size_t CompressStore(Vec128 v, Mask128 mask, - Simd /* tag */, + Simd /* tag */, float* HWY_RESTRICT unaligned) { _mm_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); @@ -4952,7 +5191,7 @@ HWY_API size_t CompressStore(Vec128 v, Mask128 mask, template HWY_API size_t CompressStore(Vec128 v, Mask128 mask, - Simd /* tag */, + Simd /* tag */, double* HWY_RESTRICT unaligned) { _mm_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); @@ -4961,7 +5200,8 @@ HWY_API size_t CompressStore(Vec128 v, Mask128 mask, // ------------------------------ CompressBlendedStore (CompressStore) template HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, - Simd d, T* HWY_RESTRICT unaligned) { + Simd d, + T* HWY_RESTRICT unaligned) { // AVX-512 already does the blending at no extra cost (latency 11, // rthroughput 2 - same as compress plus store). if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) != 2) { @@ -4971,7 +5211,7 @@ HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, } return CompressStore(v, m, d, unaligned); } else { - const size_t count = CountTrue(m); + const size_t count = CountTrue(d, m); const Vec128 compressed = Compress(v, m); const Vec128 prev = LoadU(d, unaligned); StoreU(IfThenElse(FirstN(d, count), compressed, prev), d, unaligned); @@ -4983,8 +5223,8 @@ HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, template HWY_API size_t CompressBitsStore(Vec128 v, - const uint8_t* HWY_RESTRICT bits, Simd d, - T* HWY_RESTRICT unaligned) { + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); } @@ -4995,7 +5235,7 @@ HWY_API size_t CompressBitsStore(Vec128 v, namespace detail { template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; // Easier than Set(), which would require an >8-bit type, which would not // compile for T=uint8_t, N=1. @@ -5012,7 +5252,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; const auto vmask_bits = Set(du, static_cast(mask_bits)); @@ -5020,7 +5260,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; const auto vmask_bits = Set(du, static_cast(mask_bits)); @@ -5028,7 +5268,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { } template -HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(16) constexpr uint64_t kBit[8] = {1, 2}; return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); @@ -5038,7 +5278,7 @@ HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { // `p` points to at least 8 readable bytes, not all of which need be valid. template -HWY_API Mask128 LoadMaskBits(Simd d, +HWY_API Mask128 LoadMaskBits(Simd d, const uint8_t* HWY_RESTRICT bits) { uint64_t mask_bits = 0; constexpr size_t kNumBytes = (N + 7) / 8; @@ -5061,7 +5301,7 @@ constexpr HWY_INLINE uint64_t U64FromInt(int mask_bits) { template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, const Mask128 mask) { - const Simd d; + const Simd d; const auto sign_bits = BitCast(d, VecFromMask(d, mask)).raw; return U64FromInt(_mm_movemask_epi8(sign_bits)); } @@ -5077,8 +5317,8 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, const Mask128 mask) { - const Simd d; - const Simd df; + const Simd d; + const Simd df; const auto sign_bits = BitCast(df, VecFromMask(d, mask)); return U64FromInt(_mm_movemask_ps(sign_bits.raw)); } @@ -5086,8 +5326,8 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask128 mask) { - const Simd d; - const Simd df; + const Simd d; + const Simd df; const auto sign_bits = BitCast(df, VecFromMask(d, mask)); return U64FromInt(_mm_movemask_pd(sign_bits.raw)); } @@ -5107,7 +5347,7 @@ HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { // `p` points to at least 8 writable bytes. template -HWY_API size_t StoreMaskBits(const Simd /* tag */, +HWY_API size_t StoreMaskBits(const Simd /* tag */, const Mask128 mask, uint8_t* bits) { constexpr size_t kNumBytes = (N + 7) / 8; const uint64_t mask_bits = detail::BitsFromMask(mask); @@ -5118,25 +5358,26 @@ HWY_API size_t StoreMaskBits(const Simd /* tag */, // ------------------------------ Mask testing template -HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { +HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { // Cheaper than PTEST, which is 2 uop / 3L. return detail::BitsFromMask(mask) == 0; } template -HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { constexpr uint64_t kAllBits = detail::OnlyActive((1ull << (16 / sizeof(T))) - 1); return detail::BitsFromMask(mask) == kAllBits; } template -HWY_API size_t CountTrue(const Simd /* tag */, const Mask128 mask) { +HWY_API size_t CountTrue(const Simd /* tag */, + const Mask128 mask) { return PopCount(detail::BitsFromMask(mask)); } template -HWY_API intptr_t FindFirstTrue(const Simd /* tag */, +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, const Mask128 mask) { const uint64_t mask_bits = detail::BitsFromMask(mask); return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; @@ -5147,10 +5388,10 @@ HWY_API intptr_t FindFirstTrue(const Simd /* tag */, namespace detail { template -HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { HWY_DASSERT(mask_bits < 256); const Rebind d8; - const Simd du; + const Simd du; // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need // byte indices for PSHUFB (one vector's worth for each of 256 combinations of @@ -5281,8 +5522,8 @@ HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { return BitCast(d, pairs + Set(du, 0x0100)); } -template -HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { HWY_DASSERT(mask_bits < 16); // There are only 4 lanes, so we can afford to load the index vector directly. @@ -5308,8 +5549,8 @@ HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); } -template -HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { HWY_DASSERT(mask_bits < 4); // There are only 2 lanes, so we can afford to load the index vector directly. @@ -5327,7 +5568,7 @@ HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { template HWY_API Vec128 Compress(Vec128 v, Mask128 m) { - const Simd d; + const Simd d; const RebindToUnsigned du; const uint64_t mask_bits = detail::BitsFromMask(m); @@ -5340,7 +5581,7 @@ HWY_API Vec128 Compress(Vec128 v, Mask128 m) { template HWY_API Vec128 CompressBits(Vec128 v, const uint8_t* HWY_RESTRICT bits) { - const Simd d; + const Simd d; const RebindToUnsigned du; uint64_t mask_bits = 0; @@ -5357,7 +5598,7 @@ HWY_API Vec128 CompressBits(Vec128 v, // ------------------------------ CompressStore, CompressBitsStore template -HWY_API size_t CompressStore(Vec128 v, Mask128 m, Simd d, +HWY_API size_t CompressStore(Vec128 v, Mask128 m, Simd d, T* HWY_RESTRICT unaligned) { const RebindToUnsigned du; @@ -5373,7 +5614,8 @@ HWY_API size_t CompressStore(Vec128 v, Mask128 m, Simd d, template HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, - Simd d, T* HWY_RESTRICT unaligned) { + Simd d, + T* HWY_RESTRICT unaligned) { const RebindToUnsigned du; const uint64_t mask_bits = detail::BitsFromMask(m); @@ -5391,8 +5633,8 @@ HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, template HWY_API size_t CompressBitsStore(Vec128 v, - const uint8_t* HWY_RESTRICT bits, Simd d, - T* HWY_RESTRICT unaligned) { + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { const RebindToUnsigned du; uint64_t mask_bits = 0; @@ -5461,9 +5703,8 @@ HWY_API void StoreInterleaved3(const Vec128 v0, } // 64 bits -HWY_API void StoreInterleaved3(const Vec128 v0, - const Vec128 v1, - const Vec128 v2, Simd d, +HWY_API void StoreInterleaved3(const Vec64 v0, const Vec64 v1, + const Vec64 v2, Full64 d, uint8_t* HWY_RESTRICT unaligned) { // Use full vectors for the shuffles and first result. const Full128 d_full; @@ -5507,7 +5748,7 @@ template HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, const Vec128 v2, - Simd /*tag*/, + Simd /*tag*/, uint8_t* HWY_RESTRICT unaligned) { // Use full vectors for the shuffles and result. const Full128 d_full; @@ -5559,11 +5800,11 @@ HWY_API void StoreInterleaved4(const Vec128 v0, } // 64 bits -HWY_API void StoreInterleaved4(const Vec128 in0, - const Vec128 in1, - const Vec128 in2, - const Vec128 in3, - Simd /*tag*/, +HWY_API void StoreInterleaved4(const Vec64 in0, + const Vec64 in1, + const Vec64 in2, + const Vec64 in3, + Full64 /*tag*/, uint8_t* HWY_RESTRICT unaligned) { // Use full vectors to reduce the number of stores. const Full128 d_full8; @@ -5588,7 +5829,7 @@ HWY_API void StoreInterleaved4(const Vec128 in0, const Vec128 in1, const Vec128 in2, const Vec128 in3, - Simd /*tag*/, + Simd /*tag*/, uint8_t* HWY_RESTRICT unaligned) { // Use full vectors to reduce the number of stores. const Full128 d_full8; @@ -5698,133 +5939,85 @@ HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, // u16/i16 template HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { - const Repartition> d32; + const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(d32, Min(even, odd)); // Also broadcast into odd lanes. - return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); + return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); } template HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { - const Repartition> d32; + const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(d32, Max(even, odd)); // Also broadcast into odd lanes. - return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); + return BitCast(Simd(), Or(min, ShiftLeft<16>(min))); } } // namespace detail // Supported for u/i/f 32/64. Returns the same value in each lane. template -HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { return detail::SumOfLanes(hwy::SizeTag(), v); } template -HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { return detail::MinOfLanes(hwy::SizeTag(), v); } template -HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { return detail::MaxOfLanes(hwy::SizeTag(), v); } -// ================================================== DEPRECATED +// ------------------------------ Lt128 -template -HWY_API size_t StoreMaskBits(const Mask128 mask, uint8_t* bits) { - return StoreMaskBits(Simd(), mask, bits); +namespace detail { + +// Returns vector-mask for Lt128. Also used by x86_256/x86_512. +template > +HWY_INLINE V Lt128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, "Use u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const V eqHL = VecFromMask(d, Eq(a, b)); + const V ltHL = VecFromMask(d, Lt(a, b)); + const V ltLX = ShiftLeftLanes<1>(ltHL); + const V vecHx = OrAnd(ltHL, eqHL, ltLX); + return InterleaveUpper(d, vecHx, vecHx); } -template -HWY_API bool AllTrue(const Mask128 mask) { - return AllTrue(Simd(), mask); +} // namespace detail + +template > +HWY_API MFromD Lt128(D d, const V a, const V b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); } -template -HWY_API bool AllFalse(const Mask128 mask) { - return AllFalse(Simd(), mask); +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template > +HWY_API V Min128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); } -template -HWY_API size_t CountTrue(const Mask128 mask) { - return CountTrue(Simd(), mask); -} - -template -HWY_API Vec128 SumOfLanes(const Vec128 v) { - return SumOfLanes(Simd(), v); -} -template -HWY_API Vec128 MinOfLanes(const Vec128 v) { - return MinOfLanes(Simd(), v); -} -template -HWY_API Vec128 MaxOfLanes(const Vec128 v) { - return MaxOfLanes(Simd(), v); -} - -template -HWY_API Vec128 UpperHalf(Vec128 v) { - return UpperHalf(Half>(), v); -} - -template -HWY_API Vec128 ShiftRightBytes(const Vec128 v) { - return ShiftRightBytes(Simd(), v); -} - -template -HWY_API Vec128 ShiftRightLanes(const Vec128 v) { - return ShiftRightLanes(Simd(), v); -} - -template -HWY_API Vec128 CombineShiftRightBytes(Vec128 hi, Vec128 lo) { - return CombineShiftRightBytes(Simd(), hi, lo); -} - -template -HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { - return InterleaveUpper(Simd(), a, b); -} - -template > -HWY_API VFromD> ZipUpper(Vec128 a, Vec128 b) { - return InterleaveUpper(RepartitionToWide(), a, b); -} - -template -HWY_API Vec128 Combine(Vec128 hi2, Vec128 lo2) { - return Combine(Simd(), hi2, lo2); -} - -template -HWY_API Vec128 ZeroExtendVector(Vec128 lo) { - return ZeroExtendVector(Simd(), lo); -} - -template -HWY_API Vec128 ConcatLowerLower(Vec128 hi, Vec128 lo) { - return ConcatLowerLower(Simd(), hi, lo); -} - -template -HWY_API Vec128 ConcatUpperUpper(Vec128 hi, Vec128 lo) { - return ConcatUpperUpper(Simd(), hi, lo); -} - -template -HWY_API Vec128 ConcatLowerUpper(const Vec128 hi, - const Vec128 lo) { - return ConcatLowerUpper(Simd(), hi, lo); -} - -template -HWY_API Vec128 ConcatUpperLower(Vec128 hi, Vec128 lo) { - return ConcatUpperLower(Simd(), hi, lo); +template > +HWY_API V Max128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), b, a); } // ================================================== Operator wrapper diff --git a/third_party/highway/hwy/ops/x86_256-inl.h b/third_party/highway/hwy/ops/x86_256-inl.h index 2a5315ad0e1a..d420ec0cd0ef 100644 --- a/third_party/highway/hwy/ops/x86_256-inl.h +++ b/third_party/highway/hwy/ops/x86_256-inl.h @@ -44,10 +44,6 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { - -template -using Full256 = Simd; - namespace detail { template @@ -326,6 +322,38 @@ HWY_API Vec256 Not(const Vec256 v) { #endif } +// ------------------------------ OrAnd + +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { +#if HWY_TARGET <= HWY_AVX3 + const Full256 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { +#if HWY_TARGET <= HWY_AVX3 + const Full256 d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + // ------------------------------ Operator overloads (internal-only if float) template @@ -785,6 +813,7 @@ HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { template HWY_API Vec256 ZeroIfNegative(Vec256 v) { const auto zero = Zero(Full256()); + // AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes return IfThenElse(MaskFromVec(v), zero, v); } @@ -1395,7 +1424,12 @@ HWY_API Vec256 operator-(const Vec256 a, return Vec256{_mm256_sub_pd(a.raw, b.raw)}; } -// ------------------------------ Saturating addition +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(const Vec256 v) { + return Vec256{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; +} + +// ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. @@ -1419,7 +1453,7 @@ HWY_API Vec256 SaturatedAdd(const Vec256 a, return Vec256{_mm256_adds_epi16(a.raw, b.raw)}; } -// ------------------------------ Saturating subtraction +// ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. @@ -1685,6 +1719,35 @@ HWY_API Vec256 Abs(const Vec256 v) { #endif } +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, + Vec256 no) { + // int8: AVX2 IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const Full256 d; + const RebindToSigned di; + + // 16-bit: no native blendv, so copy sign to lower byte's MSB. + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const Full256 d; + const RebindToFloat df; + + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + const MFromD msb = MaskFromVec(BitCast(df, v)); + return BitCast(d, IfThenElse(msb, BitCast(df, yes), BitCast(df, no))); +} + // ------------------------------ ShiftLeftSame HWY_API Vec256 ShiftLeftSame(const Vec256 v, @@ -2234,7 +2297,7 @@ HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, Store(v, d, lanes); alignas(32) Offset offset_lanes[N]; - Store(offset, Simd(), offset_lanes); + Store(offset, Full256(), offset_lanes); uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { @@ -2252,7 +2315,7 @@ HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, Store(v, d, lanes); alignas(32) Index index_lanes[N]; - Store(index, Simd(), index_lanes); + Store(index, Full256(), index_lanes); for (size_t i = 0; i < N; ++i) { base[index_lanes[i]] = lanes[i]; @@ -2473,7 +2536,7 @@ HWY_API Vec256 ShiftRightBytes(Full256 /* tag */, const Vec256 v) { template HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { const Repartition d8; - return BitCast(d, ShiftRightBytes(BitCast(d8, v))); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ CombineShiftRightBytes @@ -2733,6 +2796,81 @@ HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { #endif } +// ------------------------------ Reverse2 + +template +HWY_API Vec256 Reverse2(Full256 d, const Vec256 v) { + const Full256 du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse4[16] = {3, 2, 1, 0, 7, 6, 5, 4, + 11, 10, 9, 8, 15, 14, 13, 12}; + const Vec256 idx = Load(di, kReverse4); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { + return Vec256{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} +HWY_API Vec256 Reverse4(Full256 /* tag */, Vec256 v) { + return Vec256{_mm256_permute4x64_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec256 idx = Load(di, kReverse8); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { + return Reverse(d, v); +} + +template +HWY_API Vec256 Reverse8(Full256 /* tag */, const Vec256 /* v */) { + HWY_ASSERT(0); // AVX2 does not have 8 64-bit lanes +} + // ------------------------------ InterleaveLower // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides @@ -2782,12 +2920,6 @@ HWY_API Vec256 InterleaveLower(const Vec256 a, return Vec256{_mm256_unpacklo_pd(a.raw, b.raw)}; } -// Additional overload for the optional Simd<> tag. -template > -HWY_API V InterleaveLower(Full256 /* tag */, V a, V b) { - return InterleaveLower(a, b); -} - // ------------------------------ InterleaveUpper // All functions inside detail lack the required D parameter. @@ -2849,11 +2981,11 @@ HWY_API V InterleaveUpper(Full256 /* tag */, V a, V b) { // this is necessary because the single-lane scalar cannot return two values. template > HWY_API Vec256 ZipLower(Vec256 a, Vec256 b) { - return BitCast(Full256(), InterleaveLower(Full256(), a, b)); + return BitCast(Full256(), InterleaveLower(a, b)); } template > HWY_API Vec256 ZipLower(Full256 dw, Vec256 a, Vec256 b) { - return BitCast(dw, InterleaveLower(Full256(), a, b)); + return BitCast(dw, InterleaveLower(a, b)); } template > @@ -3063,6 +3195,38 @@ HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, #endif } +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template +HWY_API Vec256 DupEven(const Vec256 v) { + return InterleaveLower(Full256(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec256 DupOdd(const Vec256 v) { + return InterleaveUpper(Full256(), v, v); +} + // ------------------------------ OddEven namespace detail { @@ -3140,6 +3304,13 @@ HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { return Vec256{_mm256_permute4x64_pd(v.raw, _MM_SHUFFLE(1, 0, 3, 2))}; } +// ------------------------------ ReverseBlocks (ConcatLowerUpper) + +template +HWY_API Vec256 ReverseBlocks(Full256 d, Vec256 v) { + return ConcatLowerUpper(d, v, v); +} + // ------------------------------ TableLookupBytes (ZeroExtendVector) // Both full @@ -3436,7 +3607,7 @@ HWY_API Vec128 DemoteTo(Full128 /* tag */, _mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; } -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Full64 /* tag */, const Vec256 v) { const __m256i u16_blocks = _mm256_packus_epi32(v.raw, v.raw); // Concatenate lower 64 bits of each 128-bit block @@ -3455,7 +3626,7 @@ HWY_API Vec128 DemoteTo(Full128 /* tag */, _mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; } -HWY_API Vec128 DemoteTo(Simd /* tag */, +HWY_API Vec128 DemoteTo(Full64 /* tag */, const Vec256 v) { const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); // Concatenate lower 64 bits of each 128-bit block @@ -3553,7 +3724,7 @@ HWY_API Vec128 U8FromU32(const Vec256 v) { const auto lo = LowerHalf(quad); const auto hi = UpperHalf(Full128(), quad); const auto pair = LowerHalf(lo | hi); - return BitCast(Simd(), pair); + return BitCast(Full64(), pair); } // ------------------------------ Integer <=> fp (ShiftRight, OddEven) @@ -3691,6 +3862,19 @@ HWY_API Vec256 AESRound(Vec256 state, #endif } +HWY_API Vec256 AESLastRound(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256{_mm256_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + HWY_API Vec256 CLMulLower(Vec256 a, Vec256 b) { #if HWY_TARGET == HWY_AVX3_DL return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x00)}; @@ -4019,7 +4203,7 @@ HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, #if HWY_TARGET <= HWY_AVX3_DL return CompressStore(v, m, d, unaligned); // also native #else - const size_t count = CountTrue(m); + const size_t count = CountTrue(d, m); const Vec256 compressed = Compress(v, m); // There is no 16-bit MaskedStore, so blend. const Vec256 prev = LoadU(d, unaligned); @@ -4244,7 +4428,7 @@ HWY_API intptr_t FindFirstTrue(const Full256 /* tag */, namespace detail { template -HWY_INLINE Indices256 IndicesFromBits(Simd d, +HWY_INLINE Indices256 IndicesFromBits(Full256 d, uint64_t mask_bits) { const RebindToUnsigned d32; // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT @@ -4307,7 +4491,7 @@ HWY_INLINE Indices256 IndicesFromBits(Simd d, } template -HWY_INLINE Indices256 IndicesFromBits(Simd d, +HWY_INLINE Indices256 IndicesFromBits(Full256 d, uint64_t mask_bits) { const Repartition d32; @@ -4353,8 +4537,8 @@ HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { const auto compressed1 = Compress(promoted1, mask_bits1); const Half dh; - const auto demoted0 = ZeroExtendVector(DemoteTo(dh, compressed0)); - const auto demoted1 = ZeroExtendVector(DemoteTo(dh, compressed1)); + const auto demoted0 = ZeroExtendVector(du, DemoteTo(dh, compressed0)); + const auto demoted1 = ZeroExtendVector(du, DemoteTo(dh, compressed1)); const size_t count0 = PopCount(mask_bits0); // Now combine by shifting demoted1 up. AVX2 lacks VPERMW, so start with @@ -4625,101 +4809,6 @@ HWY_API Vec256 MaxOfLanes(Full256 d, const Vec256 vHL) { return detail::MaxOfLanes(hwy::SizeTag(), Max(vLH, vHL)); } -// ================================================== DEPRECATED - -template -HWY_API size_t StoreMaskBits(const Mask256 mask, uint8_t* bits) { - return StoreMaskBits(Full256(), mask, bits); -} - -template -HWY_API bool AllTrue(const Mask256 mask) { - return AllTrue(Full256(), mask); -} - -template -HWY_API bool AllFalse(const Mask256 mask) { - return AllFalse(Full256(), mask); -} - -template -HWY_API size_t CountTrue(const Mask256 mask) { - return CountTrue(Full256(), mask); -} - -template -HWY_API Vec256 SumOfLanes(const Vec256 vHL) { - return SumOfLanes(Full256(), vHL); -} -template -HWY_API Vec256 MinOfLanes(const Vec256 vHL) { - return MinOfLanes(Full256(), vHL); -} -template -HWY_API Vec256 MaxOfLanes(const Vec256 vHL) { - return MaxOfLanes(Full256(), vHL); -} - -template -HWY_API Vec128 UpperHalf(Vec256 v) { - return UpperHalf(Full128(), v); -} - -template -HWY_API Vec256 ShiftRightBytes(const Vec256 v) { - return ShiftRightBytes(Full256(), v); -} - -template -HWY_API Vec256 ShiftRightLanes(const Vec256 v) { - return ShiftRightLanes(Full256(), v); -} - -template -HWY_API Vec256 CombineShiftRightBytes(Vec256 hi, Vec256 lo) { - return CombineShiftRightBytes(Full256(), hi, lo); -} - -template -HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { - return InterleaveUpper(Full256(), a, b); -} - -template -HWY_API Vec256> ZipUpper(Vec256 a, Vec256 b) { - return InterleaveUpper(Full256>(), a, b); -} - -template -HWY_API Vec256 Combine(Vec128 hi, Vec128 lo) { - return Combine(Full256(), hi, lo); -} - -template -HWY_API Vec256 ZeroExtendVector(Vec128 lo) { - return ZeroExtendVector(Full256(), lo); -} - -template -HWY_API Vec256 ConcatLowerLower(Vec256 hi, Vec256 lo) { - return ConcatLowerLower(Full256(), hi, lo); -} - -template -HWY_API Vec256 ConcatLowerUpper(Vec256 hi, Vec256 lo) { - return ConcatLowerUpper(Full256(), hi, lo); -} - -template -HWY_API Vec256 ConcatUpperLower(Vec256 hi, Vec256 lo) { - return ConcatUpperLower(Full256(), hi, lo); -} - -template -HWY_API Vec256 ConcatUpperUpper(Vec256 hi, Vec256 lo) { - return ConcatUpperUpper(Full256(), hi, lo); -} - // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy diff --git a/third_party/highway/hwy/ops/x86_512-inl.h b/third_party/highway/hwy/ops/x86_512-inl.h index 10bad0d1720f..cb4b2cfef302 100644 --- a/third_party/highway/hwy/ops/x86_512-inl.h +++ b/third_party/highway/hwy/ops/x86_512-inl.h @@ -57,9 +57,6 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { -template -using Full512 = Simd; - namespace detail { template @@ -313,6 +310,30 @@ HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { return Vec512{_mm512_xor_pd(a.raw, b.raw)}; } +// ------------------------------ OrAnd + +template +HWY_API Vec512 OrAnd(Vec512 o, Vec512 a1, Vec512 a2) { + const Full512 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec512 IfVecThenElse(Vec512 mask, Vec512 yes, Vec512 no) { + const Full512 d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +} + // ------------------------------ Operator overloads (internal-only if float) template @@ -579,6 +600,13 @@ HWY_API Vec512 IfThenZeroElse(const Mask512 mask, return Vec512{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; } +template +HWY_API Vec512 IfNegativeThenElse(Vec512 v, Vec512 yes, Vec512 no) { + static_assert(IsSigned(), "Only works for signed/float"); + // AVX3 MaskFromVec only looks at the MSB + return IfThenElse(MaskFromVec(v), yes, no); +} + template HWY_API Vec512 ZeroIfNegative(const Vec512 v) { // AVX3 MaskFromVec only looks at the MSB @@ -681,7 +709,12 @@ HWY_API Vec512 operator-(const Vec512 a, return Vec512{_mm512_sub_pd(a.raw, b.raw)}; } -// ------------------------------ Saturating addition +// ------------------------------ SumsOf8 +HWY_API Vec512 SumsOf8(const Vec512 v) { + return Vec512{_mm512_sad_epu8(v.raw, _mm512_setzero_si512())}; +} + +// ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. @@ -705,7 +738,7 @@ HWY_API Vec512 SaturatedAdd(const Vec512 a, return Vec512{_mm512_adds_epi16(a.raw, b.raw)}; } -// ------------------------------ Saturating subtraction +// ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. @@ -1820,7 +1853,7 @@ HWY_API Vec512 LoadDup128(Full512 /* tag */, // https://gcc.godbolt.org/z/-Jt_-F #if HWY_LOADDUP_ASM __m512i out; - asm("vbroadcasti128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); + asm("vbroadcasti128 %1, %[reg]" : [reg] "=x"(out) : "m"(p[0])); return Vec512{out}; #else const auto x4 = LoadU(Full128(), p); @@ -1831,7 +1864,7 @@ HWY_API Vec512 LoadDup128(Full512 /* tag */, const float* const HWY_RESTRICT p) { #if HWY_LOADDUP_ASM __m512 out; - asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); + asm("vbroadcastf128 %1, %[reg]" : [reg] "=x"(out) : "m"(p[0])); return Vec512{out}; #else const __m128 x4 = _mm_loadu_ps(p); @@ -1843,7 +1876,7 @@ HWY_API Vec512 LoadDup128(Full512 /* tag */, const double* const HWY_RESTRICT p) { #if HWY_LOADDUP_ASM __m512d out; - asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); + asm("vbroadcastf128 %1, %[reg]" : [reg] "=x"(out) : "m"(p[0])); return Vec512{out}; #else const __m128d x2 = _mm_loadu_pd(p); @@ -2007,7 +2040,7 @@ HWY_INLINE Vec512 GatherIndex(hwy::SizeTag<8> /* tag */, template HWY_API Vec512 GatherOffset(Full512 d, const T* HWY_RESTRICT base, const Vec512 offset) { -static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::GatherOffset(hwy::SizeTag(), d, base, offset); } template @@ -2173,7 +2206,7 @@ HWY_API Vec512 ShiftRightBytes(Full512 /* tag */, const Vec512 v) { template HWY_API Vec512 ShiftRightLanes(Full512 d, const Vec512 v) { const Repartition d8; - return BitCast(d, ShiftRightBytes(BitCast(d8, v))); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ CombineShiftRightBytes @@ -2396,6 +2429,78 @@ HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { return TableLookupLanes(v, SetTableIndices(d, kReverse)); } +// ------------------------------ Reverse2 + +template +HWY_API Vec512 Reverse2(Full512 d, const Vec512 v) { + const Full512 du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec512 Reverse4(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int16_t kReverse4[32] = { + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, + 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; + const Vec512 idx = Load(di, kReverse4); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { + return Vec512{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} +HWY_API Vec512 Reverse4(Full512 /* tag */, Vec512 v) { + return Vec512{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int16_t kReverse8[32] = { + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int32_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { + return Reverse(d, v); +} + // ------------------------------ InterleaveLower // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides @@ -2445,12 +2550,6 @@ HWY_API Vec512 InterleaveLower(const Vec512 a, return Vec512{_mm512_unpacklo_pd(a.raw, b.raw)}; } -// Additional overload for the optional Simd<> tag. -template > -HWY_API V InterleaveLower(Full512 /* tag */, V a, V b) { - return InterleaveLower(a, b); -} - // ------------------------------ InterleaveUpper // All functions inside detail lack the required D parameter. @@ -2515,8 +2614,8 @@ HWY_API Vec512 ZipLower(Vec512 a, Vec512 b) { return BitCast(Full512(), InterleaveLower(a, b)); } template > -HWY_API Vec512 ZipLower(Full512 d, Vec512 a, Vec512 b) { - return BitCast(Full512(), InterleaveLower(d, a, b)); +HWY_API Vec512 ZipLower(Full512 /* d */, Vec512 a, Vec512 b) { + return BitCast(Full512(), InterleaveLower(a, b)); } template > @@ -2564,17 +2663,17 @@ HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, template HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { - return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, 0x4E)}; + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; } HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { - return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, 0x4E)}; + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; } HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { - return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, 0x4E)}; + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; } // hiH,hiL loH,loL |-> hiH,loL (= outer halves) @@ -2675,6 +2774,36 @@ HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, __mmask8{0xFF}, hi.raw)}; } +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; +} +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; +} + +template +HWY_API Vec512 DupEven(const Vec512 v) { + return InterleaveLower(Full512(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; +} +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; +} + +template +HWY_API Vec512 DupOdd(const Vec512 v) { + return InterleaveUpper(Full512(), v, v); +} + // ------------------------------ OddEven template @@ -2705,17 +2834,29 @@ HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { template HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { - return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(2, 3, 0, 1))}; + return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_CDAB)}; } HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { - return Vec512{ - _mm512_shuffle_f32x4(v.raw, v.raw, _MM_SHUFFLE(2, 3, 0, 1))}; + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; } HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { - return Vec512{ - _mm512_shuffle_f64x2(v.raw, v.raw, _MM_SHUFFLE(2, 3, 0, 1))}; + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +// ------------------------------ ReverseBlocks + +template +HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { + return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 ReverseBlocks(Full512 /* tag */, + Vec512 v) { + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; } // ------------------------------ TableLookupBytes (ZeroExtendVector) @@ -3012,17 +3153,23 @@ HWY_API Vec512 AESRound(Vec512 state, #if HWY_TARGET == HWY_AVX3_DL return Vec512{_mm512_aesenc_epi128(state.raw, round_key.raw)}; #else - alignas(64) uint8_t a[64]; - alignas(64) uint8_t b[64]; const Full512 d; - const Full128 d128; - Store(state, d, a); - Store(round_key, d, b); - for (size_t i = 0; i < 64; i += 16) { - const auto enc = AESRound(Load(d128, a + i), Load(d128, b + i)); - Store(enc, d128, a + i); - } - return Load(d, a); + const Half d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 AESLastRound(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full512 d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); #endif } @@ -3264,8 +3411,8 @@ HWY_API Vec512 Compress(Vec512 v, const Mask512 mask) { const auto compressed0 = Compress(promoted0, mask0); const auto compressed1 = Compress(promoted1, mask1); - const auto demoted0 = ZeroExtendVector(DemoteTo(duh, compressed0)); - const auto demoted1 = ZeroExtendVector(DemoteTo(duh, compressed1)); + const auto demoted0 = ZeroExtendVector(du, DemoteTo(duh, compressed0)); + const auto demoted1 = ZeroExtendVector(du, DemoteTo(duh, compressed1)); // Concatenate into single vector by shifting upper with writemask. const size_t num0 = CountTrue(dw, mask0); @@ -3363,7 +3510,7 @@ HWY_API size_t CompressBlendedStore(Vec512 v, Mask512 m, Full512 d, if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) != 2) { return CompressStore(v, m, d, unaligned); } else { - const size_t count = CountTrue(m); + const size_t count = CountTrue(d, m); const Vec512 compressed = Compress(v, m); const Vec512 prev = LoadU(d, unaligned); StoreU(IfThenElse(FirstN(d, count), compressed, prev), d, unaligned); @@ -3422,9 +3569,9 @@ HWY_API void StoreInterleaved3(const Vec512 a, const Vec512 b, const auto k = (r2 | g2 | b2).raw; // low byte in each 128bit: 3A 2A 1A 0A // To obtain 10 0A 05 00 in one vector, transpose "rows" into "columns". - const auto k3_k0_i3_i0 = _mm512_shuffle_i64x2(i, k, _MM_SHUFFLE(3, 0, 3, 0)); - const auto i1_i2_j0_j1 = _mm512_shuffle_i64x2(j, i, _MM_SHUFFLE(1, 2, 0, 1)); - const auto j2_j3_k1_k2 = _mm512_shuffle_i64x2(k, j, _MM_SHUFFLE(2, 3, 1, 2)); + const auto k3_k0_i3_i0 = _mm512_shuffle_i64x2(i, k, _MM_PERM_DADA); + const auto i1_i2_j0_j1 = _mm512_shuffle_i64x2(j, i, _MM_PERM_BCAB); + const auto j2_j3_k1_k2 = _mm512_shuffle_i64x2(k, j, _MM_PERM_CDBC); // Alternating order, most-significant 128 bits from the second arg. const __mmask8 m = 0xCC; @@ -3456,12 +3603,12 @@ HWY_API void StoreInterleaved4(const Vec512 v0, const auto k = ZipLower(d32, ba8, dc8).raw; // 4x128bit: d..aB d..a8 const auto l = ZipUpper(d32, ba8, dc8).raw; // 4x128bit: d..aF d..aC // 128-bit blocks were independent until now; transpose 4x4. - const auto j1_j0_i1_i0 = _mm512_shuffle_i64x2(i, j, _MM_SHUFFLE(1, 0, 1, 0)); - const auto l1_l0_k1_k0 = _mm512_shuffle_i64x2(k, l, _MM_SHUFFLE(1, 0, 1, 0)); - const auto j3_j2_i3_i2 = _mm512_shuffle_i64x2(i, j, _MM_SHUFFLE(3, 2, 3, 2)); - const auto l3_l2_k3_k2 = _mm512_shuffle_i64x2(k, l, _MM_SHUFFLE(3, 2, 3, 2)); - constexpr int k20 = _MM_SHUFFLE(2, 0, 2, 0); - constexpr int k31 = _MM_SHUFFLE(3, 1, 3, 1); + const auto j1_j0_i1_i0 = _mm512_shuffle_i64x2(i, j, _MM_PERM_BABA); + const auto l1_l0_k1_k0 = _mm512_shuffle_i64x2(k, l, _MM_PERM_BABA); + const auto j3_j2_i3_i2 = _mm512_shuffle_i64x2(i, j, _MM_PERM_DCDC); + const auto l3_l2_k3_k2 = _mm512_shuffle_i64x2(k, l, _MM_PERM_DCDC); + constexpr _MM_PERM_ENUM k20 = _MM_PERM_CACA; + constexpr _MM_PERM_ENUM k31 = _MM_PERM_DBDB; const auto l0_k0_j0_i0 = _mm512_shuffle_i64x2(j1_j0_i1_i0, l1_l0_k1_k0, k20); const auto l1_k1_j1_i1 = _mm512_shuffle_i64x2(j1_j0_i1_i0, l1_l0_k1_k0, k31); const auto l2_k2_j2_i2 = _mm512_shuffle_i64x2(j3_j2_i3_i2, l3_l2_k3_k2, k20); @@ -3631,103 +3778,6 @@ HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return BitCast(d, Or(min, ShiftLeft<16>(min))); } -// ================================================== DEPRECATED - -template -HWY_API size_t StoreMaskBits(const Mask512 mask, uint8_t* bits) { - return StoreMaskBits(Full512(), mask, bits); -} - -template -HWY_API bool AllTrue(const Mask512 mask) { - return AllTrue(Full512(), mask); -} - -template -HWY_API bool AllFalse(const Mask512 mask) { - return AllFalse(Full512(), mask); -} - -template -HWY_API size_t CountTrue(const Mask512 mask) { - return CountTrue(Full512(), mask); -} - -template -HWY_API Vec512 SumOfLanes(Vec512 v) { - return SumOfLanes(Full512(), v); -} - -template -HWY_API Vec512 MinOfLanes(Vec512 v) { - return MinOfLanes(Full512(), v); -} - -template -HWY_API Vec512 MaxOfLanes(Vec512 v) { - return MaxOfLanes(Full512(), v); -} - -template -HWY_API Vec256 UpperHalf(Vec512 v) { - return UpperHalf(Full256(), v); -} - -template -HWY_API Vec512 ShiftRightBytes(const Vec512 v) { - return ShiftRightBytes(Full512(), v); -} - -template -HWY_API Vec512 ShiftRightLanes(const Vec512 v) { - return ShiftRightBytes(Full512(), v); -} - -template -HWY_API Vec512 CombineShiftRightBytes(Vec512 hi, Vec512 lo) { - return CombineShiftRightBytes(Full512(), hi, lo); -} - -template -HWY_API Vec512 InterleaveUpper(Vec512 a, Vec512 b) { - return InterleaveUpper(Full512(), a, b); -} - -template -HWY_API Vec512> ZipUpper(Vec512 a, Vec512 b) { - return InterleaveUpper(Full512>(), a, b); -} - -template -HWY_API Vec512 Combine(Vec256 hi, Vec256 lo) { - return Combine(Full512(), hi, lo); -} - -template -HWY_API Vec512 ZeroExtendVector(Vec256 lo) { - return ZeroExtendVector(Full512(), lo); -} - -template -HWY_API Vec512 ConcatLowerLower(Vec512 hi, Vec512 lo) { - return ConcatLowerLower(Full512(), hi, lo); -} - -template -HWY_API Vec512 ConcatLowerUpper(Vec512 hi, Vec512 lo) { - return ConcatLowerUpper(Full512(), hi, lo); -} - -template -HWY_API Vec512 ConcatUpperLower(Vec512 hi, Vec512 lo) { - return ConcatUpperLower(Full512(), hi, lo); -} - -template -HWY_API Vec512 ConcatUpperUpper(Vec512 hi, Vec512 lo) { - return ConcatUpperUpper(Full512(), hi, lo); -} - // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy diff --git a/third_party/highway/hwy/targets.cc b/third_party/highway/hwy/targets.cc index daab3a67bef8..6719eb81e24a 100644 --- a/third_party/highway/hwy/targets.cc +++ b/third_party/highway/hwy/targets.cc @@ -15,23 +15,25 @@ #include "hwy/targets.h" #include +#include #include #include #include -#include -#include -#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ - defined(THREAD_SANITIZER) +#include "hwy/base.h" + +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN #include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace -#endif // defined(*_SANITIZER) +#endif + +#include // abort / exit #if HWY_ARCH_X86 #include #if HWY_COMPILER_MSVC #include -#else // HWY_COMPILER_MSVC +#else // !HWY_COMPILER_MSVC #include #endif // HWY_COMPILER_MSVC #endif // HWY_ARCH_X86 @@ -93,7 +95,7 @@ std::atomic supported_{0}; // Not yet initialized uint32_t supported_targets_for_test_ = 0; // Mask of targets disabled at runtime with DisableTargets. -uint32_t supported_mask_{std::numeric_limits::max()}; +uint32_t supported_mask_{LimitsMax()}; #if HWY_ARCH_X86 // Arbritrary bit indices indicating which instruction set extensions are @@ -190,21 +192,22 @@ HWY_NORETURN void HWY_FORMAT(3, 4) va_end(args); fprintf(stderr, "Abort at %s:%d: %s\n", file, line, buf); -#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ - defined(THREAD_SANITIZER) - // If compiled with any sanitizer print a stack trace. This call doesn't crash - // the program, instead the trap below will crash it also allowing gdb to - // break there. + +// If compiled with any sanitizer, they can also print a stack trace. +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN __sanitizer_print_stack_trace(); -#endif // defined(*_SANITIZER) +#endif // HWY_IS_* fflush(stderr); -#if HWY_COMPILER_MSVC - abort(); // Compile error without this due to HWY_NORETURN. -#elif HWY_ARCH_RVV - exit(1); // trap/abort just freeze Spike -#else +// Now terminate the program: +#if HWY_ARCH_RVV + exit(1); // trap/abort just freeze Spike. +#elif HWY_IS_DEBUG_BUILD && !HWY_COMPILER_MSVC + // Facilitates breaking into a debugger, but don't use this in non-debug + // builds because it looks like "illegal instruction", which is misleading. __builtin_trap(); +#else + abort(); // Compile error without this due to HWY_NORETURN. #endif } @@ -213,7 +216,7 @@ void DisableTargets(uint32_t disabled_targets) { // We can call Update() here to initialize the mask but that will trigger a // call to SupportedTargets() which we use in tests to tell whether any of the // highway dynamic dispatch functions were used. - chosen_target.DeInit(); + GetChosenTarget().DeInit(); } void SetSupportedTargetsForTest(uint32_t targets) { @@ -222,7 +225,7 @@ void SetSupportedTargetsForTest(uint32_t targets) { // if not zero. supported_.store(0, std::memory_order_release); supported_targets_for_test_ = targets; - chosen_target.DeInit(); + GetChosenTarget().DeInit(); } bool SupportedTargetsCalledForTest() { @@ -344,8 +347,10 @@ uint32_t SupportedTargets() { return bits & supported_mask_; } -// Declared in targets.h -ChosenTarget chosen_target; +HWY_DLLEXPORT ChosenTarget& GetChosenTarget() { + static ChosenTarget chosen_target; + return chosen_target; +} void ChosenTarget::Update() { // The supported variable contains the current CPU supported targets shifted diff --git a/third_party/highway/hwy/targets.h b/third_party/highway/hwy/targets.h index 95381e455377..90386a7ae49b 100644 --- a/third_party/highway/hwy/targets.h +++ b/third_party/highway/hwy/targets.h @@ -22,6 +22,7 @@ #include "hwy/base.h" #include "hwy/detect_targets.h" +#include "hwy/highway_export.h" namespace hwy { @@ -29,7 +30,7 @@ namespace hwy { // Implemented in targets.cc; unconditionally compiled to support the use case // of binary-only distributions. The HWY_SUPPORTED_TARGETS wrapper may allow // eliding calls to this function. -uint32_t SupportedTargets(); +HWY_DLLEXPORT uint32_t SupportedTargets(); // Evaluates to a function call, or literal if there is a single target. #if (HWY_TARGETS & (HWY_TARGETS - 1)) == 0 @@ -44,7 +45,7 @@ uint32_t SupportedTargets(); // lower target is desired. For this reason, attempts to disable targets which // are in HWY_ENABLED_BASELINE have no effect so SupportedTargets() always // returns at least the baseline target. -void DisableTargets(uint32_t disabled_targets); +HWY_DLLEXPORT void DisableTargets(uint32_t disabled_targets); // Set the mock mask of CPU supported targets instead of the actual CPU // supported targets computed in SupportedTargets(). The return value of @@ -52,11 +53,11 @@ void DisableTargets(uint32_t disabled_targets); // regardless of this mock, to prevent accidentally adding targets that are // known to be buggy in the current CPU. Call with a mask of 0 to disable the // mock and use the actual CPU supported targets instead. -void SetSupportedTargetsForTest(uint32_t targets); +HWY_DLLEXPORT void SetSupportedTargetsForTest(uint32_t targets); // Returns whether the SupportedTargets() function was called since the last // SetSupportedTargetsForTest() call. -bool SupportedTargetsCalledForTest(); +HWY_DLLEXPORT bool SupportedTargetsCalledForTest(); // Return the list of targets in HWY_TARGETS supported by the CPU as a list of // individual HWY_* target macros such as HWY_SCALAR or HWY_NEON. This list @@ -225,7 +226,7 @@ struct ChosenTarget { public: // Update the ChosenTarget mask based on the current CPU supported // targets. - void Update(); + HWY_DLLEXPORT void Update(); // Reset the ChosenTarget to the uninitialized state. void DeInit() { mask_.store(1); } @@ -245,11 +246,12 @@ struct ChosenTarget { } private: - // Initialized to 1 so GetChosenTargetIndex() returns 0. + // Initialized to 1 so GetIndex() returns 0. std::atomic mask_{1}; }; -extern ChosenTarget chosen_target; +// For internal use (e.g. by FunctionCache and DisableTargets). +HWY_DLLEXPORT ChosenTarget& GetChosenTarget(); } // namespace hwy diff --git a/third_party/highway/hwy/targets_test.cc b/third_party/highway/hwy/targets_test.cc index 5e6b4437d6db..62e677c47966 100644 --- a/third_party/highway/hwy/targets_test.cc +++ b/third_party/highway/hwy/targets_test.cc @@ -44,11 +44,11 @@ void CheckFakeFunction() { hwy::SetSupportedTargetsForTest(HWY_##TGT); \ /* Calling Update() first to make &HWY_DYNAMIC_DISPATCH() return */ \ /* the pointer to the already cached function. */ \ - hwy::chosen_target.Update(); \ + hwy::GetChosenTarget().Update(); \ EXPECT_EQ(uint32_t(HWY_##TGT), HWY_DYNAMIC_DISPATCH(FakeFunction)(42)); \ /* Calling DeInit() will test that the initializer function */ \ /* also calls the right function. */ \ - hwy::chosen_target.DeInit(); \ + hwy::GetChosenTarget().DeInit(); \ EXPECT_EQ(uint32_t(HWY_##TGT), HWY_DYNAMIC_DISPATCH(FakeFunction)(42)); \ /* Second call uses the cached value from the previous call. */ \ EXPECT_EQ(uint32_t(HWY_##TGT), HWY_DYNAMIC_DISPATCH(FakeFunction)(42)); \ diff --git a/third_party/highway/hwy/tests/arithmetic_test.cc b/third_party/highway/hwy/tests/arithmetic_test.cc index 6408acbd54e9..aa980473180e 100644 --- a/third_party/highway/hwy/tests/arithmetic_test.cc +++ b/third_party/highway/hwy/tests/arithmetic_test.cc @@ -177,387 +177,6 @@ HWY_NOINLINE void TestAllAbs() { ForFloatTypes(ForPartialVectors()); } -template -struct TestLeftShifts { - template - HWY_NOINLINE void operator()(T t, D d) { - if (kSigned) { - // Also test positive values - TestLeftShifts()(t, d); - } - - using TI = MakeSigned; - using TU = MakeUnsigned; - const size_t N = Lanes(d); - auto expected = AllocateAligned(N); - - const auto values = Iota(d, kSigned ? -TI(N) : TI(0)); // value to shift - constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; - - // 0 - HWY_ASSERT_VEC_EQ(d, values, ShiftLeft<0>(values)); - HWY_ASSERT_VEC_EQ(d, values, ShiftLeftSame(values, 0)); - - // 1 - for (size_t i = 0; i < N; ++i) { - const T value = kSigned ? T(T(i) - T(N)) : T(i); - expected[i] = T(TU(value) << 1); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft<1>(values)); - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, 1)); - - // max - for (size_t i = 0; i < N; ++i) { - const T value = kSigned ? T(T(i) - T(N)) : T(i); - expected[i] = T(TU(value) << kMaxShift); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft(values)); - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, kMaxShift)); - } -}; - -template -struct TestVariableLeftShifts { - template - HWY_NOINLINE void operator()(T t, D d) { - if (kSigned) { - // Also test positive values - TestVariableLeftShifts()(t, d); - } - - using TI = MakeSigned; - using TU = MakeUnsigned; - const size_t N = Lanes(d); - auto expected = AllocateAligned(N); - - const auto v0 = Zero(d); - const auto v1 = Set(d, 1); - const auto values = Iota(d, kSigned ? -TI(N) : TI(0)); // value to shift - - constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; - const auto max_shift = Set(d, kMaxShift); - const auto small_shifts = And(Iota(d, 0), max_shift); - const auto large_shifts = max_shift - small_shifts; - - // Same: 0 - HWY_ASSERT_VEC_EQ(d, values, Shl(values, v0)); - - // Same: 1 - for (size_t i = 0; i < N; ++i) { - const T value = kSigned ? T(i) - T(N) : T(i); - expected[i] = T(TU(value) << 1); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, v1)); - - // Same: max - for (size_t i = 0; i < N; ++i) { - const T value = kSigned ? T(i) - T(N) : T(i); - expected[i] = T(TU(value) << kMaxShift); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, max_shift)); - - // Variable: small - for (size_t i = 0; i < N; ++i) { - const T value = kSigned ? T(i) - T(N) : T(i); - expected[i] = T(TU(value) << (i & kMaxShift)); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, small_shifts)); - - // Variable: large - for (size_t i = 0; i < N; ++i) { - expected[i] = T(TU(1) << (kMaxShift - (i & kMaxShift))); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(v1, large_shifts)); - } -}; - -struct TestUnsignedRightShifts { - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - const size_t N = Lanes(d); - auto expected = AllocateAligned(N); - - const auto values = Iota(d, 0); - - const T kMax = LimitsMax(); - constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; - - // Shift by 0 - HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); - HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); - - // Shift by 1 - for (size_t i = 0; i < N; ++i) { - expected[i] = T(T(i & kMax) >> 1); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); - - // max - for (size_t i = 0; i < N; ++i) { - expected[i] = T(T(i & kMax) >> kMaxShift); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight(values)); - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, kMaxShift)); - } -}; - -struct TestRotateRight { - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - const size_t N = Lanes(d); - auto expected = AllocateAligned(N); - - constexpr size_t kBits = sizeof(T) * 8; - const auto mask_shift = Set(d, T{kBits}); - // Cover as many bit positions as possible to test shifting out - const auto values = Shl(Set(d, T{1}), And(Iota(d, 0), mask_shift)); - - // Rotate by 0 - HWY_ASSERT_VEC_EQ(d, values, RotateRight<0>(values)); - - // Rotate by 1 - Store(values, d, expected.get()); - for (size_t i = 0; i < N; ++i) { - expected[i] = (expected[i] >> 1) | (expected[i] << (kBits - 1)); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight<1>(values)); - - // Rotate by half - Store(values, d, expected.get()); - for (size_t i = 0; i < N; ++i) { - expected[i] = (expected[i] >> (kBits / 2)) | (expected[i] << (kBits / 2)); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight(values)); - - // Rotate by max - Store(values, d, expected.get()); - for (size_t i = 0; i < N; ++i) { - expected[i] = (expected[i] >> (kBits - 1)) | (expected[i] << 1); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight(values)); - } -}; - -struct TestVariableUnsignedRightShifts { - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - const size_t N = Lanes(d); - auto expected = AllocateAligned(N); - - const auto v0 = Zero(d); - const auto v1 = Set(d, 1); - const auto values = Iota(d, 0); - - const T kMax = LimitsMax(); - const auto max = Set(d, kMax); - - constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; - const auto max_shift = Set(d, kMaxShift); - const auto small_shifts = And(Iota(d, 0), max_shift); - const auto large_shifts = max_shift - small_shifts; - - // Same: 0 - HWY_ASSERT_VEC_EQ(d, values, Shr(values, v0)); - - // Same: 1 - for (size_t i = 0; i < N; ++i) { - expected[i] = T(T(i & kMax) >> 1); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, v1)); - - // Same: max - HWY_ASSERT_VEC_EQ(d, v0, Shr(values, max_shift)); - - // Variable: small - for (size_t i = 0; i < N; ++i) { - expected[i] = T(i) >> (i & kMaxShift); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, small_shifts)); - - // Variable: Large - for (size_t i = 0; i < N; ++i) { - expected[i] = kMax >> (kMaxShift - (i & kMaxShift)); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(max, large_shifts)); - } -}; - -template -T RightShiftNegative(T val) { - // C++ shifts are implementation-defined for negative numbers, and we have - // seen divisions replaced with shifts, so resort to bit operations. - using TU = hwy::MakeUnsigned; - TU bits; - CopyBytes(&val, &bits); - - const TU shifted = TU(bits >> kAmount); - - const TU all = TU(~TU(0)); - const size_t num_zero = sizeof(TU) * 8 - 1 - kAmount; - const TU sign_extended = static_cast((all << num_zero) & LimitsMax()); - - bits = shifted | sign_extended; - CopyBytes(&bits, &val); - return val; -} - -class TestSignedRightShifts { - public: - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - const size_t N = Lanes(d); - auto expected = AllocateAligned(N); - constexpr T kMin = LimitsMin(); - constexpr T kMax = LimitsMax(); - constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; - - // First test positive values, negative are checked below. - const auto v0 = Zero(d); - const auto values = And(Iota(d, 0), Set(d, kMax)); - - // Shift by 0 - HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); - HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); - - // Shift by 1 - for (size_t i = 0; i < N; ++i) { - expected[i] = T(T(i & kMax) >> 1); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); - - // max - HWY_ASSERT_VEC_EQ(d, v0, ShiftRight(values)); - HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(values, kMaxShift)); - - // Even negative value - Test<0>(kMin, d, __LINE__); - Test<1>(kMin, d, __LINE__); - Test<2>(kMin, d, __LINE__); - Test(kMin, d, __LINE__); - - const T odd = static_cast(kMin + 1); - Test<0>(odd, d, __LINE__); - Test<1>(odd, d, __LINE__); - Test<2>(odd, d, __LINE__); - Test(odd, d, __LINE__); - } - - private: - template - void Test(T val, D d, int line) { - const auto expected = Set(d, RightShiftNegative(val)); - const auto in = Set(d, val); - const char* file = __FILE__; - AssertVecEqual(d, expected, ShiftRight(in), file, line); - AssertVecEqual(d, expected, ShiftRightSame(in, kAmount), file, line); - } -}; - -struct TestVariableSignedRightShifts { - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - using TU = MakeUnsigned; - const size_t N = Lanes(d); - auto expected = AllocateAligned(N); - - constexpr T kMin = LimitsMin(); - constexpr T kMax = LimitsMax(); - - constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; - - // First test positive values, negative are checked below. - const auto v0 = Zero(d); - const auto positive = Iota(d, 0) & Set(d, kMax); - - // Shift by 0 - HWY_ASSERT_VEC_EQ(d, positive, ShiftRight<0>(positive)); - HWY_ASSERT_VEC_EQ(d, positive, ShiftRightSame(positive, 0)); - - // Shift by 1 - for (size_t i = 0; i < N; ++i) { - expected[i] = T(T(i & kMax) >> 1); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(positive)); - HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(positive, 1)); - - // max - HWY_ASSERT_VEC_EQ(d, v0, ShiftRight(positive)); - HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(positive, kMaxShift)); - - const auto max_shift = Set(d, kMaxShift); - const auto small_shifts = And(Iota(d, 0), max_shift); - const auto large_shifts = max_shift - small_shifts; - - const auto negative = Iota(d, kMin); - - // Test varying negative to shift - for (size_t i = 0; i < N; ++i) { - expected[i] = RightShiftNegative<1>(static_cast(kMin + i)); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(negative, Set(d, 1))); - - // Shift MSB right by small amounts - for (size_t i = 0; i < N; ++i) { - const size_t amount = i & kMaxShift; - const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); - CopyBytes(&shifted, &expected[i]); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), small_shifts)); - - // Shift MSB right by large amounts - for (size_t i = 0; i < N; ++i) { - const size_t amount = kMaxShift - (i & kMaxShift); - const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); - CopyBytes(&shifted, &expected[i]); - } - HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), large_shifts)); - } -}; - -HWY_NOINLINE void TestAllShifts() { - ForUnsignedTypes(ForPartialVectors>()); - ForSignedTypes(ForPartialVectors>()); - ForUnsignedTypes(ForPartialVectors()); - ForSignedTypes(ForPartialVectors()); -} - -HWY_NOINLINE void TestAllVariableShifts() { - const ForPartialVectors> shl_u; - const ForPartialVectors> shl_s; - const ForPartialVectors shr_u; - const ForPartialVectors shr_s; - - shl_u(uint16_t()); - shr_u(uint16_t()); - - shl_u(uint32_t()); - shr_u(uint32_t()); - - shl_s(int16_t()); - shr_s(int16_t()); - - shl_s(int32_t()); - shr_s(int32_t()); - -#if HWY_CAP_INTEGER64 - shl_u(uint64_t()); - shr_u(uint64_t()); - - shl_s(int64_t()); - shr_s(int64_t()); -#endif -} - -HWY_NOINLINE void TestAllRotateRight() { - const ForPartialVectors test; - test(uint32_t()); -#if HWY_CAP_INTEGER64 - test(uint64_t()); -#endif -} - struct TestUnsignedMinMax { template HWY_NOINLINE void operator()(T /*unused*/, D d) { @@ -644,6 +263,84 @@ HWY_NOINLINE void TestAllMinMax() { ForFloatTypes(ForPartialVectors()); } +class TestMinMax128 { + template + static HWY_NOINLINE Vec Make128(D d, uint64_t hi, uint64_t lo) { + alignas(16) uint64_t in[2]; + in[0] = lo; + in[1] = hi; + return LoadDup128(d, in); + } + + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec; + const size_t N = Lanes(d); + auto a_lanes = AllocateAligned(N); + auto b_lanes = AllocateAligned(N); + auto min_lanes = AllocateAligned(N); + auto max_lanes = AllocateAligned(N); + RandomState rng; + + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + // Same arg + HWY_ASSERT_VEC_EQ(d, v00, Min128(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Min128(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Min128(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Min128(d, v11, v11)); + HWY_ASSERT_VEC_EQ(d, v00, Max128(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Max128(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Max128(d, v11, v11)); + + // First arg less + HWY_ASSERT_VEC_EQ(d, v00, Min128(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v01, Min128(d, v01, v10)); + HWY_ASSERT_VEC_EQ(d, v10, Min128(d, v10, v11)); + HWY_ASSERT_VEC_EQ(d, v01, Max128(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128(d, v01, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Max128(d, v10, v11)); + + // Second arg less + HWY_ASSERT_VEC_EQ(d, v00, Min128(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Min128(d, v10, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Min128(d, v11, v10)); + HWY_ASSERT_VEC_EQ(d, v01, Max128(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v10, Max128(d, v10, v01)); + HWY_ASSERT_VEC_EQ(d, v11, Max128(d, v11, v10)); + + // Also check 128-bit blocks are independent + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + a_lanes[i] = Random64(&rng); + b_lanes[i] = Random64(&rng); + } + const V a = Load(d, a_lanes.get()); + const V b = Load(d, b_lanes.get()); + for (size_t i = 0; i < N; i += 2) { + const bool lt = a_lanes[i + 1] == b_lanes[i + 1] + ? (a_lanes[i] < b_lanes[i]) + : (a_lanes[i + 1] < b_lanes[i + 1]); + min_lanes[i + 0] = lt ? a_lanes[i + 0] : b_lanes[i + 0]; + min_lanes[i + 1] = lt ? a_lanes[i + 1] : b_lanes[i + 1]; + max_lanes[i + 0] = lt ? b_lanes[i + 0] : a_lanes[i + 0]; + max_lanes[i + 1] = lt ? b_lanes[i + 1] : a_lanes[i + 1]; + } + HWY_ASSERT_VEC_EQ(d, min_lanes.get(), Min128(d, a, b)); + HWY_ASSERT_VEC_EQ(d, max_lanes.get(), Max128(d, a, b)); + } + } +}; + +HWY_NOINLINE void TestAllMinMax128() { + ForGEVectors<128, TestMinMax128>()(uint64_t()); +} + struct TestUnsignedMul { template HWY_NOINLINE void operator()(T /*unused*/, D d) { @@ -834,11 +531,11 @@ struct TestMulEvenOdd64 { }; HWY_NOINLINE void TestAllMulEven() { - ForExtendableVectors test; + ForGEVectors<64, TestMulEven> test; test(int32_t()); test(uint32_t()); - ForGE128Vectors()(uint64_t()); + ForGEVectors<128, TestMulEvenOdd64>()(uint64_t()); } struct TestMulAdd { @@ -1113,7 +810,6 @@ AlignedFreeUniquePtr RoundTestCases(T /*unused*/, D d, size_t& padded) { // negative +/- epsilon T(-1) + eps, T(-1) - eps, -#if !defined(HWY_EMULATE_SVE) // these are not safe to just cast to int // +/- huge (but still fits in float) T(1E34), T(-1E35), @@ -1122,7 +818,6 @@ AlignedFreeUniquePtr RoundTestCases(T /*unused*/, D d, size_t& padded) { -std::numeric_limits::infinity(), // qNaN GetLane(NaN(d)) -#endif }; const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); const size_t N = Lanes(d); @@ -1369,6 +1064,41 @@ HWY_NOINLINE void TestAllAbsDiff() { ForPartialVectors()(float()); } +struct TestSumsOf8 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const size_t N = Lanes(d); + if (N < 8) return; + const Repartition du64; + + auto in_lanes = AllocateAligned(N); + auto sum_lanes = AllocateAligned(N / 8); + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = Random64(&rng) & 0xFF; + } + + for (size_t idx_sum = 0; idx_sum < N / 8; ++idx_sum) { + uint64_t sum = 0; + for (size_t i = 0; i < 8; ++i) { + sum += in_lanes[idx_sum * 8 + i]; + } + sum_lanes[idx_sum] = sum; + } + + const Vec in = Load(d, in_lanes.get()); + HWY_ASSERT_VEC_EQ(du64, sum_lanes.get(), SumsOf8(in)); + } + } +}; + +HWY_NOINLINE void TestAllSumsOf8() { + ForGEVectors<64, TestSumsOf8>()(uint8_t()); +} + struct TestNeg { template HWY_NOINLINE void operator()(T /*unused*/, D d) { @@ -1397,10 +1127,8 @@ namespace hwy { HWY_BEFORE_TEST(HwyArithmeticTest); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllPlusMinus); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllSaturatingArithmetic); -HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllShifts); -HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllVariableShifts); -HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllRotateRight); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMax); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMax128); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAverage); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAbs); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMul); @@ -1420,6 +1148,7 @@ HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllTrunc); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllCeil); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllFloor); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAbsDiff); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllSumsOf8); HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllNeg); } // namespace hwy diff --git a/third_party/highway/hwy/tests/blockwise_test.cc b/third_party/highway/hwy/tests/blockwise_test.cc index eb4e0ee80b56..2a063a276585 100644 --- a/third_party/highway/hwy/tests/blockwise_test.cc +++ b/third_party/highway/hwy/tests/blockwise_test.cc @@ -393,7 +393,7 @@ HWY_NOINLINE void TestAllZip() { lower_unsigned(uint8_t()); #endif lower_unsigned(uint16_t()); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 lower_unsigned(uint32_t()); // generates u64 #endif @@ -402,7 +402,7 @@ HWY_NOINLINE void TestAllZip() { lower_signed(int8_t()); #endif lower_signed(int16_t()); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 lower_signed(int32_t()); // generates i64 #endif @@ -411,7 +411,7 @@ HWY_NOINLINE void TestAllZip() { upper_unsigned(uint8_t()); #endif upper_unsigned(uint16_t()); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 upper_unsigned(uint32_t()); // generates u64 #endif @@ -420,19 +420,20 @@ HWY_NOINLINE void TestAllZip() { upper_signed(int8_t()); #endif upper_signed(int16_t()); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 upper_signed(int32_t()); // generates i64 #endif // No float - concatenating f32 does not result in a f64 } -template -struct TestCombineShiftRightBytesR { - template - HWY_NOINLINE void operator()(T t, D d) { // Scalar does not define CombineShiftRightBytes. #if HWY_TARGET != HWY_SCALAR || HWY_IDE + +template +struct TestCombineShiftRightBytes { + template + HWY_NOINLINE void operator()(T, D d) { const size_t kBlockSize = 16; static_assert(kBytes < kBlockSize, "Shift count is per block"); const Repartition d8; @@ -461,21 +462,13 @@ struct TestCombineShiftRightBytesR { const auto expected = BitCast(d, Load(d8, expected_bytes.get())); HWY_ASSERT_VEC_EQ(d, expected, CombineShiftRightBytes(d, hi, lo)); } - - TestCombineShiftRightBytesR()(t, d); -#else - (void)t; - (void)d; -#endif // #if HWY_TARGET != HWY_SCALAR } }; template -struct TestCombineShiftRightLanesR { +struct TestCombineShiftRightLanes { template - HWY_NOINLINE void operator()(T t, D d) { -// Scalar does not define CombineShiftRightBytes (needed for *Lanes). -#if HWY_TARGET != HWY_SCALAR || HWY_IDE + HWY_NOINLINE void operator()(T, D d) { const Repartition d8; const size_t N8 = Lanes(d8); if (N8 < 16) return; @@ -505,33 +498,29 @@ struct TestCombineShiftRightLanesR { const auto expected = BitCast(d, Load(d8, expected_bytes.get())); HWY_ASSERT_VEC_EQ(d, expected, CombineShiftRightLanes(d, hi, lo)); } - - TestCombineShiftRightLanesR()(t, d); -#else - (void)t; - (void)d; -#endif // #if HWY_TARGET != HWY_SCALAR } }; -template <> -struct TestCombineShiftRightBytesR<0> { - template - void operator()(T /*unused*/, D /*unused*/) {} -}; - -template <> -struct TestCombineShiftRightLanesR<0> { - template - void operator()(T /*unused*/, D /*unused*/) {} -}; +#endif // #if HWY_TARGET != HWY_SCALAR struct TestCombineShiftRight { template HWY_NOINLINE void operator()(T t, D d) { +// Scalar does not define CombineShiftRightBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE constexpr int kMaxBytes = HWY_MIN(16, int(MaxLanes(d) * sizeof(T))); - TestCombineShiftRightBytesR()(t, d); - TestCombineShiftRightLanesR()(t, d); + constexpr int kMaxLanes = kMaxBytes / static_cast(sizeof(T)); + TestCombineShiftRightBytes()(t, d); + TestCombineShiftRightBytes()(t, d); + TestCombineShiftRightBytes<1>()(t, d); + + TestCombineShiftRightLanes()(t, d); + TestCombineShiftRightLanes()(t, d); + TestCombineShiftRightLanes<1>()(t, d); +#else + (void)t; + (void)d; +#endif } }; @@ -553,11 +542,13 @@ class TestSpecialShuffle32 { } private: + // HWY_INLINE works around a Clang SVE compiler bug where all but the first + // 128 bits (the NEON register) of actual are zero. template - HWY_NOINLINE void VerifyLanes32(D d, VecArg actual, const size_t i3, - const size_t i2, const size_t i1, - const size_t i0, const char* filename, - const int line) { + HWY_INLINE void VerifyLanes32(D d, VecArg actual, const size_t i3, + const size_t i2, const size_t i1, + const size_t i0, const char* filename, + const int line) { using T = TFromD; constexpr size_t kBlockN = 16 / sizeof(T); const size_t N = Lanes(d); @@ -582,10 +573,12 @@ class TestSpecialShuffle64 { } private: + // HWY_INLINE works around a Clang SVE compiler bug where all but the first + // 128 bits (the NEON register) of actual are zero. template - HWY_NOINLINE void VerifyLanes64(D d, VecArg actual, const size_t i1, - const size_t i0, const char* filename, - const int line) { + HWY_INLINE void VerifyLanes64(D d, VecArg actual, const size_t i1, + const size_t i0, const char* filename, + const int line) { using T = TFromD; constexpr size_t kBlockN = 16 / sizeof(T); const size_t N = Lanes(d); @@ -600,19 +593,19 @@ class TestSpecialShuffle64 { }; HWY_NOINLINE void TestAllSpecialShuffles() { - const ForGE128Vectors test32; + const ForGEVectors<128, TestSpecialShuffle32> test32; test32(uint32_t()); test32(int32_t()); test32(float()); -#if HWY_CAP_INTEGER64 - const ForGE128Vectors test64; +#if HWY_HAVE_INTEGER64 + const ForGEVectors<128, TestSpecialShuffle64> test64; test64(uint64_t()); test64(int64_t()); #endif -#if HWY_CAP_FLOAT64 - const ForGE128Vectors test_d; +#if HWY_HAVE_FLOAT64 + const ForGEVectors<128, TestSpecialShuffle64> test_d; test_d(double()); #endif } diff --git a/third_party/highway/hwy/tests/combine_test.cc b/third_party/highway/hwy/tests/combine_test.cc index ba37f39ef257..1bc6315e0207 100644 --- a/third_party/highway/hwy/tests/combine_test.cc +++ b/third_party/highway/hwy/tests/combine_test.cc @@ -22,9 +22,6 @@ #include "hwy/highway.h" #include "hwy/tests/test_util-inl.h" -// Not yet implemented -#if HWY_TARGET != HWY_RVV - HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { @@ -85,8 +82,8 @@ struct TestLowerQuarter { }; HWY_NOINLINE void TestAllLowerHalf() { - ForAllTypes(ForDemoteVectors()); - ForAllTypes(ForDemoteVectors()); + ForAllTypes(ForHalfVectors()); + ForAllTypes(ForHalfVectors()); } struct TestUpperHalf { @@ -95,21 +92,14 @@ struct TestUpperHalf { // Scalar does not define UpperHalf. #if HWY_TARGET != HWY_SCALAR const Half d2; - - const auto v = Iota(d, 1); - const size_t N = Lanes(d); - auto lanes = AllocateAligned(N); - std::fill(lanes.get(), lanes.get() + N, T(0)); - - Store(UpperHalf(d2, v), d2, lanes.get()); + const size_t N2 = Lanes(d2); + HWY_ASSERT(N2 * 2 == Lanes(d)); + auto expected = AllocateAligned(N2); size_t i = 0; - for (; i < Lanes(d2); ++i) { - HWY_ASSERT_EQ(T(Lanes(d2) + 1 + i), lanes[i]); - } - // Other half remains unchanged - for (; i < N; ++i) { - HWY_ASSERT_EQ(T(0), lanes[i]); + for (; i < N2; ++i) { + expected[i] = static_cast(N2 + 1 + i); } + HWY_ASSERT_VEC_EQ(d2, expected.get(), UpperHalf(d2, Iota(d, 1))); #else (void)d; #endif @@ -117,7 +107,7 @@ struct TestUpperHalf { }; HWY_NOINLINE void TestAllUpperHalf() { - ForAllTypes(ForShrinkableVectors()); + ForAllTypes(ForHalfVectors()); } struct TestZeroExtendVector { @@ -126,23 +116,23 @@ struct TestZeroExtendVector { const Twice d2; const auto v = Iota(d, 1); + const size_t N = Lanes(d); const size_t N2 = Lanes(d2); + // If equal, then N was already MaxLanes(d) and it's not clear what + // Combine or ZeroExtendVector should return. + if (N2 == N) return; + HWY_ASSERT(N2 == 2 * N); auto lanes = AllocateAligned(N2); Store(v, d, &lanes[0]); - Store(v, d, &lanes[N2 / 2]); + Store(v, d, &lanes[N]); const auto ext = ZeroExtendVector(d2, v); Store(ext, d2, lanes.get()); - size_t i = 0; // Lower half is unchanged - for (; i < N2 / 2; ++i) { - HWY_ASSERT_EQ(T(1 + i), lanes[i]); - } + HWY_ASSERT_VEC_EQ(d, v, Load(d, &lanes[0])); // Upper half is zero - for (; i < N2; ++i) { - HWY_ASSERT_EQ(T(0), lanes[i]); - } + HWY_ASSERT_VEC_EQ(d, Zero(d), Load(d, &lanes[N])); } }; @@ -158,7 +148,7 @@ struct TestCombine { auto lanes = AllocateAligned(N2); const auto lo = Iota(d, 1); - const auto hi = Iota(d, N2 / 2 + 1); + const auto hi = Iota(d, static_cast(N2 / 2 + 1)); const auto combined = Combine(d2, hi, lo); Store(combined, d2, lanes.get()); @@ -232,7 +222,7 @@ struct TestConcatOddEven { HWY_NOINLINE void operator()(T /*unused*/, D d) { #if HWY_TARGET != HWY_RVV && HWY_TARGET != HWY_SCALAR const size_t N = Lanes(d); - const auto hi = Iota(d, N); + const auto hi = Iota(d, static_cast(N)); const auto lo = Iota(d, 0); const auto even = Add(Iota(d, 0), Iota(d, 0)); const auto odd = Add(even, Set(d, 1)); @@ -272,7 +262,3 @@ int main(int argc, char **argv) { } #endif // HWY_ONCE - -#else -int main(int, char**) { return 0; } -#endif // HWY_TARGET != HWY_RVV diff --git a/third_party/highway/hwy/tests/compare_test.cc b/third_party/highway/hwy/tests/compare_test.cc index 85cc802536f2..24383b979d1d 100644 --- a/third_party/highway/hwy/tests/compare_test.cc +++ b/third_party/highway/hwy/tests/compare_test.cc @@ -218,6 +218,63 @@ HWY_NOINLINE void TestAllWeakFloat() { ForFloatTypes(ForPartialVectors()); } +class TestLt128 { + template + static HWY_NOINLINE Vec Make128(D d, uint64_t hi, uint64_t lo) { + alignas(16) uint64_t in[2]; + in[0] = lo; + in[1] = hi; + return LoadDup128(d, in); + } + + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec; + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v10, v10)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v01, v11)); + + // Reversed order + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v11, v01)); + + // Also check 128-bit blocks are independent + const V iota = Iota(d, 1); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, Add(iota, v10), iota)); + + // Max value + const V vm = Make128(d, LimitsMax(), LimitsMax()); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v11, vm)); + } +}; + +HWY_NOINLINE void TestAllLt128() { ForGEVectors<128, TestLt128>()(uint64_t()); } + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy @@ -232,6 +289,7 @@ HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictUnsigned); HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictInt); HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictFloat); HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllWeakFloat); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128); } // namespace hwy // Ought not to be necessary, but without this, no tests run on RVV. diff --git a/third_party/highway/hwy/tests/convert_test.cc b/third_party/highway/hwy/tests/convert_test.cc index aeed5cc81061..8c4b8bdedc30 100644 --- a/third_party/highway/hwy/tests/convert_test.cc +++ b/third_party/highway/hwy/tests/convert_test.cc @@ -57,17 +57,17 @@ struct TestBitCastFrom { TestBitCast()(t, d); TestBitCast()(t, d); TestBitCast()(t, d); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 TestBitCast()(t, d); #endif TestBitCast()(t, d); TestBitCast()(t, d); TestBitCast()(t, d); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 TestBitCast()(t, d); #endif TestBitCast()(t, d); -#if HWY_CAP_FLOAT64 +#if HWY_HAVE_FLOAT64 TestBitCast()(t, d); #endif } @@ -103,39 +103,39 @@ HWY_NOINLINE void TestAllBitCast() { to_i32(int32_t()); to_i32(float()); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 const ForPartialVectors> to_u64; to_u64(uint64_t()); to_u64(int64_t()); -#if HWY_CAP_FLOAT64 +#if HWY_HAVE_FLOAT64 to_u64(double()); #endif const ForPartialVectors> to_i64; to_i64(uint64_t()); to_i64(int64_t()); -#if HWY_CAP_FLOAT64 +#if HWY_HAVE_FLOAT64 to_i64(double()); #endif -#endif // HWY_CAP_INTEGER64 +#endif // HWY_HAVE_INTEGER64 const ForPartialVectors> to_float; to_float(uint32_t()); to_float(int32_t()); to_float(float()); -#if HWY_CAP_FLOAT64 +#if HWY_HAVE_FLOAT64 const ForPartialVectors> to_double; to_double(double()); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 to_double(uint64_t()); to_double(int64_t()); -#endif // HWY_CAP_INTEGER64 -#endif // HWY_CAP_FLOAT64 +#endif // HWY_HAVE_INTEGER64 +#endif // HWY_HAVE_FLOAT64 #if HWY_TARGET != HWY_SCALAR // For non-scalar vectors, we can cast all types to all. - ForAllTypes(ForGE64Vectors()); + ForAllTypes(ForGEVectors<64, TestBitCastFrom>()); #endif } @@ -165,39 +165,39 @@ struct TestPromoteTo { }; HWY_NOINLINE void TestAllPromoteTo() { - const ForPromoteVectors, 2> to_u16div2; + const ForPromoteVectors, 1> to_u16div2; to_u16div2(uint8_t()); - const ForPromoteVectors, 4> to_u32div4; + const ForPromoteVectors, 2> to_u32div4; to_u32div4(uint8_t()); - const ForPromoteVectors, 2> to_u32div2; + const ForPromoteVectors, 1> to_u32div2; to_u32div2(uint16_t()); - const ForPromoteVectors, 2> to_i16div2; + const ForPromoteVectors, 1> to_i16div2; to_i16div2(uint8_t()); to_i16div2(int8_t()); - const ForPromoteVectors, 2> to_i32div2; + const ForPromoteVectors, 1> to_i32div2; to_i32div2(uint16_t()); to_i32div2(int16_t()); - const ForPromoteVectors, 4> to_i32div4; + const ForPromoteVectors, 2> to_i32div4; to_i32div4(uint8_t()); to_i32div4(int8_t()); // Must test f16/bf16 separately because we can only load/store/convert them. -#if HWY_CAP_INTEGER64 - const ForPromoteVectors, 2> to_u64div2; +#if HWY_HAVE_INTEGER64 + const ForPromoteVectors, 1> to_u64div2; to_u64div2(uint32_t()); - const ForPromoteVectors, 2> to_i64div2; + const ForPromoteVectors, 1> to_i64div2; to_i64div2(int32_t()); #endif -#if HWY_CAP_FLOAT64 - const ForPromoteVectors, 2> to_f64div2; +#if HWY_HAVE_FLOAT64 + const ForPromoteVectors, 1> to_f64div2; to_f64div2(int32_t()); to_f64div2(float()); #endif @@ -213,111 +213,6 @@ bool IsFinite(T /*unused*/) { return true; } -template -struct TestDemoteTo { - template - HWY_NOINLINE void operator()(T /*unused*/, D from_d) { - static_assert(!IsFloat(), "Use TestDemoteToFloat for float output"); - static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); - const Rebind to_d; - - const size_t N = Lanes(from_d); - auto from = AllocateAligned(N); - auto expected = AllocateAligned(N); - - // Narrower range in the wider type, for clamping before we cast - const T min = LimitsMin(); - const T max = LimitsMax(); - - const auto value_ok = [&](T& value) { - if (!IsFinite(value)) return false; -#if HWY_EMULATE_SVE - // farm_sve just casts, which is undefined if the value is out of range. - value = HWY_MIN(HWY_MAX(min, value), max); -#endif - return true; - }; - - RandomState rng; - for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { - for (size_t i = 0; i < N; ++i) { - do { - const uint64_t bits = rng(); - memcpy(&from[i], &bits, sizeof(T)); - } while (!value_ok(from[i])); - expected[i] = static_cast(HWY_MIN(HWY_MAX(min, from[i]), max)); - } - - HWY_ASSERT_VEC_EQ(to_d, expected.get(), - DemoteTo(to_d, Load(from_d, from.get()))); - } - } -}; - -HWY_NOINLINE void TestAllDemoteToInt() { - ForDemoteVectors>()(int16_t()); - ForDemoteVectors, 4>()(int32_t()); - - ForDemoteVectors>()(int16_t()); - ForDemoteVectors, 4>()(int32_t()); - - const ForDemoteVectors> to_u16; - to_u16(int32_t()); - - const ForDemoteVectors> to_i16; - to_i16(int32_t()); -} - -HWY_NOINLINE void TestAllDemoteToMixed() { -#if HWY_CAP_FLOAT64 - const ForDemoteVectors> to_i32; - to_i32(double()); -#endif -} - -template -struct TestDemoteToFloat { - template - HWY_NOINLINE void operator()(T /*unused*/, D from_d) { - // For floats, we clamp differently and cannot call LimitsMin. - static_assert(IsFloat(), "Use TestDemoteTo for integer output"); - static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); - const Rebind to_d; - - const size_t N = Lanes(from_d); - auto from = AllocateAligned(N); - auto expected = AllocateAligned(N); - - RandomState rng; - for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { - for (size_t i = 0; i < N; ++i) { - do { - const uint64_t bits = rng(); - memcpy(&from[i], &bits, sizeof(T)); - } while (!IsFinite(from[i])); - const T magn = std::abs(from[i]); - const T max_abs = HighestValue(); - // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see - // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html - const T clipped = copysign(HWY_MIN(magn, max_abs), from[i]); - expected[i] = static_cast(clipped); - } - - HWY_ASSERT_VEC_EQ(to_d, expected.get(), - DemoteTo(to_d, Load(from_d, from.get()))); - } - } -}; - -HWY_NOINLINE void TestAllDemoteToFloat() { - // Must test f16 separately because we can only load/store/convert them. - -#if HWY_CAP_FLOAT64 - const ForDemoteVectors, 2> to_float; - to_float(double()); -#endif -} - template AlignedFreeUniquePtr F16TestCases(D d, size_t& padded) { const float test_cases[] = { @@ -352,7 +247,7 @@ AlignedFreeUniquePtr F16TestCases(D d, size_t& padded) { struct TestF16 { template HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { -#if HWY_CAP_FLOAT16 +#if HWY_HAVE_FLOAT16 size_t padded; auto in = F16TestCases(d32, padded); using TF16 = float16_t; @@ -406,7 +301,7 @@ AlignedFreeUniquePtr BF16TestCases(D d, size_t& padded) { struct TestBF16 { template HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { -#if HWY_TARGET != HWY_RVV +#if !defined(HWY_EMULATE_SVE) size_t padded; auto in = BF16TestCases(d32, padded); using TBF16 = bfloat16_t; @@ -417,6 +312,7 @@ struct TestBF16 { #endif const Half dbf16_half; const size_t N = Lanes(d32); + HWY_ASSERT(Lanes(dbf16_half) <= N); auto temp16 = AllocateAligned(N); for (size_t i = 0; i < padded; i += N) { @@ -434,124 +330,6 @@ struct TestBF16 { HWY_NOINLINE void TestAllBF16() { ForShrinkableVectors()(float()); } -template -AlignedFreeUniquePtr ReorderBF16TestCases(D d, size_t& padded) { - const float test_cases[] = { - // Same as BF16TestCases: - // +/- 1 - 1.0f, - -1.0f, - // +/- 0 - 0.0f, - -0.0f, - // near 0 - 0.25f, - -0.25f, - // +/- integer - 4.0f, - -32.0f, - // positive +/- delta - 2.015625f, - 3.984375f, - // negative +/- delta - -2.015625f, - -3.984375f, - - // No huge values - would interfere with sum. But add more to fill 2 * N: - -2.0f, - -10.0f, - 0.03125f, - 1.03125f, - 1.5f, - 2.0f, - 4.0f, - 5.0f, - 6.0f, - 8.0f, - 10.0f, - 256.0f, - 448.0f, - 2080.0f, - }; - const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); - const size_t N = Lanes(d); - padded = RoundUpTo(kNumTestCases, 2 * N); // allow loading pairs of vectors - auto in = AllocateAligned(padded); - auto expected = AllocateAligned(padded); - std::copy(test_cases, test_cases + kNumTestCases, in.get()); - std::fill(in.get() + kNumTestCases, in.get() + padded, 0.0f); - return in; -} - -class TestReorderDemote2To { - // In-place N^2 selection sort to avoid dependencies - void Sort(float* p, size_t count) { - for (size_t i = 0; i < count - 1; ++i) { - // Find min_element - size_t idx_min = i; - for (size_t j = i + 1; j < count; j++) { - if (p[j] < p[idx_min]) { - idx_min = j; - } - } - - // Swap with current - const float tmp = p[i]; - p[i] = p[idx_min]; - p[idx_min] = tmp; - } - } - - public: - template - HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { -#if HWY_TARGET != HWY_SCALAR - size_t padded; - auto in = ReorderBF16TestCases(d32, padded); - - using TBF16 = bfloat16_t; - const Repartition dbf16; - const Half dbf16_half; - const size_t N = Lanes(d32); - auto temp16 = AllocateAligned(2 * N); - auto expected = AllocateAligned(2 * N); - auto actual = AllocateAligned(2 * N); - - for (size_t i = 0; i < padded; i += 2 * N) { - const auto f0 = Load(d32, &in[i + 0]); - const auto f1 = Load(d32, &in[i + N]); - const auto v16 = ReorderDemote2To(dbf16, f0, f1); - Store(v16, dbf16, temp16.get()); - const auto promoted0 = PromoteTo(d32, Load(dbf16_half, temp16.get() + 0)); - const auto promoted1 = PromoteTo(d32, Load(dbf16_half, temp16.get() + N)); - - // Smoke test: sum should be same (with tolerance for non-associativity) - const auto sum_expected = - GetLane(SumOfLanes(d32, Add(promoted0, promoted1))); - const auto sum_actual = GetLane(SumOfLanes(d32, Add(f0, f1))); - HWY_ASSERT(sum_actual - 1E-4 <= sum_actual && - sum_expected <= sum_actual + 1E-4); - - // Ensure values are the same after sorting to undo the Reorder - Store(f0, d32, expected.get() + 0); - Store(f1, d32, expected.get() + N); - Store(promoted0, d32, actual.get() + 0); - Store(promoted1, d32, actual.get() + N); - Sort(expected.get(), 2 * N); - Sort(actual.get(), 2 * N); - HWY_ASSERT_VEC_EQ(d32, expected.get() + 0, Load(d32, actual.get() + 0)); - HWY_ASSERT_VEC_EQ(d32, expected.get() + N, Load(d32, actual.get() + N)); - } -#else // HWY_SCALAR - (void)d32; -#endif - } -}; - -HWY_NOINLINE void TestAllReorderDemote2To() { - ForShrinkableVectors()(float()); -} - struct TestConvertU8 { template HWY_NOINLINE void operator()(T /*unused*/, const D du32) { @@ -564,7 +342,7 @@ struct TestConvertU8 { }; HWY_NOINLINE void TestAllConvertU8() { - ForDemoteVectors()(uint32_t()); + ForDemoteVectors()(uint32_t()); } // Separate function to attempt to work around a compiler bug on ARM: when this @@ -574,19 +352,23 @@ struct TestIntFromFloatHuge { HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { // Still does not work, although ARMv7 manual says that float->int // saturates, i.e. chooses the nearest representable value. Also causes - // out-of-memory for MSVC, and unsafe cast in farm_sve. -#if HWY_TARGET != HWY_NEON && !HWY_COMPILER_MSVC && !defined(HWY_EMULATE_SVE) + // out-of-memory for MSVC. +#if HWY_TARGET != HWY_NEON && !HWY_COMPILER_MSVC using TI = MakeSigned; const Rebind di; - // Huge positive (lvalue works around GCC bug, tested with 10.2.1, where - // the expected i32 value is otherwise 0x80..00). - const auto expected_max = Set(di, LimitsMax()); - HWY_ASSERT_VEC_EQ(di, expected_max, ConvertTo(di, Set(df, TF(1E20)))); + // Workaround for incorrect 32-bit GCC codegen for SSSE3 - Print-ing + // the expected lvalue also seems to prevent the issue. + const size_t N = Lanes(df); + auto expected = AllocateAligned(N); - // Huge negative (also lvalue for safety, but GCC bug was not triggered) - const auto expected_min = Set(di, LimitsMin()); - HWY_ASSERT_VEC_EQ(di, expected_min, ConvertTo(di, Set(df, TF(-1E20)))); + // Huge positive + Store(Set(di, LimitsMax()), di, expected.get()); + HWY_ASSERT_VEC_EQ(di, expected.get(), ConvertTo(di, Set(df, TF(1E20)))); + + // Huge negative + Store(Set(di, LimitsMin()), di, expected.get()); + HWY_ASSERT_VEC_EQ(di, expected.get(), ConvertTo(di, Set(df, TF(-1E20)))); #else (void)df; #endif @@ -634,10 +416,6 @@ class TestIntFromFloat { const uint64_t bits = rng(); memcpy(&from[i], &bits, sizeof(TF)); } while (!std::isfinite(from[i])); -#if defined(HWY_EMULATE_SVE) - // farm_sve just casts, which is undefined if the value is out of range. - from[i] = HWY_MIN(HWY_MAX(min / 2, from[i]), max / 2); -#endif if (from[i] >= max) { expected[i] = LimitsMax(); } else if (from[i] <= min) { @@ -725,30 +503,21 @@ struct TestI32F64 { const size_t N = Lanes(df); // Integer positive - HWY_ASSERT_VEC_EQ(di, Iota(di, TI(4)), DemoteTo(di, Iota(df, TF(4.0)))); HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), PromoteTo(df, Iota(di, TI(4)))); // Integer negative - HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), DemoteTo(di, Iota(df, -TF(N)))); HWY_ASSERT_VEC_EQ(df, Iota(df, -TF(N)), PromoteTo(df, Iota(di, -TI(N)))); // Above positive - HWY_ASSERT_VEC_EQ(di, Iota(di, TI(2)), DemoteTo(di, Iota(df, TF(2.001)))); HWY_ASSERT_VEC_EQ(df, Iota(df, TF(2.0)), PromoteTo(df, Iota(di, TI(2)))); // Below positive - HWY_ASSERT_VEC_EQ(di, Iota(di, TI(3)), DemoteTo(di, Iota(df, TF(3.9999)))); HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), PromoteTo(df, Iota(di, TI(4)))); - const TF eps = static_cast(0.0001); // Above negative - HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), - DemoteTo(di, Iota(df, -TF(N + 1) + eps))); HWY_ASSERT_VEC_EQ(df, Iota(df, TF(-4.0)), PromoteTo(df, Iota(di, TI(-4)))); // Below negative - HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N + 1)), - DemoteTo(di, Iota(df, -TF(N + 1) - eps))); HWY_ASSERT_VEC_EQ(df, Iota(df, TF(-2.0)), PromoteTo(df, Iota(di, TI(-2)))); // Max positive int @@ -758,22 +527,11 @@ struct TestI32F64 { // Min negative int HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMin())), PromoteTo(df, Set(di, LimitsMin()))); - - // farm_sve just casts, which is undefined if the value is out of range. -#if !defined(HWY_EMULATE_SVE) - // Huge positive float - HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMax()), - DemoteTo(di, Set(df, TF(1E12)))); - - // Huge negative float - HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMin()), - DemoteTo(di, Set(df, TF(-1E12)))); -#endif } }; HWY_NOINLINE void TestAllI32F64() { -#if HWY_CAP_FLOAT64 +#if HWY_HAVE_FLOAT64 ForDemoteVectors()(double()); #endif } @@ -790,12 +548,8 @@ namespace hwy { HWY_BEFORE_TEST(HwyConvertTest); HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllBitCast); HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllPromoteTo); -HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllDemoteToInt); -HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllDemoteToMixed); -HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllDemoteToFloat); HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllF16); HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllBF16); -HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllReorderDemote2To); HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllConvertU8); HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllIntFromFloat); HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllFloatFromInt); diff --git a/third_party/highway/hwy/tests/crypto_test.cc b/third_party/highway/hwy/tests/crypto_test.cc index c85d63af953e..892deb3edbfb 100644 --- a/third_party/highway/hwy/tests/crypto_test.cc +++ b/third_party/highway/hwy/tests/crypto_test.cc @@ -74,7 +74,7 @@ class TestAES { } for (size_t i = 0; i < 256; i += N) { - const auto in = Iota(d, i); + const auto in = Iota(d, static_cast(i)); HWY_ASSERT_VEC_EQ(d, expected.get() + i, detail::SubBytes(in)); } } @@ -89,11 +89,17 @@ class TestAES { 0x42, 0xCA, 0x6B, 0x99, 0x7A, 0x5C, 0x58, 0x16}; const auto test = LoadDup128(d, test_lanes); + // = ShiftRow result + alignas(16) constexpr uint8_t expected_sr_lanes[16] = { + 0x09, 0x28, 0x7F, 0x47, 0x6F, 0x74, 0x6A, 0xBF, + 0x2C, 0x4A, 0x62, 0x04, 0xDA, 0x08, 0xE3, 0xEE}; + const auto expected_sr = LoadDup128(d, expected_sr_lanes); + // = MixColumn result - alignas(16) constexpr uint8_t expected0_lanes[16] = { + alignas(16) constexpr uint8_t expected_mc_lanes[16] = { 0x52, 0x9F, 0x16, 0xC2, 0x97, 0x86, 0x15, 0xCA, 0xE0, 0x1A, 0xAE, 0x54, 0xBA, 0x1A, 0x26, 0x59}; - const auto expected0 = LoadDup128(d, expected0_lanes); + const auto expected_mc = LoadDup128(d, expected_mc_lanes); // = KeyAddition result alignas(16) constexpr uint8_t expected_lanes[16] = { @@ -103,17 +109,20 @@ class TestAES { alignas(16) uint8_t key_lanes[16]; for (size_t i = 0; i < 16; ++i) { - key_lanes[i] = expected0_lanes[i] ^ expected_lanes[i]; + key_lanes[i] = expected_mc_lanes[i] ^ expected_lanes[i]; } const auto round_key = LoadDup128(d, key_lanes); - HWY_ASSERT_VEC_EQ(d, expected0, AESRound(test, Zero(d))); + HWY_ASSERT_VEC_EQ(d, expected_mc, AESRound(test, Zero(d))); HWY_ASSERT_VEC_EQ(d, expected, AESRound(test, round_key)); + HWY_ASSERT_VEC_EQ(d, expected_sr, AESLastRound(test, Zero(d))); + HWY_ASSERT_VEC_EQ(d, Xor(expected_sr, round_key), + AESLastRound(test, round_key)); TestSBox(t, d); } }; -HWY_NOINLINE void TestAllAES() { ForGE128Vectors()(uint8_t()); } +HWY_NOINLINE void TestAllAES() { ForGEVectors<128, TestAES>()(uint8_t()); } #else HWY_NOINLINE void TestAllAES() {} @@ -123,7 +132,7 @@ struct TestCLMul { template HWY_NOINLINE void operator()(T /*unused*/, D d) { // needs 64 bit lanes and 128-bit result -#if HWY_TARGET != HWY_SCALAR && HWY_CAP_INTEGER64 +#if HWY_TARGET != HWY_SCALAR && HWY_HAVE_INTEGER64 const size_t N = Lanes(d); if (N == 1) return; @@ -525,7 +534,7 @@ struct TestCLMul { } }; -HWY_NOINLINE void TestAllCLMul() { ForGE128Vectors()(uint64_t()); } +HWY_NOINLINE void TestAllCLMul() { ForGEVectors<128, TestCLMul>()(uint64_t()); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE diff --git a/third_party/highway/hwy/tests/demote_test.cc b/third_party/highway/hwy/tests/demote_test.cc new file mode 100644 index 000000000000..635b806714d7 --- /dev/null +++ b/third_party/highway/hwy/tests/demote_test.cc @@ -0,0 +1,333 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/demote_test.cc" +#include "hwy/foreach_target.h" + +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +// Causes build timeout. +#if !HWY_IS_MSAN + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +bool IsFinite(T t) { + return std::isfinite(t); +} +// Wrapper avoids calling std::isfinite for integer types (ambiguous). +template +bool IsFinite(T /*unused*/) { + return true; +} + +template +struct TestDemoteTo { + template + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + static_assert(!IsFloat(), "Use TestDemoteToFloat for float output"); + static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); + const Rebind to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + + // Narrower range in the wider type, for clamping before we cast + const T min = LimitsMin(); + const T max = LimitsMax(); + + const auto value_ok = [&](T& value) { + if (!IsFinite(value)) return false; + return true; + }; + + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + memcpy(&from[i], &bits, sizeof(T)); + } while (!value_ok(from[i])); + expected[i] = static_cast(HWY_MIN(HWY_MAX(min, from[i]), max)); + } + + const auto in = Load(from_d, from.get()); + HWY_ASSERT_VEC_EQ(to_d, expected.get(), DemoteTo(to_d, in)); + } + } +}; + +HWY_NOINLINE void TestAllDemoteToInt() { + ForDemoteVectors>()(int16_t()); + ForDemoteVectors, 2>()(int32_t()); + + ForDemoteVectors>()(int16_t()); + ForDemoteVectors, 2>()(int32_t()); + + const ForDemoteVectors> to_u16; + to_u16(int32_t()); + + const ForDemoteVectors> to_i16; + to_i16(int32_t()); +} + +HWY_NOINLINE void TestAllDemoteToMixed() { +#if HWY_HAVE_FLOAT64 + const ForDemoteVectors> to_i32; + to_i32(double()); +#endif +} + +template +struct TestDemoteToFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + // For floats, we clamp differently and cannot call LimitsMin. + static_assert(IsFloat(), "Use TestDemoteTo for integer output"); + static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); + const Rebind to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + memcpy(&from[i], &bits, sizeof(T)); + } while (!IsFinite(from[i])); + const T magn = std::abs(from[i]); + const T max_abs = HighestValue(); + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + const T clipped = copysign(HWY_MIN(magn, max_abs), from[i]); + expected[i] = static_cast(clipped); + } + + HWY_ASSERT_VEC_EQ(to_d, expected.get(), + DemoteTo(to_d, Load(from_d, from.get()))); + } + } +}; + +HWY_NOINLINE void TestAllDemoteToFloat() { + // Must test f16 separately because we can only load/store/convert them. + +#if HWY_HAVE_FLOAT64 + const ForDemoteVectors, 1> to_float; + to_float(double()); +#endif +} + +template +AlignedFreeUniquePtr ReorderBF16TestCases(D d, size_t& padded) { + const float test_cases[] = { + // Same as BF16TestCases: + // +/- 1 + 1.0f, + -1.0f, + // +/- 0 + 0.0f, + -0.0f, + // near 0 + 0.25f, + -0.25f, + // +/- integer + 4.0f, + -32.0f, + // positive +/- delta + 2.015625f, + 3.984375f, + // negative +/- delta + -2.015625f, + -3.984375f, + + // No huge values - would interfere with sum. But add more to fill 2 * N: + -2.0f, + -10.0f, + 0.03125f, + 1.03125f, + 1.5f, + 2.0f, + 4.0f, + 5.0f, + 6.0f, + 8.0f, + 10.0f, + 256.0f, + 448.0f, + 2080.0f, + }; + const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + padded = RoundUpTo(kNumTestCases, 2 * N); // allow loading pairs of vectors + auto in = AllocateAligned(padded); + auto expected = AllocateAligned(padded); + std::copy(test_cases, test_cases + kNumTestCases, in.get()); + std::fill(in.get() + kNumTestCases, in.get() + padded, 0.0f); + return in; +} + +class TestReorderDemote2To { + // In-place N^2 selection sort to avoid dependencies + void Sort(float* p, size_t count) { + for (size_t i = 0; i < count - 1; ++i) { + // Find min_element + size_t idx_min = i; + for (size_t j = i + 1; j < count; j++) { + if (p[j] < p[idx_min]) { + idx_min = j; + } + } + + // Swap with current + const float tmp = p[i]; + p[i] = p[idx_min]; + p[idx_min] = tmp; + } + } + + public: + template + HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { +#if HWY_TARGET != HWY_SCALAR + + size_t padded; + auto in = ReorderBF16TestCases(d32, padded); + + using TBF16 = bfloat16_t; + const Repartition dbf16; + const Half dbf16_half; + const size_t N = Lanes(d32); + auto temp16 = AllocateAligned(2 * N); + auto expected = AllocateAligned(2 * N); + auto actual = AllocateAligned(2 * N); + + for (size_t i = 0; i < padded; i += 2 * N) { + const auto f0 = Load(d32, &in[i + 0]); + const auto f1 = Load(d32, &in[i + N]); + const auto v16 = ReorderDemote2To(dbf16, f0, f1); + Store(v16, dbf16, temp16.get()); + const auto promoted0 = PromoteTo(d32, Load(dbf16_half, temp16.get() + 0)); + const auto promoted1 = PromoteTo(d32, Load(dbf16_half, temp16.get() + N)); + + // Smoke test: sum should be same (with tolerance for non-associativity) + const auto sum_expected = + GetLane(SumOfLanes(d32, Add(promoted0, promoted1))); + const auto sum_actual = GetLane(SumOfLanes(d32, Add(f0, f1))); + HWY_ASSERT(sum_actual - 1E-4 <= sum_actual && + sum_expected <= sum_actual + 1E-4); + + // Ensure values are the same after sorting to undo the Reorder + Store(f0, d32, expected.get() + 0); + Store(f1, d32, expected.get() + N); + Store(promoted0, d32, actual.get() + 0); + Store(promoted1, d32, actual.get() + N); + Sort(expected.get(), 2 * N); + Sort(actual.get(), 2 * N); + HWY_ASSERT_VEC_EQ(d32, expected.get() + 0, Load(d32, actual.get() + 0)); + HWY_ASSERT_VEC_EQ(d32, expected.get() + N, Load(d32, actual.get() + N)); + } +#else // HWY_SCALAR + (void)d32; +#endif + } +}; + +HWY_NOINLINE void TestAllReorderDemote2To() { + ForShrinkableVectors()(float()); +} + +struct TestI32F64 { + template + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TI = int32_t; + const Rebind di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(4)), DemoteTo(di, Iota(df, TF(4.0)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), DemoteTo(di, Iota(df, -TF(N)))); + + // Above positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(2)), DemoteTo(di, Iota(df, TF(2.001)))); + + // Below positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(3)), DemoteTo(di, Iota(df, TF(3.9999)))); + + const TF eps = static_cast(0.0001); + // Above negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), + DemoteTo(di, Iota(df, -TF(N + 1) + eps))); + + // Below negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N + 1)), + DemoteTo(di, Iota(df, -TF(N + 1) - eps))); + + // Huge positive float + HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMax()), + DemoteTo(di, Set(df, TF(1E12)))); + + // Huge negative float + HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMin()), + DemoteTo(di, Set(df, TF(-1E12)))); + } +}; + +HWY_NOINLINE void TestAllI32F64() { +#if HWY_HAVE_FLOAT64 + ForDemoteVectors()(double()); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // !HWY_IS_MSAN + +#if HWY_ONCE + +namespace hwy { +#if !HWY_IS_MSAN +HWY_BEFORE_TEST(HwyDemoteTest); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllDemoteToInt); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllDemoteToMixed); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllDemoteToFloat); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllReorderDemote2To); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllI32F64); +#endif // !HWY_IS_MSAN +} // namespace hwy + +// Ought not to be necessary, but without this, no tests run on RVV. +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +#endif diff --git a/third_party/highway/hwy/tests/logical_test.cc b/third_party/highway/hwy/tests/logical_test.cc index bc9835e5081d..6450a8abb426 100644 --- a/third_party/highway/hwy/tests/logical_test.cc +++ b/third_party/highway/hwy/tests/logical_test.cc @@ -17,7 +17,6 @@ #include // memcmp #include "hwy/aligned_allocator.h" -#include "hwy/base.h" #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "tests/logical_test.cc" @@ -59,6 +58,15 @@ struct TestLogicalInteger { HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, v0)); HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, vi)); + HWY_ASSERT_VEC_EQ(d, v0, OrAnd(v0, v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, OrAnd(v0, vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, OrAnd(v0, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(v0, vi, vi)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, vi, vi)); + auto v = vi; v = And(v, vi); HWY_ASSERT_VEC_EQ(d, vi, v); @@ -156,6 +164,43 @@ struct TestCopySign { } }; +struct TestIfVecThenElse { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TU = MakeUnsigned; // For all-one mask + const Rebind du; + const size_t N = Lanes(d); + auto in1 = AllocateAligned(N); + auto in2 = AllocateAligned(N); + auto vec_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = static_cast(Random32(&rng)); + in2[i] = static_cast(Random32(&rng)); + vec_lanes[i] = (Random32(&rng) & 16) ? static_cast(~TU(0)) : TU(0); + } + + const auto v1 = Load(d, in1.get()); + const auto v2 = Load(d, in2.get()); + const auto vec = BitCast(d, Load(du, vec_lanes.get())); + + for (size_t i = 0; i < N; ++i) { + expected[i] = vec_lanes[i] ? in1[i] : in2[i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfVecThenElse(vec, v1, v2)); + } + } +}; + +HWY_NOINLINE void TestAllIfVecThenElse() { + ForAllTypes(ForPartialVectors()); +} + HWY_NOINLINE void TestAllCopySign() { ForFloatTypes(ForPartialVectors()); } @@ -180,6 +225,31 @@ HWY_NOINLINE void TestAllZeroIfNegative() { ForFloatTypes(ForPartialVectors()); } +struct TestIfNegative { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp = Iota(d, 1); + const auto vn = Or(vp, SignBit(d)); + + // Zero and positive remain unchanged + HWY_ASSERT_VEC_EQ(d, v0, IfNegativeThenElse(v0, vn, v0)); + HWY_ASSERT_VEC_EQ(d, vn, IfNegativeThenElse(v0, v0, vn)); + HWY_ASSERT_VEC_EQ(d, vp, IfNegativeThenElse(vp, vn, vp)); + HWY_ASSERT_VEC_EQ(d, vn, IfNegativeThenElse(vp, vp, vn)); + + // Negative are replaced with 2nd arg + HWY_ASSERT_VEC_EQ(d, v0, IfNegativeThenElse(vn, v0, vp)); + HWY_ASSERT_VEC_EQ(d, vn, IfNegativeThenElse(vn, vn, v0)); + HWY_ASSERT_VEC_EQ(d, vp, IfNegativeThenElse(vn, vp, vn)); + } +}; + +HWY_NOINLINE void TestAllIfNegative() { + ForFloatTypes(ForPartialVectors()); + ForSignedTypes(ForPartialVectors()); +} + struct TestBroadcastSignBit { template HWY_NOINLINE void operator()(T /*unused*/, D d) { @@ -234,16 +304,11 @@ HWY_NOINLINE void TestAllTestBit() { struct TestPopulationCount { template HWY_NOINLINE void operator()(T /*unused*/, D d) { -#if HWY_TARGET == HWY_RVV || HWY_IS_DEBUG_BUILD - constexpr size_t kNumTests = 1 << 14; -#else - constexpr size_t kNumTests = 1 << 20; -#endif RandomState rng; size_t N = Lanes(d); auto data = AllocateAligned(N); auto popcnt = AllocateAligned(N); - for (size_t i = 0; i < kNumTests / N; i++) { + for (size_t i = 0; i < AdjustedReps(1 << 18) / N; i++) { for (size_t i = 0; i < N; i++) { data[i] = static_cast(rng()); popcnt[i] = static_cast(PopCount(data[i])); @@ -268,8 +333,10 @@ namespace hwy { HWY_BEFORE_TEST(HwyLogicalTest); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogicalInteger); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogicalFloat); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllIfVecThenElse); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllCopySign); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllZeroIfNegative); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllIfNegative); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllBroadcastSignBit); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllPopulationCount); diff --git a/third_party/highway/hwy/tests/mask_test.cc b/third_party/highway/hwy/tests/mask_test.cc index 569f85ba5734..d6f7ceb5365d 100644 --- a/third_party/highway/hwy/tests/mask_test.cc +++ b/third_party/highway/hwy/tests/mask_test.cc @@ -17,8 +17,6 @@ #include #include // memcmp -#include "hwy/base.h" - #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "tests/mask_test.cc" #include "hwy/foreach_target.h" @@ -55,13 +53,18 @@ struct TestFirstN { template HWY_NOINLINE void operator()(T /*unused*/, D d) { const size_t N = Lanes(d); - const RebindToSigned di; using TI = TFromD; using TN = SignedFromSize; const size_t max_len = static_cast(LimitsMax()); - for (size_t len = 0; len <= HWY_MIN(2 * N, max_len); ++len) { +// TODO(janwas): 8-bit FirstN (using SlideUp) causes spike to freeze. +#if HWY_TARGET == HWY_RVV + if (sizeof(T) == 1) return; +#endif + + const size_t max_lanes = AdjustedReps(HWY_MIN(2 * N, size_t(64))); + for (size_t len = 0; len <= HWY_MIN(max_lanes, max_len); ++len) { const auto expected = RebindMask(d, Lt(Iota(di, 0), Set(di, static_cast(len)))); const auto actual = FirstN(d, len); @@ -368,7 +371,7 @@ struct TestFindFirstTrue { memset(bool_lanes.get(), 0, N * sizeof(TI)); // For all combinations of zero/nonzero state of subset of lanes: - const size_t max_lanes = HWY_MIN(N, size_t(10)); + const size_t max_lanes = AdjustedLog2Reps(HWY_MIN(N, size_t(9))); HWY_ASSERT_EQ(intptr_t(-1), FindFirstTrue(d, MaskFalse(d))); HWY_ASSERT_EQ(intptr_t(0), FindFirstTrue(d, MaskTrue(d))); @@ -407,7 +410,7 @@ struct TestLogicalMask { HWY_ASSERT_MASK_EQ(d, m_all, Not(m0)); // For all combinations of zero/nonzero state of subset of lanes: - const size_t max_lanes = HWY_MIN(N, size_t(6)); + const size_t max_lanes = AdjustedLog2Reps(HWY_MIN(N, size_t(6))); for (size_t code = 0; code < (1ull << max_lanes); ++code) { for (size_t i = 0; i < max_lanes; ++i) { bool_lanes[i] = (code & (1ull << i)) ? TI(1) : TI(0); diff --git a/third_party/highway/hwy/tests/memory_test.cc b/third_party/highway/hwy/tests/memory_test.cc index 3f72809a29d1..8213c8ee9c76 100644 --- a/third_party/highway/hwy/tests/memory_test.cc +++ b/third_party/highway/hwy/tests/memory_test.cc @@ -36,7 +36,7 @@ struct TestLoadStore { template HWY_NOINLINE void operator()(T /*unused*/, D d) { const size_t N = Lanes(d); - const auto hi = Iota(d, 1 + N); + const auto hi = Iota(d, static_cast(1 + N)); const auto lo = Iota(d, 1); auto lanes = AllocateAligned(2 * N); Store(hi, d, &lanes[N]); @@ -135,7 +135,7 @@ struct TestStoreInterleaved3 { HWY_NOINLINE void TestAllStoreInterleaved3() { #if HWY_TARGET == HWY_RVV // Segments are limited to 8 registers, so we can only go up to LMUL=2. - const ForExtendableVectors test; + const ForExtendableVectors test; #else const ForPartialVectors test; #endif @@ -198,7 +198,7 @@ struct TestStoreInterleaved4 { HWY_NOINLINE void TestAllStoreInterleaved4() { #if HWY_TARGET == HWY_RVV // Segments are limited to 8 registers, so we can only go up to LMUL=2. - const ForExtendableVectors test; + const ForExtendableVectors test; #else const ForPartialVectors test; #endif @@ -230,7 +230,7 @@ struct TestLoadDup128 { }; HWY_NOINLINE void TestAllLoadDup128() { - ForAllTypes(ForGE128Vectors()); + ForAllTypes(ForGEVectors<128, TestLoadDup128>()); } struct TestStream { @@ -245,7 +245,7 @@ struct TestStream { std::fill(out.get(), out.get() + 2 * affected_lanes, T(0)); Stream(v, d, out.get()); - StoreFence(); + FlushStream(); const auto actual = Load(d, out.get()); HWY_ASSERT_VEC_EQ(d, v, actual); // Ensure Stream didn't modify more memory than expected @@ -386,7 +386,7 @@ HWY_NOINLINE void TestAllGather() { HWY_NOINLINE void TestAllCache() { LoadFence(); - StoreFence(); + FlushStream(); int test = 0; Prefetch(&test); FlushCacheline(&test); diff --git a/third_party/highway/hwy/tests/shift_test.cc b/third_party/highway/hwy/tests/shift_test.cc new file mode 100644 index 000000000000..4eb3502f3273 --- /dev/null +++ b/third_party/highway/hwy/tests/shift_test.cc @@ -0,0 +1,433 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/shift_test.cc" +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct TestLeftShifts { + template + HWY_NOINLINE void operator()(T t, D d) { + if (kSigned) { + // Also test positive values + TestLeftShifts()(t, d); + } + + using TI = MakeSigned; + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto values = Iota(d, kSigned ? -TI(N) : TI(0)); // value to shift + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftLeft<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftLeftSame(values, 0)); + + // 1 + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(T(i) - T(N)) : T(i); + expected[i] = T(TU(value) << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, 1)); + + // max + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(T(i) - T(N)) : T(i); + expected[i] = T(TU(value) << kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, kMaxShift)); + } +}; + +template +struct TestVariableLeftShifts { + template + HWY_NOINLINE void operator()(T t, D d) { + if (kSigned) { + // Also test positive values + TestVariableLeftShifts()(t, d); + } + + using TI = MakeSigned; + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto values = Iota(d, kSigned ? -TI(N) : TI(0)); // value to shift + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + // Same: 0 + HWY_ASSERT_VEC_EQ(d, values, Shl(values, v0)); + + // Same: 1 + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, v1)); + + // Same: max + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, max_shift)); + + // Variable: small + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << (i & kMaxShift)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, small_shifts)); + + // Variable: large + for (size_t i = 0; i < N; ++i) { + expected[i] = T(TU(1) << (kMaxShift - (i & kMaxShift))); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(v1, large_shifts)); + } +}; + +struct TestUnsignedRightShifts { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto values = Iota(d, 0); + + const T kMax = LimitsMax(); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); + + // max + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, kMaxShift)); + } +}; + +struct TestRotateRight { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + constexpr size_t kBits = sizeof(T) * 8; + const auto mask_shift = Set(d, T{kBits}); + // Cover as many bit positions as possible to test shifting out + const auto values = Shl(Set(d, T{1}), And(Iota(d, 0), mask_shift)); + + // Rotate by 0 + HWY_ASSERT_VEC_EQ(d, values, RotateRight<0>(values)); + + // Rotate by 1 + Store(values, d, expected.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = (expected[i] >> 1) | (expected[i] << (kBits - 1)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight<1>(values)); + + // Rotate by half + Store(values, d, expected.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = (expected[i] >> (kBits / 2)) | (expected[i] << (kBits / 2)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight(values)); + + // Rotate by max + Store(values, d, expected.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = (expected[i] >> (kBits - 1)) | (expected[i] << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight(values)); + } +}; + +struct TestVariableUnsignedRightShifts { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto values = Iota(d, 0); + + const T kMax = LimitsMax(); + const auto max = Set(d, kMax); + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + // Same: 0 + HWY_ASSERT_VEC_EQ(d, values, Shr(values, v0)); + + // Same: 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, v1)); + + // Same: max + HWY_ASSERT_VEC_EQ(d, v0, Shr(values, max_shift)); + + // Variable: small + for (size_t i = 0; i < N; ++i) { + expected[i] = T(i) >> (i & kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, small_shifts)); + + // Variable: Large + for (size_t i = 0; i < N; ++i) { + expected[i] = kMax >> (kMaxShift - (i & kMaxShift)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(max, large_shifts)); + } +}; + +template +T RightShiftNegative(T val) { + // C++ shifts are implementation-defined for negative numbers, and we have + // seen divisions replaced with shifts, so resort to bit operations. + using TU = hwy::MakeUnsigned; + TU bits; + CopyBytes(&val, &bits); + + const TU shifted = TU(bits >> kAmount); + + const TU all = TU(~TU(0)); + const size_t num_zero = sizeof(TU) * 8 - 1 - kAmount; + const TU sign_extended = static_cast((all << num_zero) & LimitsMax()); + + bits = shifted | sign_extended; + CopyBytes(&bits, &val); + return val; +} + +class TestSignedRightShifts { + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + constexpr T kMin = LimitsMin(); + constexpr T kMax = LimitsMax(); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // First test positive values, negative are checked below. + const auto v0 = Zero(d); + const auto values = And(Iota(d, 0), Set(d, kMax)); + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); + + // max + HWY_ASSERT_VEC_EQ(d, v0, ShiftRight(values)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(values, kMaxShift)); + + // Even negative value + Test<0>(kMin, d, __LINE__); + Test<1>(kMin, d, __LINE__); + Test<2>(kMin, d, __LINE__); + Test(kMin, d, __LINE__); + + const T odd = static_cast(kMin + 1); + Test<0>(odd, d, __LINE__); + Test<1>(odd, d, __LINE__); + Test<2>(odd, d, __LINE__); + Test(odd, d, __LINE__); + } + + private: + template + void Test(T val, D d, int line) { + const auto expected = Set(d, RightShiftNegative(val)); + const auto in = Set(d, val); + const char* file = __FILE__; + AssertVecEqual(d, expected, ShiftRight(in), file, line); + AssertVecEqual(d, expected, ShiftRightSame(in, kAmount), file, line); + } +}; + +struct TestVariableSignedRightShifts { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + constexpr T kMin = LimitsMin(); + constexpr T kMax = LimitsMax(); + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // First test positive values, negative are checked below. + const auto v0 = Zero(d); + const auto positive = Iota(d, 0) & Set(d, kMax); + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, positive, ShiftRight<0>(positive)); + HWY_ASSERT_VEC_EQ(d, positive, ShiftRightSame(positive, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(positive)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(positive, 1)); + + // max + HWY_ASSERT_VEC_EQ(d, v0, ShiftRight(positive)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(positive, kMaxShift)); + + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + const auto negative = Iota(d, kMin); + + // Test varying negative to shift + for (size_t i = 0; i < N; ++i) { + expected[i] = RightShiftNegative<1>(static_cast(kMin + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(negative, Set(d, 1))); + + // Shift MSB right by small amounts + for (size_t i = 0; i < N; ++i) { + const size_t amount = i & kMaxShift; + const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); + CopyBytes(&shifted, &expected[i]); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), small_shifts)); + + // Shift MSB right by large amounts + for (size_t i = 0; i < N; ++i) { + const size_t amount = kMaxShift - (i & kMaxShift); + const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); + CopyBytes(&shifted, &expected[i]); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), large_shifts)); + } +}; + +HWY_NOINLINE void TestAllShifts() { + ForUnsignedTypes(ForPartialVectors>()); + ForSignedTypes(ForPartialVectors>()); + ForUnsignedTypes(ForPartialVectors()); + ForSignedTypes(ForPartialVectors()); +} + +HWY_NOINLINE void TestAllVariableShifts() { + const ForPartialVectors> shl_u; + const ForPartialVectors> shl_s; + const ForPartialVectors shr_u; + const ForPartialVectors shr_s; + + shl_u(uint16_t()); + shr_u(uint16_t()); + + shl_u(uint32_t()); + shr_u(uint32_t()); + + shl_s(int16_t()); + shr_s(int16_t()); + + shl_s(int32_t()); + shr_s(int32_t()); + +#if HWY_HAVE_INTEGER64 + shl_u(uint64_t()); + shr_u(uint64_t()); + + shl_s(int64_t()); + shr_s(int64_t()); +#endif +} + +HWY_NOINLINE void TestAllRotateRight() { + const ForPartialVectors test; + test(uint32_t()); +#if HWY_HAVE_INTEGER64 + test(uint64_t()); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyShiftTest); +HWY_EXPORT_AND_TEST_P(HwyShiftTest, TestAllShifts); +HWY_EXPORT_AND_TEST_P(HwyShiftTest, TestAllVariableShifts); +HWY_EXPORT_AND_TEST_P(HwyShiftTest, TestAllRotateRight); +} // namespace hwy + +// Ought not to be necessary, but without this, no tests run on RVV. +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +#endif diff --git a/third_party/highway/hwy/tests/swizzle_test.cc b/third_party/highway/hwy/tests/swizzle_test.cc index ea14514d6402..e899ae7341ba 100644 --- a/third_party/highway/hwy/tests/swizzle_test.cc +++ b/third_party/highway/hwy/tests/swizzle_test.cc @@ -19,6 +19,8 @@ #include // IWYU pragma: keep +#include "hwy/base.h" + #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "tests/swizzle_test.cc" #include "hwy/foreach_target.h" @@ -44,12 +46,48 @@ HWY_NOINLINE void TestAllGetLane() { ForAllTypes(ForPartialVectors()); } +struct TestDupEven { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((static_cast(i) & ~1) + 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), DupEven(Iota(d, 1))); + } +}; + +HWY_NOINLINE void TestAllDupEven() { + ForUIF3264(ForShrinkableVectors()); +} + +struct TestDupOdd { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((static_cast(i) & ~1) + 2); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), DupOdd(Iota(d, 1))); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllDupOdd() { + ForUIF3264(ForShrinkableVectors()); +} + struct TestOddEven { template HWY_NOINLINE void operator()(T /*unused*/, D d) { const size_t N = Lanes(d); const auto even = Iota(d, 1); - const auto odd = Iota(d, 1 + N); + const auto odd = Iota(d, static_cast(1 + N)); auto expected = AllocateAligned(N); for (size_t i = 0; i < N; ++i) { expected[i] = static_cast(1 + i + ((i & 1) ? N : 0)); @@ -67,7 +105,7 @@ struct TestOddEvenBlocks { HWY_NOINLINE void operator()(T /*unused*/, D d) { const size_t N = Lanes(d); const auto even = Iota(d, 1); - const auto odd = Iota(d, 1 + N); + const auto odd = Iota(d, static_cast(1 + N)); auto expected = AllocateAligned(N); for (size_t i = 0; i < N; ++i) { const size_t idx_block = i / (16 / sizeof(T)); @@ -78,7 +116,7 @@ struct TestOddEvenBlocks { }; HWY_NOINLINE void TestAllOddEvenBlocks() { - ForAllTypes(ForShrinkableVectors()); + ForAllTypes(ForGEVectors<128, TestOddEvenBlocks>()); } struct TestSwapAdjacentBlocks { @@ -100,7 +138,7 @@ struct TestSwapAdjacentBlocks { }; HWY_NOINLINE void TestAllSwapAdjacentBlocks() { - ForAllTypes(ForPartialVectors()); + ForAllTypes(ForGEVectors<128, TestSwapAdjacentBlocks>()); } struct TestTableLookupLanes { @@ -197,23 +235,131 @@ struct TestReverse { } }; +struct TestReverse2 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[i ^ 1]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse2(d, v)); + } +}; + +struct TestReverse4 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[i ^ 3]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse4(d, v)); + } +}; + +struct TestReverse8 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[i ^ 7]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse8(d, v)); + } +}; + HWY_NOINLINE void TestAllReverse() { // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, // which requires 16 bits. ForUIF163264(ForPartialVectors()); } +HWY_NOINLINE void TestAllReverse2() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF64(ForGEVectors<128, TestReverse2>()); + ForUIF32(ForGEVectors<64, TestReverse2>()); + ForUIF16(ForGEVectors<32, TestReverse2>()); +} + +HWY_NOINLINE void TestAllReverse4() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF64(ForGEVectors<256, TestReverse4>()); + ForUIF32(ForGEVectors<128, TestReverse4>()); + ForUIF16(ForGEVectors<64, TestReverse4>()); +} + +HWY_NOINLINE void TestAllReverse8() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF64(ForGEVectors<512, TestReverse8>()); + ForUIF32(ForGEVectors<256, TestReverse8>()); + ForUIF16(ForGEVectors<128, TestReverse8>()); +} + +struct TestReverseBlocks { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned(N); + + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + const size_t num_blocks = N / kLanesPerBlock; + HWY_ASSERT(num_blocks != 0); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + const size_t idx_block = i / kLanesPerBlock; + const size_t base = (num_blocks - 1 - idx_block) * kLanesPerBlock; + expected[i] = copy[base + (i % kLanesPerBlock)]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ReverseBlocks(d, v)); + } +}; + +HWY_NOINLINE void TestAllReverseBlocks() { + ForAllTypes(ForGEVectors<128, TestReverseBlocks>()); +} + class TestCompress { - template - void CheckStored(Simd d, Simd di, size_t expected_pos, - size_t actual_pos, const AlignedFreeUniquePtr& in, + template , typename TI = TFromD> + void CheckStored(D d, DI di, size_t expected_pos, size_t actual_pos, + const AlignedFreeUniquePtr& in, const AlignedFreeUniquePtr& mask_lanes, const AlignedFreeUniquePtr& expected, const T* actual_u, int line) { if (expected_pos != actual_pos) { - hwy::Abort(__FILE__, line, - "Size mismatch for %s: expected %" PRIu64 ", actual %" PRIu64 "\n", - TypeName(T(), N).c_str(), static_cast(expected_pos), static_cast(actual_pos)); + hwy::Abort( + __FILE__, line, + "Size mismatch for %s: expected %" PRIu64 ", actual %" PRIu64 "\n", + TypeName(T(), Lanes(d)).c_str(), static_cast(expected_pos), + static_cast(actual_pos)); } // Upper lanes are undefined. Modified from AssertVecEqual. for (size_t i = 0; i < expected_pos; ++i) { @@ -222,6 +368,7 @@ class TestCompress { "Mismatch at i=%" PRIu64 " of %" PRIu64 ", line %d:\n\n", static_cast(i), static_cast(expected_pos), line); + const size_t N = Lanes(d); Print(di, "mask", Load(di, mask_lanes.get()), 0, N); Print(d, "in", Load(d, in.get()), 0, N); Print(d, "expect", Load(d, expected.get()), 0, N); @@ -251,7 +398,10 @@ class TestCompress { auto expected = AllocateAligned(N); auto actual_a = AllocateAligned(misalign + N); T* actual_u = actual_a.get() + misalign; - auto bits = AllocateAligned(HWY_MAX(8, (N + 7) / 8)); + + const size_t bits_size = RoundUpTo((N + 7) / 8, 8); + auto bits = AllocateAligned(bits_size); + memset(bits.get(), 0, bits_size); // for MSAN // Each lane should have a chance of having mask=true. for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { @@ -465,7 +615,7 @@ HWY_NOINLINE void TestAllCompress() { test(uint16_t()); test(int16_t()); -#if HWY_CAP_FLOAT16 +#if HWY_HAVE_FLOAT16 test(float16_t()); #endif @@ -482,11 +632,17 @@ HWY_AFTER_NAMESPACE(); namespace hwy { HWY_BEFORE_TEST(HwySwizzleTest); HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllGetLane); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllDupEven); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllDupOdd); HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllOddEven); HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllOddEvenBlocks); HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllSwapAdjacentBlocks); HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllTableLookupLanes); HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllReverse); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllReverse2); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllReverse4); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllReverse8); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllReverseBlocks); HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllCompress); } // namespace hwy diff --git a/third_party/highway/hwy/tests/test_util-inl.h b/third_party/highway/hwy/tests/test_util-inl.h index 8a82267b3d93..6a02886a44ed 100644 --- a/third_party/highway/hwy/tests/test_util-inl.h +++ b/third_party/highway/hwy/tests/test_util-inl.h @@ -41,7 +41,7 @@ HWY_NOINLINE void PrintValue(T value) { fprintf(stderr, "0x%02X,", byte); } -#if HWY_CAP_FLOAT16 +#if HWY_HAVE_FLOAT16 HWY_NOINLINE void PrintValue(float16_t value) { uint16_t bits; CopyBytes<2>(&value, &bits); @@ -70,9 +70,11 @@ void Print(const D d, const char* caption, VecArg v, size_t lane_u = 0, } // Compare expected vector to vector. +// HWY_INLINE works around a Clang SVE compiler bug where all but the first +// 128 bits (the NEON register) of actual are zero. template , class V = Vec> -void AssertVecEqual(D d, const T* expected, VecArg actual, - const char* filename, const int line) { +HWY_INLINE void AssertVecEqual(D d, const T* expected, VecArg actual, + const char* filename, const int line) { const size_t N = Lanes(d); auto actual_lanes = AllocateAligned(N); Store(actual, d, actual_lanes.get()); @@ -84,9 +86,11 @@ void AssertVecEqual(D d, const T* expected, VecArg actual, } // Compare expected lanes to vector. +// HWY_INLINE works around a Clang SVE compiler bug where all but the first +// 128 bits (the NEON register) of actual are zero. template , class V = Vec> -HWY_NOINLINE void AssertVecEqual(D d, VecArg expected, VecArg actual, - const char* filename, int line) { +HWY_INLINE void AssertVecEqual(D d, VecArg expected, VecArg actual, + const char* filename, int line) { auto expected_lanes = AllocateAligned(Lanes(d)); Store(expected, d, expected_lanes.get()); AssertVecEqual(d, expected_lanes.get(), actual, filename, line); @@ -96,7 +100,10 @@ HWY_NOINLINE void AssertVecEqual(D d, VecArg expected, VecArg actual, template HWY_NOINLINE void AssertMaskEqual(D d, VecArg> a, VecArg> b, const char* filename, int line) { - AssertVecEqual(d, VecFromMask(d, a), VecFromMask(d, b), filename, line); + // lvalues prevented MSAN failure in farm_sve. + const Vec va = VecFromMask(d, a); + const Vec vb = VecFromMask(d, b); + AssertVecEqual(d, va, vb, filename, line); const char* target_name = hwy::TargetName(HWY_TARGET); AssertEqual(CountTrue(d, a), CountTrue(d, b), target_name, filename, line); @@ -178,169 +185,269 @@ HWY_INLINE Mask MaskFalse(const D d) { // Helpers for instantiating tests with combinations of lane types / counts. -// For ensuring we do not call tests with D such that widening D results in 0 -// lanes. Example: assume T=u32, VLEN=256, and fraction=1/8: there is no 1/8th -// of a u64 vector in this case. -template -HWY_INLINE size_t PromotedLanes(const D d) { - return Lanes(RepartitionToWide()); -} -// Already the widest possible T, cannot widen. -template -HWY_INLINE size_t PromotedLanes(const D d) { - return Lanes(d); -} +// Calls Test for each CappedTag where N is in [kMinLanes, kMul * kMinArg] +// and the resulting Lanes() is in [min_lanes, max_lanes]. The upper bound +// is required to ensure capped vectors remain extendable. Implemented by +// recursively halving kMul until it is zero. +template +struct ForeachCappedR { + static void Do(size_t min_lanes, size_t max_lanes) { + const CappedTag d; -// For all power of two N in [kMinLanes, kMul * kMinLanes] (so that recursion -// stops at kMul == 0). Note that N may be capped or a fraction. -template -struct ForeachSizeR { - static void Do() { - const Simd d; + // If we already don't have enough lanes, stop. + const size_t lanes = Lanes(d); + if (lanes < min_lanes) return; - // Skip invalid fractions (e.g. 1/8th of u32x4). - const size_t lanes = kPromote ? PromotedLanes(d) : Lanes(d); - if (lanes < kMinLanes) return; - - Test()(T(), d); - - static_assert(kMul != 0, "Recursion should have ended already"); - ForeachSizeR::Do(); + if (lanes <= max_lanes) { + Test()(T(), d); + } + ForeachCappedR::Do(min_lanes, max_lanes); } }; // Base case to stop the recursion. -template -struct ForeachSizeR { - static void Do() {} +template +struct ForeachCappedR { + static void Do(size_t, size_t) {} }; +#if HWY_HAVE_SCALABLE + +constexpr int MinVectorSize() { +#if HWY_TARGET == HWY_RVV + // Actually 16 for the application processor profile, but the intrinsics are + // defined as if VLEN might be only 64: there is no vuint64mf2_t. + return 8; +#else + return 16; +#endif +} + +template +constexpr int MinPow2() { + // Highway follows RVV LMUL in that the smallest fraction is 1/8th (encoded + // as kPow2 == -3). The fraction also must not result in zero lanes for the + // smallest possible vector size. + return HWY_MAX(-3, -static_cast(CeilLog2(MinVectorSize() / sizeof(T)))); +} + +// Iterates kPow2 upward through +3. +template +struct ForeachShiftR { + static void Do(size_t min_lanes) { + const ScalableTag d; + + // Precondition: [kPow2, 3] + kAddPow2 is a valid fraction of the minimum + // vector size, so we always have enough lanes, except ForGEVectors. + if (Lanes(d) >= min_lanes) { + Test()(T(), d); + } else { + fprintf(stderr, "%d lanes < %d: T=%d pow=%d\n", + static_cast(Lanes(d)), static_cast(min_lanes), + static_cast(sizeof(T)), kPow2 + kAddPow2); + HWY_ASSERT(min_lanes != 1); + } + + ForeachShiftR::Do(min_lanes); + } +}; + +// Base case to stop the recursion. +template +struct ForeachShiftR { + static void Do(size_t) {} +}; +#else +// ForeachCappedR already handled all possible sizes. +#endif // HWY_HAVE_SCALABLE + // These adapters may be called directly, or via For*Types: -// Calls Test for all power of two N in [1, Lanes(d) / kFactor]. This is for +// Calls Test for all power of two N in [1, Lanes(d) >> kPow2]. This is for // ops that widen their input, e.g. Combine (not supported by HWY_SCALAR). -template +template struct ForExtendableVectors { template void operator()(T /*unused*/) const { + constexpr size_t kMaxCapped = HWY_LANES(T); + // Skip CappedTag that are already full vectors. + const size_t max_lanes = Lanes(ScalableTag()) >> kPow2; + (void)kMaxCapped; + (void)max_lanes; #if HWY_TARGET == HWY_SCALAR // not supported #else - constexpr bool kPromote = true; + ForeachCappedR> kPow2), 1, Test>::Do(1, max_lanes); #if HWY_TARGET == HWY_RVV - ForeachSizeR::Do(); - // TODO(janwas): also capped - // ForeachSizeR::Do(); -#elif HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 - // Capped - ForeachSizeR::Do(); - // Fractions - ForeachSizeR::Do(); -#else - ForeachSizeR::Do(); + // For each [MinPow2, 3 - kPow2]; counter is [MinPow2 + kPow2, 3]. + ForeachShiftR() + kPow2, -kPow2, Test>::Do(1); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2, 0 - kPow2]; counter is [MinPow2 + kPow2 + 3, 3]. + ForeachShiftR() + kPow2 + 3, -kPow2 - 3, Test>::Do(1); #endif #endif // HWY_SCALAR } }; -// Calls Test for all power of two N in [kFactor, Lanes(d)]. This is for ops +// Calls Test for all power of two N in [1 << kPow2, Lanes(d)]. This is for ops // that narrow their input, e.g. UpperHalf. -template +template struct ForShrinkableVectors { template void operator()(T /*unused*/) const { + constexpr size_t kMinLanes = size_t{1} << kPow2; + constexpr size_t kMaxCapped = HWY_LANES(T); + // For shrinking, an upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + + (void)kMinLanes; + (void)max_lanes; + (void)max_lanes; #if HWY_TARGET == HWY_SCALAR // not supported -#elif HWY_TARGET == HWY_RVV - ForeachSizeR::Do(); - // TODO(janwas): also capped -#elif HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 - // Capped - ForeachSizeR::Do(); - // Fractions - ForeachSizeR::Do(); -#elif HWY_TARGET == HWY_SCALAR - // not supported #else - ForeachSizeR::Do(); + ForeachCappedR> kPow2), kMinLanes, Test>::Do(kMinLanes, + max_lanes); +#if HWY_TARGET == HWY_RVV + // For each [MinPow2 + kPow2, 3]; counter is [MinPow2 + kPow2, 3]. + ForeachShiftR() + kPow2, 0, Test>::Do(kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2 + kPow2, 0]; counter is [MinPow2 + kPow2 + 3, 3]. + ForeachShiftR() + kPow2 + 3, -3, Test>::Do(kMinLanes); #endif +#endif // HWY_TARGET == HWY_SCALAR } }; -// Calls Test for all power of two N in [16 / sizeof(T), Lanes(d)]. This is for -// ops that require at least 128 bits, e.g. AES or 64x64 = 128 mul. -template -struct ForGE128Vectors { +// Calls Test for all supported power of two vectors of at least kMinBits. +// Examples: AES or 64x64 require 128 bits, casts may require 64 bits. +template +struct ForGEVectors { template void operator()(T /*unused*/) const { + constexpr size_t kMaxCapped = HWY_LANES(T); + constexpr size_t kMinLanes = kMinBits / 8 / sizeof(T); + // An upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + (void)max_lanes; #if HWY_TARGET == HWY_SCALAR - // not supported -#elif HWY_TARGET == HWY_RVV - ForeachSizeR::Do(); - // TODO(janwas): also capped - // ForeachSizeR::Do(); -#elif HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 - // Capped - ForeachSizeR::Do(); - // Fractions - ForeachSizeR::Do(); + (void)kMinLanes; // not supported #else - ForeachSizeR::Do(); + ForeachCappedR::Do(kMinLanes, + max_lanes); +#if HWY_TARGET == HWY_RVV + // Can be 0 (handled below) if kMinBits > 64. + constexpr size_t kRatio = MinVectorSize() * 8 / kMinBits; + constexpr int kMinPow2 = + kRatio == 0 ? 0 : -static_cast(CeilLog2(kRatio)); + // For each [kMinPow2, 3]; counter is [kMinPow2, 3]. + ForeachShiftR::Do(kMinLanes); +#elif HWY_HAVE_SCALABLE + // Can be 0 (handled below) if kMinBits > 128. + constexpr size_t kRatio = MinVectorSize() * 8 / kMinBits; + constexpr int kMinPow2 = + kRatio == 0 ? 0 : -static_cast(CeilLog2(kRatio)); + // For each [kMinPow2, 0]; counter is [kMinPow2 + 3, 3]. + ForeachShiftR::Do(kMinLanes); #endif +#endif // HWY_TARGET == HWY_SCALAR } }; -// Calls Test for all power of two N in [8 / sizeof(T), Lanes(d)]. This is for -// ops that require at least 64 bits, e.g. casts. template -struct ForGE64Vectors { - template - void operator()(T /*unused*/) const { -#if HWY_TARGET == HWY_SCALAR - // not supported -#elif HWY_TARGET == HWY_RVV - ForeachSizeR::Do(); - // TODO(janwas): also capped - // ForeachSizeR::Do(); -#elif HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 - // Capped - ForeachSizeR::Do(); - // Fractions - ForeachSizeR::Do(); -#else - ForeachSizeR::Do(); -#endif - } -}; +using ForGE128Vectors = ForGEVectors<128, Test>; // Calls Test for all N that can be promoted (not the same as Extendable because // HWY_SCALAR has one lane). Also used for ZipLower, but not ZipUpper. -template +template struct ForPromoteVectors { template void operator()(T /*unused*/) const { + constexpr size_t kFactor = size_t{1} << kPow2; + static_assert(kFactor >= 2 && kFactor * sizeof(T) <= sizeof(uint64_t), ""); + constexpr size_t kMaxCapped = HWY_LANES(T); + constexpr size_t kMinLanes = kFactor; + // Skip CappedTag that are already full vectors. + const size_t max_lanes = Lanes(ScalableTag()) >> kPow2; + (void)kMaxCapped; + (void)kMinLanes; + (void)max_lanes; #if HWY_TARGET == HWY_SCALAR - ForeachSizeR::Do(); + ForeachCappedR::Do(1, 1); #else - return ForExtendableVectors()(T()); + // TODO(janwas): call Extendable if kMinLanes check not required? + ForeachCappedR> kPow2), 1, Test>::Do(kMinLanes, max_lanes); +#if HWY_TARGET == HWY_RVV + // For each [MinPow2, 3 - kPow2]; counter is [MinPow2 + kPow2, 3]. + ForeachShiftR() + kPow2, -kPow2, Test>::Do(kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2, 0 - kPow2]; counter is [MinPow2 + kPow2 + 3, 3]. + ForeachShiftR() + kPow2 + 3, -kPow2 - 3, Test>::Do(kMinLanes); #endif +#endif // HWY_SCALAR } }; // Calls Test for all N than can be demoted (not the same as Shrinkable because -// HWY_SCALAR has one lane). Also used for LowerHalf, but not UpperHalf. -template +// HWY_SCALAR has one lane). +template struct ForDemoteVectors { template void operator()(T /*unused*/) const { + constexpr size_t kMinLanes = size_t{1} << kPow2; + constexpr size_t kMaxCapped = HWY_LANES(T); + // For shrinking, an upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + + (void)kMinLanes; + (void)max_lanes; + (void)max_lanes; #if HWY_TARGET == HWY_SCALAR - ForeachSizeR::Do(); + ForeachCappedR::Do(1, 1); #else - return ForShrinkableVectors()(T()); + ForeachCappedR> kPow2), kMinLanes, Test>::Do(kMinLanes, + max_lanes); + +// TODO(janwas): call Extendable if kMinLanes check not required? +#if HWY_TARGET == HWY_RVV + // For each [MinPow2 + kPow2, 3]; counter is [MinPow2 + kPow2, 3]. + ForeachShiftR() + kPow2, 0, Test>::Do(kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2 + kPow2, 0]; counter is [MinPow2 + kPow2 + 3, 3]. + ForeachShiftR() + kPow2 + 3, -3, Test>::Do(kMinLanes); #endif +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +// For LowerHalf/Quarter. +template +struct ForHalfVectors { + template + void operator()(T /*unused*/) const { + constexpr size_t kMinLanes = size_t{1} << kPow2; + constexpr size_t kMaxCapped = HWY_LANES(T); + // For shrinking, an upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + + (void)kMinLanes; + (void)max_lanes; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + ForeachCappedR::Do(1, 1); +#else +// ForeachCappedR> kPow2), kMinLanes, Test>::Do(kMinLanes, +// max_lanes); + +// TODO(janwas): call Extendable if kMinLanes check not required? +#if HWY_TARGET == HWY_RVV + // For each [MinPow2 + kPow2, 3]; counter is [MinPow2 + kPow2, 3]. + ForeachShiftR() + kPow2, 0, Test>::Do(kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2 + kPow2, 0]; counter is [MinPow2 + kPow2 + 3, 3]. + ForeachShiftR() + kPow2 + 3, -3, Test>::Do(kMinLanes); +#endif +#endif // HWY_TARGET == HWY_SCALAR } }; @@ -350,7 +457,7 @@ template struct ForPartialVectors { template void operator()(T t) const { - ForExtendableVectors()(t); + ForExtendableVectors()(t); } }; @@ -361,7 +468,7 @@ void ForSignedTypes(const Func& func) { func(int8_t()); func(int16_t()); func(int32_t()); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 func(int64_t()); #endif } @@ -371,7 +478,7 @@ void ForUnsignedTypes(const Func& func) { func(uint8_t()); func(uint16_t()); func(uint32_t()); -#if HWY_CAP_INTEGER64 +#if HWY_HAVE_INTEGER64 func(uint64_t()); #endif } @@ -385,7 +492,7 @@ void ForIntegerTypes(const Func& func) { template void ForFloatTypes(const Func& func) { func(float()); -#if HWY_CAP_FLOAT64 +#if HWY_HAVE_FLOAT64 func(double()); #endif } @@ -397,32 +504,49 @@ void ForAllTypes(const Func& func) { } template -void ForUIF3264(const Func& func) { +void ForUIF16(const Func& func) { + func(uint16_t()); + func(int16_t()); +#if HWY_HAVE_FLOAT16 + func(float16_t()); +#endif +} + +template +void ForUIF32(const Func& func) { func(uint32_t()); func(int32_t()); -#if HWY_CAP_INTEGER64 + func(float()); +} + +template +void ForUIF64(const Func& func) { +#if HWY_HAVE_INTEGER64 func(uint64_t()); func(int64_t()); #endif +#if HWY_HAVE_FLOAT64 + func(double()); +#endif +} - ForFloatTypes(func); +template +void ForUIF3264(const Func& func) { + ForUIF32(func); + ForUIF64(func); } template void ForUIF163264(const Func& func) { + ForUIF16(func); ForUIF3264(func); - func(uint16_t()); - func(int16_t()); -#if HWY_CAP_FLOAT16 - func(float16_t()); -#endif } // For tests that involve loops, adjust the trip count so that emulated tests // finish quickly (but always at least 2 iterations to ensure some diversity). constexpr size_t AdjustedReps(size_t max_reps) { #if HWY_ARCH_RVV - return HWY_MAX(max_reps / 16, 2); + return HWY_MAX(max_reps / 32, 2); #elif HWY_ARCH_ARM return HWY_MAX(max_reps / 4, 2); #elif HWY_IS_DEBUG_BUILD @@ -432,6 +556,20 @@ constexpr size_t AdjustedReps(size_t max_reps) { #endif } +// Same as above, but the loop trip count will be 1 << max_pow2. +constexpr size_t AdjustedLog2Reps(size_t max_pow2) { + // If "negative" (unsigned wraparound), use original. +#if HWY_ARCH_RVV + return HWY_MIN(max_pow2 - 4, max_pow2); +#elif HWY_ARCH_ARM + return HWY_MIN(max_pow2 - 1, max_pow2); +#elif HWY_IS_DEBUG_BUILD + return HWY_MIN(max_pow2 - 1, max_pow2); +#else + return max_pow2; +#endif +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy diff --git a/third_party/highway/hwy/tests/test_util.cc b/third_party/highway/hwy/tests/test_util.cc index 861f6a4d642a..45757ca821b0 100644 --- a/third_party/highway/hwy/tests/test_util.cc +++ b/third_party/highway/hwy/tests/test_util.cc @@ -30,9 +30,6 @@ bool BytesEqual(const void* p1, const void* p2, const size_t size, const uint8_t* bytes2 = reinterpret_cast(p2); for (size_t i = 0; i < size; ++i) { if (bytes1[i] != bytes2[i]) { - fprintf(stderr, "Mismatch at byte %" PRIu64 " of %" PRIu64 ": %d != %d\n", - static_cast(i), static_cast(size), bytes1[i], - bytes2[i]); if (pos != nullptr) { *pos = i; } diff --git a/third_party/highway/hwy/tests/test_util.h b/third_party/highway/hwy/tests/test_util.h index 076d82b2bb69..fc7d3bb00816 100644 --- a/third_party/highway/hwy/tests/test_util.h +++ b/third_party/highway/hwy/tests/test_util.h @@ -26,6 +26,7 @@ #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/highway.h" +#include "hwy/highway_export.h" namespace hwy { @@ -67,9 +68,7 @@ static HWY_INLINE uint32_t Random32(RandomState* rng) { return static_cast((*rng)()); } -static HWY_INLINE uint64_t Random64(RandomState* rng) { - return (*rng)(); -} +static HWY_INLINE uint64_t Random64(RandomState* rng) { return (*rng)(); } // Prevents the compiler from eliding the computations that led to "output". // Works by indicating to the compiler that "output" is being read and modified. @@ -84,8 +83,8 @@ inline void PreventElision(T&& output) { #endif // HWY_COMPILER_MSVC } -bool BytesEqual(const void* p1, const void* p2, const size_t size, - size_t* pos = nullptr); +HWY_TEST_DLLEXPORT bool BytesEqual(const void* p1, const void* p2, + const size_t size, size_t* pos = nullptr); void AssertStringEqual(const char* expected, const char* actual, const char* target_name, const char* filename, int line); @@ -129,25 +128,25 @@ HWY_INLINE TypeInfo MakeTypeInfo() { return info; } -bool IsEqual(const TypeInfo& info, const void* expected_ptr, - const void* actual_ptr); +HWY_TEST_DLLEXPORT bool IsEqual(const TypeInfo& info, const void* expected_ptr, + const void* actual_ptr); -void TypeName(const TypeInfo& info, size_t N, char* string100); +HWY_TEST_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100); -void PrintArray(const TypeInfo& info, const char* caption, - const void* array_void, size_t N, size_t lane_u = 0, - size_t max_lanes = 7); +HWY_TEST_DLLEXPORT void PrintArray(const TypeInfo& info, const char* caption, + const void* array_void, size_t N, + size_t lane_u = 0, size_t max_lanes = 7); -HWY_NORETURN void PrintMismatchAndAbort(const TypeInfo& info, - const void* expected_ptr, - const void* actual_ptr, - const char* target_name, - const char* filename, int line, - size_t lane = 0, size_t num_lanes = 1); +HWY_TEST_DLLEXPORT HWY_NORETURN void PrintMismatchAndAbort( + const TypeInfo& info, const void* expected_ptr, const void* actual_ptr, + const char* target_name, const char* filename, int line, size_t lane = 0, + size_t num_lanes = 1); -void AssertArrayEqual(const TypeInfo& info, const void* expected_void, - const void* actual_void, size_t N, - const char* target_name, const char* filename, int line); +HWY_TEST_DLLEXPORT void AssertArrayEqual(const TypeInfo& info, + const void* expected_void, + const void* actual_void, size_t N, + const char* target_name, + const char* filename, int line); } // namespace detail diff --git a/third_party/highway/hwy/tests/test_util_test.cc b/third_party/highway/hwy/tests/test_util_test.cc index af484adbaeec..704c056d4bf1 100644 --- a/third_party/highway/hwy/tests/test_util_test.cc +++ b/third_party/highway/hwy/tests/test_util_test.cc @@ -52,10 +52,10 @@ HWY_NOINLINE void TestAllName() { ForAllTypes(ForPartialVectors()); } struct TestEqualInteger { template HWY_NOINLINE void operator()(T /*t*/) const { - HWY_ASSERT(IsEqual(T(0), T(0))); - HWY_ASSERT(IsEqual(T(1), T(1))); - HWY_ASSERT(IsEqual(T(-1), T(-1))); - HWY_ASSERT(IsEqual(LimitsMin(), LimitsMin())); + HWY_ASSERT_EQ(T(0), T(0)); + HWY_ASSERT_EQ(T(1), T(1)); + HWY_ASSERT_EQ(T(-1), T(-1)); + HWY_ASSERT_EQ(LimitsMin(), LimitsMin()); HWY_ASSERT(!IsEqual(T(0), T(1))); HWY_ASSERT(!IsEqual(T(1), T(0))); diff --git a/third_party/highway/libhwy-contrib.pc.in b/third_party/highway/libhwy-contrib.pc.in index 260f15fa2f15..89c45f5e42db 100644 --- a/third_party/highway/libhwy-contrib.pc.in +++ b/third_party/highway/libhwy-contrib.pc.in @@ -4,7 +4,7 @@ libdir=${exec_prefix}/@CMAKE_INSTALL_LIBDIR@ includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ Name: libhwy-contrib -Description: Additions to Highway: image and math library +Description: Additions to Highway: dot product, image, math, sort Version: @HWY_LIBRARY_VERSION@ Libs: -L${libdir} -lhwy_contrib Cflags: -I${includedir} diff --git a/third_party/highway/libhwy.pc.in b/third_party/highway/libhwy.pc.in index 2ada0e847cbf..643989275df8 100644 --- a/third_party/highway/libhwy.pc.in +++ b/third_party/highway/libhwy.pc.in @@ -7,4 +7,4 @@ Name: libhwy Description: Efficient and performance-portable SIMD wrapper Version: @HWY_LIBRARY_VERSION@ Libs: -L${libdir} -lhwy -Cflags: -I${includedir} +Cflags: -I${includedir} -D@DLLEXPORT_TO_DEFINE@ diff --git a/third_party/highway/preamble.js.lds b/third_party/highway/preamble.js.lds new file mode 100644 index 000000000000..f484a19d2c69 --- /dev/null +++ b/third_party/highway/preamble.js.lds @@ -0,0 +1,9 @@ +/* + * Copyright 2019 Google LLC + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* mock crypto module for benchmarks and unit tests or std::random_device fails at runtime */ +var crypto = { getRandomValues: function(array) { for (var i = 0; i < array.length; i++) array[i] = (Math.random()*256)|0 } }; \ No newline at end of file diff --git a/third_party/highway/run_tests.sh b/third_party/highway/run_tests.sh index 4efae5c77781..017e536acf3d 100644 --- a/third_party/highway/run_tests.sh +++ b/third_party/highway/run_tests.sh @@ -59,7 +59,7 @@ export QEMU_LD_PREFIX=/usr/arm-linux-gnueabihf rm -rf build_arm7 mkdir build_arm7 cd build_arm7 -CC=arm-linux-gnueabihf-gcc CXX=arm-linux-gnueabihf-g++ cmake .. -DHWY_CMAKE_ARM7:BOOL=ON -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON +CC=arm-linux-gnueabihf-gcc-11 CXX=arm-linux-gnueabihf-g++-11 cmake .. -DHWY_CMAKE_ARM7:BOOL=ON -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON make -j8 ctest cd .. @@ -71,7 +71,7 @@ export QEMU_LD_PREFIX=/usr/aarch64-linux-gnu rm -rf build_arm8 mkdir build_arm8 cd build_arm8 -CC=aarch64-linux-gnu-gcc CXX=aarch64-linux-gnu-g++ cmake .. -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON +CC=aarch64-linux-gnu-gcc-11 CXX=aarch64-linux-gnu-g++-11 cmake .. -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON make -j8 ctest cd .. diff --git a/third_party/jpeg-xl/.github/workflows/build_test.yml b/third_party/jpeg-xl/.github/workflows/build_test.yml index ba2e2e83cd99..2c99b0ab7a9b 100644 --- a/third_party/jpeg-xl/.github/workflows/build_test.yml +++ b/third_party/jpeg-xl/.github/workflows/build_test.yml @@ -34,7 +34,7 @@ jobs: env_stack_size: 1 max_stack: 3000 # Conformance tooling test requires numpy. - apt_pkgs: python3-numpy + apt_pkgs: graphviz python3-numpy - name: lowprecision mode: release test_in_pr: true @@ -461,8 +461,8 @@ jobs: runs-on: ubuntu-latest env: CCACHE_DIR: ${{ github.workspace }}/.ccache - EM_VERSION: 2.0.23 - V8_VERSION: 9.3.22 + EM_VERSION: 3.1.4 + V8_VERSION: 9.8.177 V8: ${{ github.workspace }}/.jsvu/v8 BUILD_TARGET: wasm32 @@ -506,7 +506,7 @@ jobs: ${{ runner.os }}-${{ steps.git-env.outputs.parent }}-${{ matrix.variant }} - name: Install emsdk - uses: mymindstorm/setup-emsdk@v10 + uses: mymindstorm/setup-emsdk@v11 # TODO(deymo): We could cache this action but it doesn't work when running # in a matrix. with: diff --git a/third_party/jpeg-xl/deps.sh b/third_party/jpeg-xl/deps.sh index 1abf187426fe..e2bbd755a987 100755 --- a/third_party/jpeg-xl/deps.sh +++ b/third_party/jpeg-xl/deps.sh @@ -14,7 +14,7 @@ MYDIR=$(dirname $(realpath "$0")) # Git revisions we use for the given submodules. Update these whenever you # update a git submodule. THIRD_PARTY_GFLAGS="827c769e5fc98e0f2a34c47cef953cc6328abced" -THIRD_PARTY_HIGHWAY="e69083a12a05caf037cabecdf1b248b7579705a5" +THIRD_PARTY_HIGHWAY="f13e3b956eb226561ac79427893ec0afd66f91a8" THIRD_PARTY_SKCMS="64374756e03700d649f897dbd98c95e78c30c7da" THIRD_PARTY_SJPEG="868ab558fad70fcbe8863ba4e85179eeb81cc840" THIRD_PARTY_ZLIB="cacf7f1d4e3d44d871b605da3b647f07d718623f" diff --git a/third_party/jpeg-xl/lib/extras/codec.cc b/third_party/jpeg-xl/lib/extras/codec.cc index e23344aa8058..933defe2ae26 100644 --- a/third_party/jpeg-xl/lib/extras/codec.cc +++ b/third_party/jpeg-xl/lib/extras/codec.cc @@ -5,6 +5,12 @@ #include "lib/extras/codec.h" +#include "jxl/decode.h" +#include "jxl/types.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" + #if JPEGXL_ENABLE_APNG #include "lib/extras/enc/apng.h" #endif @@ -68,6 +74,14 @@ Status Encode(const CodecInOut& io, const extras::Codec codec, JXL_WARNING("Writing JPEG data as pixels"); } + extras::PackedPixelFile ppf; + size_t num_channels = io.metadata.m.color_encoding.Channels(); + JxlPixelFormat format = { + static_cast(num_channels), + bits_per_sample <= 8 ? JXL_TYPE_UINT8 : JXL_TYPE_UINT16, + JXL_NATIVE_ENDIAN, 0}; + std::vector bytes_vector; + const bool floating_point = bits_per_sample > 16; switch (codec) { case extras::Codec::kPNG: #if JPEGXL_ENABLE_APNG @@ -87,8 +101,24 @@ Status Encode(const CodecInOut& io, const extras::Codec codec, return JXL_FAILURE("JPEG XL was built without JPEG support"); #endif case extras::Codec::kPNM: - return extras::EncodeImagePNM(&io, c_desired, bits_per_sample, pool, - bytes); + + // Choose native for PFM; PGM/PPM require big-endian (N/A for PBM) + format.endianness = floating_point ? JXL_NATIVE_ENDIAN : JXL_BIG_ENDIAN; + if (floating_point) { + format.data_type = JXL_TYPE_FLOAT; + } + if (!c_desired.IsSRGB()) { + JXL_WARNING( + "PNM encoder cannot store custom ICC profile; decoder\n" + "will need hint key=color_space to get the same values"); + } + JXL_RETURN_IF_ERROR(extras::ConvertCodecInOutToPackedPixelFile( + io, format, c_desired, pool, &ppf)); + JXL_RETURN_IF_ERROR( + extras::EncodeImagePNM(ppf, bits_per_sample, pool, &bytes_vector)); + bytes->assign(bytes_vector.data(), + bytes_vector.data() + bytes_vector.size()); + return true; case extras::Codec::kPGX: return extras::EncodeImagePGX(&io, c_desired, bits_per_sample, pool, bytes); diff --git a/third_party/jpeg-xl/lib/extras/dec/color_description.cc b/third_party/jpeg-xl/lib/extras/dec/color_description.cc index 2d0aa3a9db9a..2325b50f3b1b 100644 --- a/third_party/jpeg-xl/lib/extras/dec/color_description.cc +++ b/third_party/jpeg-xl/lib/extras/dec/color_description.cc @@ -61,7 +61,6 @@ const EnumName kJxlRenderingIntentNames[] = { template Status ParseEnum(const std::string& token, const EnumName* enum_values, size_t enum_len, T* value) { - std::string str; for (size_t i = 0; i < enum_len; i++) { if (enum_values[i].name == token) { *value = enum_values[i].value; diff --git a/third_party/jpeg-xl/lib/extras/dec/gif.cc b/third_party/jpeg-xl/lib/extras/dec/gif.cc index 92b7017d9b39..4245df4aae1a 100644 --- a/third_party/jpeg-xl/lib/extras/dec/gif.cc +++ b/third_party/jpeg-xl/lib/extras/dec/gif.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -32,8 +33,12 @@ struct PackedRgba { uint8_t r, g, b, a; }; -// Gif does not support partial transparency, so this considers anything non-0 -// as opaque. +struct PackedRgb { + uint8_t r, g, b; +}; + +// Gif does not support partial transparency, so this considers any nonzero +// alpha channel value as opaque. bool AllOpaque(const PackedImage& color) { for (size_t y = 0; y < color.ysize; ++y) { const PackedRgba* const JXL_RESTRICT row = @@ -47,6 +52,21 @@ bool AllOpaque(const PackedImage& color) { return true; } +void ensure_have_alpha(PackedFrame* frame) { + if (!frame->extra_channels.empty()) return; + const JxlPixelFormat alpha_format{ + /*num_channels=*/1u, + /*data_type=*/JXL_TYPE_UINT8, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0, + }; + frame->extra_channels.emplace_back(frame->color.xsize, frame->color.ysize, + alpha_format); + // We need to set opaque-by-default. + std::fill_n(static_cast(frame->extra_channels[0].pixels()), + frame->color.xsize * frame->color.ysize, 255u); +} + } // namespace Status DecodeImageGIF(Span bytes, const ColorHints& color_hints, @@ -138,13 +158,26 @@ Status DecodeImageGIF(Span bytes, const ColorHints& color_hints, ppf->info.num_color_channels = 3; - const JxlPixelFormat format{ + // Pixel format for the 'canvas' onto which we paint + // the (potentially individually cropped) GIF frames + // of an animation. + const JxlPixelFormat canvas_format{ /*num_channels=*/4u, /*data_type=*/JXL_TYPE_UINT8, /*endianness=*/JXL_NATIVE_ENDIAN, /*align=*/0, }; + // Pixel format for the JXL PackedFrame that goes into the + // PackedPixelFile. Here, we use 3 color channels, and provide + // the alpha channel as an extra_channel wherever it is used. + const JxlPixelFormat packed_frame_format{ + /*num_channels=*/3u, + /*data_type=*/JXL_TYPE_UINT8, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0, + }; + GifColorType background_color; if (gif->SColorMap == nullptr || gif->SBackGroundColor >= gif->SColorMap->ColorCount) { @@ -154,14 +187,13 @@ Status DecodeImageGIF(Span bytes, const ColorHints& color_hints, } const PackedRgba background_rgba{background_color.Red, background_color.Green, background_color.Blue, 0}; - PackedFrame canvas(gif->SWidth, gif->SHeight, format); + PackedFrame canvas(gif->SWidth, gif->SHeight, canvas_format); std::fill_n(static_cast(canvas.color.pixels()), canvas.color.xsize * canvas.color.ysize, background_rgba); Rect canvas_rect{0, 0, canvas.color.xsize, canvas.color.ysize}; Rect previous_rect_if_restore_to_background; - bool has_alpha = false; bool replace = true; bool last_base_was_none = true; for (int i = 0; i < gif->ImageCount; ++i) { @@ -199,8 +231,22 @@ Status DecodeImageGIF(Span bytes, const ColorHints& color_hints, } // Allocates the frame buffer. - ppf->frames.emplace_back(total_rect.xsize(), total_rect.ysize(), format); - auto* frame = &ppf->frames.back(); + ppf->frames.emplace_back(total_rect.xsize(), total_rect.ysize(), + packed_frame_format); + PackedFrame* frame = &ppf->frames.back(); + + // We cannot tell right from the start whether there will be a + // need for an alpha channel. This is discovered only as soon as + // we see a transparent pixel. We hence initialize alpha lazily. + auto set_pixel_alpha = [&frame](size_t x, size_t y, uint8_t a) { + // If we do not have an alpha-channel and a==255 (fully opaque), + // we can skip setting this pixel-value and rely on + // "no alpha channel = no transparency". + if (a == 255 && !frame->extra_channels.empty()) return; + ensure_have_alpha(frame); + static_cast( + frame->extra_channels[0].pixels())[y * frame->color.xsize + x] = a; + }; const ColorMapObject* const color_map = image.ImageDesc.ColorMap ? image.ImageDesc.ColorMap : gif->SColorMap; @@ -270,23 +316,26 @@ Status DecodeImageGIF(Span bytes, const ColorHints& color_hints, } } const PackedImage& sub_frame_image = frame->color; - bool blend_alpha = false; if (replace) { // Copy from the new canvas image to the subframe for (size_t y = 0; y < total_rect.ysize(); ++y) { const PackedRgba* row_in = static_cast(new_canvas_image.pixels()) + (y + total_rect.y0()) * new_canvas_image.xsize + total_rect.x0(); - PackedRgba* row_out = - static_cast(sub_frame_image.pixels()) + - y * sub_frame_image.xsize; - memcpy(row_out, row_in, sub_frame_image.xsize * sizeof(PackedRgba)); + PackedRgb* row_out = static_cast(sub_frame_image.pixels()) + + y * sub_frame_image.xsize; + for (size_t x = 0; x < sub_frame_image.xsize; ++x) { + row_out[x].r = row_in[x].r; + row_out[x].g = row_in[x].g; + row_out[x].b = row_in[x].b; + set_pixel_alpha(x, y, row_in[x].a); + } } } else { for (size_t y = 0, byte_index = 0; y < image_rect.ysize(); ++y) { // Assumes format.align == 0 - PackedRgba* row = static_cast(sub_frame_image.pixels()) + - y * sub_frame_image.xsize; + PackedRgb* row = static_cast(sub_frame_image.pixels()) + + y * sub_frame_image.xsize; for (size_t x = 0; x < image_rect.xsize(); ++x, ++byte_index) { const GifByteType byte = image.RasterBits[byte_index]; if (byte > color_map->ColorCount) { @@ -296,22 +345,19 @@ Status DecodeImageGIF(Span bytes, const ColorHints& color_hints, row[x].r = 0; row[x].g = 0; row[x].b = 0; - row[x].a = 0; - blend_alpha = - true; // need to use alpha channel if BlendMode blend is used + set_pixel_alpha(x, y, 0); continue; } GifColorType color = color_map->Colors[byte]; row[x].r = color.Red; row[x].g = color.Green; row[x].b = color.Blue; - row[x].a = 255; + set_pixel_alpha(x, y, 255); } } } - if (!has_alpha && (!AllOpaque(sub_frame_image) || blend_alpha)) { - has_alpha = true; + if (!frame->extra_channels.empty()) { ppf->info.alpha_bits = 8; } @@ -335,7 +381,20 @@ Status DecodeImageGIF(Span bytes, const ColorHints& color_hints, canvas.color.xsize * canvas.color.ysize, background_rgba); } } - + // Finally, if any frame has an alpha-channel, every frame will need + // to have an alpha-channel. + bool seen_alpha = false; + for (const PackedFrame& frame : ppf->frames) { + if (!frame.extra_channels.empty()) { + seen_alpha = true; + break; + } + } + if (seen_alpha) { + for (PackedFrame& frame : ppf->frames) { + ensure_have_alpha(&frame); + } + } return true; } diff --git a/third_party/jpeg-xl/lib/extras/dec/pgx.cc b/third_party/jpeg-xl/lib/extras/dec/pgx.cc index 9df36e06c9b7..7b79eaf88cb6 100644 --- a/third_party/jpeg-xl/lib/extras/dec/pgx.cc +++ b/third_party/jpeg-xl/lib/extras/dec/pgx.cc @@ -127,9 +127,9 @@ class Parser { } size_t numpixels = header->xsize * header->ysize; - size_t bytes_per_pixel = header->bits_per_sample <= 8 - ? 1 - : header->bits_per_sample <= 16 ? 2 : 4; + size_t bytes_per_pixel = header->bits_per_sample <= 8 ? 1 + : header->bits_per_sample <= 16 ? 2 + : 4; if (pos_ + numpixels * bytes_per_pixel > end_) { return JXL_FAILURE("PGX: data too small"); } diff --git a/third_party/jpeg-xl/lib/extras/enc/apng.cc b/third_party/jpeg-xl/lib/extras/enc/apng.cc index 7dc872cfe6ec..30750ee55fae 100644 --- a/third_party/jpeg-xl/lib/extras/enc/apng.cc +++ b/third_party/jpeg-xl/lib/extras/enc/apng.cc @@ -132,7 +132,9 @@ Status EncodeImageAPNG(const CodecInOut* io, const ColorEncoding& c_desired, size_t anim_chunks = 0; int W = 0, H = 0; - for (auto& frame : io->frames) { + for (size_t i = 0; i < io->frames.size(); i++) { + auto& frame = io->frames[i]; + if (!have_anim && i + 1 < io->frames.size()) continue; png_structp png_ptr; png_infop info_ptr; diff --git a/third_party/jpeg-xl/lib/extras/enc/jpg.cc b/third_party/jpeg-xl/lib/extras/enc/jpg.cc index 9e3478d25b6d..27d6eff78faf 100644 --- a/third_party/jpeg-xl/lib/extras/enc/jpg.cc +++ b/third_party/jpeg-xl/lib/extras/enc/jpg.cc @@ -46,120 +46,6 @@ constexpr int kExifMarker = JPEG_APP0 + 1; constexpr float kJPEGSampleMin = 0; constexpr float kJPEGSampleMax = MAXJSAMPLE; -static inline bool IsJPG(const Span bytes) { - if (bytes.size() < 2) return false; - if (bytes[0] != 0xFF || bytes[1] != 0xD8) return false; - return true; -} - -bool MarkerIsICC(const jpeg_saved_marker_ptr marker) { - return marker->marker == kICCMarker && - marker->data_length >= sizeof kICCSignature + 2 && - std::equal(std::begin(kICCSignature), std::end(kICCSignature), - marker->data); -} -bool MarkerIsExif(const jpeg_saved_marker_ptr marker) { - return marker->marker == kExifMarker && - marker->data_length >= sizeof kExifSignature + 2 && - std::equal(std::begin(kExifSignature), std::end(kExifSignature), - marker->data); -} - -Status ReadICCProfile(jpeg_decompress_struct* const cinfo, - std::vector* const icc) { - constexpr size_t kICCSignatureSize = sizeof kICCSignature; - // ICC signature + uint8_t index + uint8_t max_index. - constexpr size_t kICCHeadSize = kICCSignatureSize + 2; - // Markers are 1-indexed, and we keep them that way in this vector to get a - // convenient 0 at the front for when we compute the offsets later. - std::vector marker_lengths; - int num_markers = 0; - int seen_markers_count = 0; - bool has_num_markers = false; - for (jpeg_saved_marker_ptr marker = cinfo->marker_list; marker != nullptr; - marker = marker->next) { - // marker is initialized by libjpeg, which we are not instrumenting with - // msan. - msan::UnpoisonMemory(marker, sizeof(*marker)); - msan::UnpoisonMemory(marker->data, marker->data_length); - if (!MarkerIsICC(marker)) continue; - - const int current_marker = marker->data[kICCSignatureSize]; - if (current_marker == 0) { - return JXL_FAILURE("inconsistent JPEG ICC marker numbering"); - } - const int current_num_markers = marker->data[kICCSignatureSize + 1]; - if (current_marker > current_num_markers) { - return JXL_FAILURE("inconsistent JPEG ICC marker numbering"); - } - if (has_num_markers) { - if (current_num_markers != num_markers) { - return JXL_FAILURE("inconsistent numbers of JPEG ICC markers"); - } - } else { - num_markers = current_num_markers; - has_num_markers = true; - marker_lengths.resize(num_markers + 1); - } - - size_t marker_length = marker->data_length - kICCHeadSize; - - if (marker_length == 0) { - // NB: if we allow empty chunks, then the next check is incorrect. - return JXL_FAILURE("Empty ICC chunk"); - } - - if (marker_lengths[current_marker] != 0) { - return JXL_FAILURE("duplicate JPEG ICC marker number"); - } - marker_lengths[current_marker] = marker_length; - seen_markers_count++; - } - - if (marker_lengths.empty()) { - // Not an error. - return false; - } - - if (seen_markers_count != num_markers) { - JXL_DASSERT(has_num_markers); - return JXL_FAILURE("Incomplete set of ICC chunks"); - } - - std::vector offsets = std::move(marker_lengths); - std::partial_sum(offsets.begin(), offsets.end(), offsets.begin()); - icc->resize(offsets.back()); - - for (jpeg_saved_marker_ptr marker = cinfo->marker_list; marker != nullptr; - marker = marker->next) { - if (!MarkerIsICC(marker)) continue; - const uint8_t* first = marker->data + kICCHeadSize; - uint8_t current_marker = marker->data[kICCSignatureSize]; - size_t offset = offsets[current_marker - 1]; - size_t marker_length = offsets[current_marker] - offset; - std::copy_n(first, marker_length, icc->data() + offset); - } - - return true; -} - -void ReadExif(jpeg_decompress_struct* const cinfo, - std::vector* const exif) { - constexpr size_t kExifSignatureSize = sizeof kExifSignature; - for (jpeg_saved_marker_ptr marker = cinfo->marker_list; marker != nullptr; - marker = marker->next) { - // marker is initialized by libjpeg, which we are not instrumenting with - // msan. - msan::UnpoisonMemory(marker, sizeof(*marker)); - msan::UnpoisonMemory(marker->data, marker->data_length); - if (!MarkerIsExif(marker)) continue; - size_t marker_length = marker->data_length - kExifSignatureSize; - exif->resize(marker_length); - std::copy_n(marker->data + kExifSignatureSize, marker_length, exif->data()); - return; - } -} - // TODO (jon): take orientation into account when writing jpeg output // TODO (jon): write Exif blob also in sjpeg encoding // TODO (jon): overwrite orientation in Exif blob to avoid double orientation @@ -213,149 +99,8 @@ Status SetChromaSubsampling(const YCbCrChromaSubsampling& chroma_subsampling, return true; } -void MyErrorExit(j_common_ptr cinfo) { - jmp_buf* env = static_cast(cinfo->client_data); - (*cinfo->err->output_message)(cinfo); - jpeg_destroy_decompress(reinterpret_cast(cinfo)); - longjmp(*env, 1); -} - -void MyOutputMessage(j_common_ptr cinfo) { -#if JXL_DEBUG_WARNING == 1 - char buf[JMSG_LENGTH_MAX]; - (*cinfo->err->format_message)(cinfo, buf); - JXL_WARNING("%s", buf); -#endif -} - } // namespace -Status DecodeImageJPG(const Span bytes, - const ColorHints& color_hints, - const SizeConstraints& constraints, - PackedPixelFile* ppf) { - // Don't do anything for non-JPEG files (no need to report an error) - if (!IsJPG(bytes)) return false; - - // TODO(veluca): use JPEGData also for pixels? - - // We need to declare all the non-trivial destructor local variables before - // the call to setjmp(). - ColorEncoding color_encoding; - PaddedBytes icc; - Image3F image; - std::unique_ptr row; - - const auto try_catch_block = [&]() -> bool { - jpeg_decompress_struct cinfo; - // cinfo is initialized by libjpeg, which we are not instrumenting with - // msan, therefore we need to initialize cinfo here. - msan::UnpoisonMemory(&cinfo, sizeof(cinfo)); - // Setup error handling in jpeg library so we can deal with broken jpegs in - // the fuzzer. - jpeg_error_mgr jerr; - jmp_buf env; - cinfo.err = jpeg_std_error(&jerr); - jerr.error_exit = &MyErrorExit; - jerr.output_message = &MyOutputMessage; - if (setjmp(env)) { - return false; - } - cinfo.client_data = static_cast(&env); - - jpeg_create_decompress(&cinfo); - jpeg_mem_src(&cinfo, reinterpret_cast(bytes.data()), - bytes.size()); - jpeg_save_markers(&cinfo, kICCMarker, 0xFFFF); - jpeg_save_markers(&cinfo, kExifMarker, 0xFFFF); - const auto failure = [&cinfo](const char* str) -> Status { - jpeg_abort_decompress(&cinfo); - jpeg_destroy_decompress(&cinfo); - return JXL_FAILURE("%s", str); - }; - int read_header_result = jpeg_read_header(&cinfo, TRUE); - // TODO(eustas): what about JPEG_HEADER_TABLES_ONLY? - if (read_header_result == JPEG_SUSPENDED) { - return failure("truncated JPEG input"); - } - if (!VerifyDimensions(&constraints, cinfo.image_width, - cinfo.image_height)) { - return failure("image too big"); - } - // Might cause CPU-zip bomb. - if (cinfo.arith_code) { - return failure("arithmetic code JPEGs are not supported"); - } - int nbcomp = cinfo.num_components; - if (nbcomp != 1 && nbcomp != 3) { - return failure("unsupported number of components in JPEG"); - } - if (!ReadICCProfile(&cinfo, &ppf->icc)) { - ppf->icc.clear(); - // Default to SRGB - // Actually, (cinfo.output_components == nbcomp) will be checked after - // `jpeg_start_decompress`. - ppf->color_encoding.color_space = - (nbcomp == 1) ? JXL_COLOR_SPACE_GRAY : JXL_COLOR_SPACE_RGB; - ppf->color_encoding.white_point = JXL_WHITE_POINT_D65; - ppf->color_encoding.primaries = JXL_PRIMARIES_SRGB; - ppf->color_encoding.transfer_function = JXL_TRANSFER_FUNCTION_SRGB; - ppf->color_encoding.rendering_intent = JXL_RENDERING_INTENT_PERCEPTUAL; - } - ReadExif(&cinfo, &ppf->metadata.exif); - if (!ApplyColorHints(color_hints, /*color_already_set=*/true, - /*is_gray=*/false, ppf)) { - return failure("ApplyColorHints failed"); - } - - ppf->info.xsize = cinfo.image_width; - ppf->info.ysize = cinfo.image_height; - // Original data is uint, so exponent_bits_per_sample = 0. - ppf->info.bits_per_sample = BITS_IN_JSAMPLE; - JXL_ASSERT(BITS_IN_JSAMPLE == 8 || BITS_IN_JSAMPLE == 16); - ppf->info.exponent_bits_per_sample = 0; - ppf->info.uses_original_profile = true; - - // No alpha in JPG - ppf->info.alpha_bits = 0; - ppf->info.alpha_exponent_bits = 0; - - ppf->info.num_color_channels = nbcomp; - ppf->info.orientation = JXL_ORIENT_IDENTITY; - - jpeg_start_decompress(&cinfo); - JXL_ASSERT(cinfo.output_components == nbcomp); - - const JxlPixelFormat format{ - /*num_channels=*/static_cast(nbcomp), - /*data_type=*/BITS_IN_JSAMPLE == 8 ? JXL_TYPE_UINT8 : JXL_TYPE_UINT16, - /*endianness=*/JXL_NATIVE_ENDIAN, - /*align=*/0, - }; - ppf->frames.clear(); - // Allocates the frame buffer. - ppf->frames.emplace_back(cinfo.image_width, cinfo.image_height, format); - const auto& frame = ppf->frames.back(); - JXL_ASSERT(sizeof(JSAMPLE) * cinfo.output_components * cinfo.image_width <= - frame.color.stride); - - for (size_t y = 0; y < cinfo.image_height; ++y) { - JSAMPROW rows[] = {reinterpret_cast( - static_cast(frame.color.pixels()) + - frame.color.stride * y)}; - jpeg_read_scanlines(&cinfo, rows, 1); - msan::UnpoisonMemory(rows[0], sizeof(JSAMPLE) * cinfo.output_components * - cinfo.image_width); - } - - jpeg_finish_decompress(&cinfo); - jpeg_destroy_decompress(&cinfo); - return true; - }; - - return try_catch_block(); -} - Status EncodeWithLibJpeg(const ImageBundle* ib, const CodecInOut* io, size_t quality, const YCbCrChromaSubsampling& chroma_subsampling, diff --git a/third_party/jpeg-xl/lib/extras/enc/pnm.cc b/third_party/jpeg-xl/lib/extras/enc/pnm.cc index a6db9610ffc4..686101e16b80 100644 --- a/third_party/jpeg-xl/lib/extras/enc/pnm.cc +++ b/third_party/jpeg-xl/lib/extras/enc/pnm.cc @@ -9,7 +9,9 @@ #include #include +#include +#include "lib/extras/packed_image.h" #include "lib/jxl/base/byte_order.h" #include "lib/jxl/base/compiler_specific.h" #include "lib/jxl/base/file_io.h" @@ -30,35 +32,40 @@ namespace { constexpr size_t kMaxHeaderSize = 200; -Status EncodeHeader(const ImageBundle& ib, const size_t bits_per_sample, +Status EncodeHeader(const PackedPixelFile& ppf, const size_t bits_per_sample, const bool little_endian, char* header, int* JXL_RESTRICT chars_written) { - if (ib.HasAlpha()) return JXL_FAILURE("PNM: can't store alpha"); + if (ppf.info.alpha_bits > 0) return JXL_FAILURE("PNM: can't store alpha"); + bool is_gray = ppf.info.num_color_channels <= 2; + size_t oriented_xsize = + ppf.info.orientation <= 4 ? ppf.info.xsize : ppf.info.ysize; + size_t oriented_ysize = + ppf.info.orientation <= 4 ? ppf.info.ysize : ppf.info.xsize; if (bits_per_sample == 32) { // PFM - const char type = ib.IsGray() ? 'f' : 'F'; + const char type = is_gray ? 'f' : 'F'; const double scale = little_endian ? -1.0 : 1.0; *chars_written = snprintf(header, kMaxHeaderSize, "P%c\n%" PRIuS " %" PRIuS "\n%.1f\n", - type, ib.oriented_xsize(), ib.oriented_ysize(), scale); + type, oriented_xsize, oriented_ysize, scale); JXL_RETURN_IF_ERROR(static_cast(*chars_written) < kMaxHeaderSize); } else if (bits_per_sample == 1) { // PBM - if (!ib.IsGray()) { + if (is_gray) { return JXL_FAILURE("Cannot encode color as PBM"); } *chars_written = snprintf(header, kMaxHeaderSize, "P4\n%" PRIuS " %" PRIuS "\n", - ib.oriented_xsize(), ib.oriented_ysize()); + oriented_xsize, oriented_ysize); JXL_RETURN_IF_ERROR(static_cast(*chars_written) < kMaxHeaderSize); } else { // PGM/PPM const uint32_t max_val = (1U << bits_per_sample) - 1; if (max_val >= 65536) return JXL_FAILURE("PNM cannot have > 16 bits"); - const char type = ib.IsGray() ? '5' : '6'; + const char type = is_gray ? '5' : '6'; *chars_written = snprintf(header, kMaxHeaderSize, "P%c\n%" PRIuS " %" PRIuS "\n%u\n", - type, ib.oriented_xsize(), ib.oriented_ysize(), max_val); + type, oriented_xsize, oriented_ysize, max_val); JXL_RETURN_IF_ERROR(static_cast(*chars_written) < kMaxHeaderSize); } @@ -72,15 +79,16 @@ Span MakeSpan(const char* str) { // Flip the image vertically for loading/saving PFM files which have the // scanlines inverted. -void VerticallyFlipImage(Image3F* const image) { - for (int c = 0; c < 3; c++) { - for (size_t y = 0; y < image->ysize() / 2; y++) { - float* first_row = image->PlaneRow(c, y); - float* other_row = image->PlaneRow(c, image->ysize() - y - 1); - for (size_t x = 0; x < image->xsize(); ++x) { - float tmp = first_row[x]; - first_row[x] = other_row[x]; - other_row[x] = tmp; +void VerticallyFlipImage(float* const float_image, const size_t xsize, + const size_t ysize, const size_t num_channels) { + for (size_t y = 0; y < ysize / 2; y++) { + float* first_row = &float_image[y * num_channels * xsize]; + float* other_row = &float_image[(ysize - y - 1) * num_channels * xsize]; + for (size_t c = 0; c < num_channels; c++) { + for (size_t x = 0; x < xsize; ++x) { + float tmp = first_row[x * num_channels + c]; + first_row[x * num_channels + c] = other_row[x * num_channels + c]; + other_row[x * num_channels + c] = tmp; } } } @@ -88,59 +96,33 @@ void VerticallyFlipImage(Image3F* const image) { } // namespace -Status EncodeImagePNM(const CodecInOut* io, const ColorEncoding& c_desired, - size_t bits_per_sample, ThreadPool* pool, - PaddedBytes* bytes) { +Status EncodeImagePNM(const PackedPixelFile& ppf, size_t bits_per_sample, + ThreadPool* pool, std::vector* bytes) { const bool floating_point = bits_per_sample > 16; // Choose native for PFM; PGM/PPM require big-endian (N/A for PBM) const JxlEndianness endianness = floating_point ? JXL_NATIVE_ENDIAN : JXL_BIG_ENDIAN; - - ImageMetadata metadata_copy = io->metadata.m; - // AllDefault sets all_default, which can cause a race condition. - if (!Bundle::AllDefault(metadata_copy)) { + if (!ppf.metadata.exif.empty() || !ppf.metadata.iptc.empty() || + !ppf.metadata.jumbf.empty() || !ppf.metadata.xmp.empty()) { JXL_WARNING("PNM encoder ignoring metadata - use a different codec"); } - if (!c_desired.IsSRGB()) { - JXL_WARNING( - "PNM encoder cannot store custom ICC profile; decoder\n" - "will need hint key=color_space to get the same values"); - } - - ImageBundle ib = io->Main().Copy(); - // In case of PFM the image must be flipped upside down since that format - // is designed that way. - const ImageBundle* to_color_transform = &ib; - ImageBundle flipped; - if (floating_point) { - flipped = ib.Copy(); - VerticallyFlipImage(flipped.color()); - to_color_transform = &flipped; - } - ImageMetadata metadata = io->metadata.m; - ImageBundle store(&metadata); - const ImageBundle* transformed; - JXL_RETURN_IF_ERROR(TransformIfNeeded( - *to_color_transform, c_desired, GetJxlCms(), pool, &store, &transformed)); - size_t bytes_per_sample = floating_point ? 4 : bits_per_sample > 8 ? 2 : 1; - size_t stride = ib.oriented_xsize() * c_desired.Channels() * bytes_per_sample; - PaddedBytes pixels(stride * ib.oriented_ysize()); - JXL_RETURN_IF_ERROR(ConvertToExternal( - *transformed, bits_per_sample, floating_point, c_desired.Channels(), - endianness, stride, pool, pixels.data(), pixels.size(), - /*out_callback=*/nullptr, /*out_opaque=*/nullptr, - metadata.GetOrientation())); char header[kMaxHeaderSize]; int header_size = 0; bool is_little_endian = endianness == JXL_LITTLE_ENDIAN || (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian()); - JXL_RETURN_IF_ERROR(EncodeHeader(*transformed, bits_per_sample, - is_little_endian, header, &header_size)); - - bytes->resize(static_cast(header_size) + pixels.size()); + JXL_RETURN_IF_ERROR(EncodeHeader(ppf, bits_per_sample, is_little_endian, + header, &header_size)); + bytes->resize(static_cast(header_size) + + ppf.frames[0].color.pixels_size); memcpy(bytes->data(), header, static_cast(header_size)); - memcpy(bytes->data() + header_size, pixels.data(), pixels.size()); + memcpy(bytes->data() + header_size, ppf.frames[0].color.pixels(), + ppf.frames[0].color.pixels_size); + if (floating_point) { + VerticallyFlipImage(reinterpret_cast(bytes->data() + header_size), + ppf.frames[0].color.xsize, ppf.frames[0].color.ysize, + ppf.info.num_color_channels); + } return true; } diff --git a/third_party/jpeg-xl/lib/extras/enc/pnm.h b/third_party/jpeg-xl/lib/extras/enc/pnm.h index c8630d3e61bb..9e41435a8ae8 100644 --- a/third_party/jpeg-xl/lib/extras/enc/pnm.h +++ b/third_party/jpeg-xl/lib/extras/enc/pnm.h @@ -11,6 +11,7 @@ // TODO(janwas): workaround for incorrect Win64 codegen (cause unknown) #include +#include "lib/extras/packed_image.h" #include "lib/jxl/base/data_parallel.h" #include "lib/jxl/base/padded_bytes.h" #include "lib/jxl/base/status.h" @@ -21,9 +22,8 @@ namespace jxl { namespace extras { // Transforms from io->c_current to `c_desired` and encodes into `bytes`. -Status EncodeImagePNM(const CodecInOut* io, const ColorEncoding& c_desired, - size_t bits_per_sample, ThreadPool* pool, - PaddedBytes* bytes); +Status EncodeImagePNM(const PackedPixelFile& ppf, size_t bits_per_sample, + ThreadPool* pool, std::vector* bytes); } // namespace extras } // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/packed_image_convert.cc b/third_party/jpeg-xl/lib/extras/packed_image_convert.cc index 465f023d2073..92d8d32a171a 100644 --- a/third_party/jpeg-xl/lib/extras/packed_image_convert.cc +++ b/third_party/jpeg-xl/lib/extras/packed_image_convert.cc @@ -81,6 +81,26 @@ Status ConvertPackedPixelFileToCodecInOut(const PackedPixelFile& ppf, io->blobs.xmp.clear(); io->blobs.xmp.append(ppf.metadata.xmp); + // Append all other extra channels. + for (const PackedPixelFile::PackedExtraChannel& info : + ppf.extra_channels_info) { + ExtraChannelInfo out; + out.type = static_cast(info.ec_info.type); + out.bit_depth.bits_per_sample = info.ec_info.bits_per_sample; + out.bit_depth.exponent_bits_per_sample = + info.ec_info.exponent_bits_per_sample; + out.bit_depth.floating_point_sample = + info.ec_info.exponent_bits_per_sample != 0; + out.dim_shift = info.ec_info.dim_shift; + out.name = info.name; + out.alpha_associated = (info.ec_info.alpha_premultiplied != 0); + out.spot_color[0] = info.ec_info.spot_color[0]; + out.spot_color[1] = info.ec_info.spot_color[1]; + out.spot_color[2] = info.ec_info.spot_color[2]; + out.spot_color[3] = info.ec_info.spot_color[3]; + io->metadata.m.extra_channel_info.push_back(std::move(out)); + } + // Convert the pixels io->dec_pixels = 0; io->frames.clear(); @@ -126,8 +146,12 @@ Status ConvertPackedPixelFileToCodecInOut(const PackedPixelFile& ppf, /*flipped_y=*/frame.color.flipped_y, pool, &bundle, /*float_in=*/float_in, /*align=*/0)); - // TODO(deymo): Convert the extra channels. FIXME! - JXL_CHECK(frame.extra_channels.empty()); + for (const auto& ppf_ec : frame.extra_channels) { + bundle.extra_channels().emplace_back(ppf_ec.xsize, ppf_ec.ysize); + JXL_CHECK(BufferToImageF(ppf_ec.format, ppf_ec.xsize, ppf_ec.ysize, + ppf_ec.pixels(), ppf_ec.pixels_size, pool, + &bundle.extra_channels().back())); + } io->frames.push_back(std::move(bundle)); io->dec_pixels += frame.color.xsize * frame.color.ysize; diff --git a/third_party/jpeg-xl/lib/jxl/base/data_parallel.h b/third_party/jpeg-xl/lib/jxl/base/data_parallel.h index 5b602d89ad7e..2fc03de6e2f2 100644 --- a/third_party/jpeg-xl/lib/jxl/base/data_parallel.h +++ b/third_party/jpeg-xl/lib/jxl/base/data_parallel.h @@ -56,7 +56,6 @@ class ThreadPool { static Status NoInit(size_t num_threads) { return true; } private: - // class holding the state of a Run() call to pass to the runner_ as an // opaque_jpegxl pointer. template diff --git a/third_party/jpeg-xl/lib/jxl/base/os_macros.h b/third_party/jpeg-xl/lib/jxl/base/os_macros.h index b230f2675873..84d0b82bf5c1 100644 --- a/third_party/jpeg-xl/lib/jxl/base/os_macros.h +++ b/third_party/jpeg-xl/lib/jxl/base/os_macros.h @@ -20,7 +20,7 @@ #define JXL_OS_LINUX 0 #endif -#ifdef __MACH__ +#ifdef __APPLE__ #define JXL_OS_MAC 1 #else #define JXL_OS_MAC 0 diff --git a/third_party/jpeg-xl/lib/jxl/base/padded_bytes.h b/third_party/jpeg-xl/lib/jxl/base/padded_bytes.h index 1840a6c936ca..4534ddf8630f 100644 --- a/third_party/jpeg-xl/lib/jxl/base/padded_bytes.h +++ b/third_party/jpeg-xl/lib/jxl/base/padded_bytes.h @@ -160,9 +160,11 @@ class PaddedBytes { } void append(const uint8_t* begin, const uint8_t* end) { - size_t old_size = size(); - resize(size() + (end - begin)); - memcpy(data() + old_size, begin, end - begin); + if (end - begin > 0) { + size_t old_size = size(); + resize(size() + (end - begin)); + memcpy(data() + old_size, begin, end - begin); + } } private: diff --git a/third_party/jpeg-xl/lib/jxl/blending.cc b/third_party/jpeg-xl/lib/jxl/blending.cc index 94d15151b424..f06592d07c17 100644 --- a/third_party/jpeg-xl/lib/jxl/blending.cc +++ b/third_party/jpeg-xl/lib/jxl/blending.cc @@ -453,7 +453,7 @@ void PerformBlending(const float* const* bg, const float* const* fg, JXL_ABORT("Unreachable"); } for (size_t i = 0; i < 3 + num_ec; i++) { - memcpy(out[i] + x0, tmp.Row(i), xsize * sizeof(**out)); + if (xsize != 0) memcpy(out[i] + x0, tmp.Row(i), xsize * sizeof(**out)); } } diff --git a/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc index 35081c999fbf..a2eca448c8eb 100644 --- a/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc +++ b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc @@ -692,9 +692,16 @@ Status AdaptToXYZD50(float wx, float wy, float matrix[9]) { MatMul(kBradford, w, 3, 3, 1, lms); MatMul(kBradford, w50, 3, 3, 1, lms50); + if (lms[0] == 0 || lms[1] == 0 || lms[2] == 0) { + return JXL_FAILURE("Invalid white point"); + } float a[9] = { + // /----> 0, 1, 2, 3, /----> 4, 5, 6, 7, /----> 8, lms50[0] / lms[0], 0, 0, 0, lms50[1] / lms[1], 0, 0, 0, lms50[2] / lms[2], }; + if (!std::isfinite(a[0]) || !std::isfinite(a[4]) || !std::isfinite(a[8])) { + return JXL_FAILURE("Invalid white point"); + } float b[9]; MatMul(a, kBradford, 3, 3, 3, b); diff --git a/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h index 5b4bdef5fe07..2a8ea0745669 100644 --- a/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h +++ b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h @@ -290,6 +290,7 @@ struct ColorEncoding : public Fields { void DecideIfWantICC(); bool IsGray() const { return color_space_ == ColorSpace::kGray; } + bool IsCMYK() const { return cmyk_; } size_t Channels() const { return IsGray() ? 1 : 3; } // Returns false if the field is invalid and unusable. @@ -399,7 +400,7 @@ struct ColorEncoding : public Fields { private: // Returns true if all fields have been initialized (possibly to kUnknown). // Returns false if the ICC profile is invalid or decoding it fails. - // Defined in color_management.cc. + // Defined in enc_color_management.cc. Status SetFieldsFromICC(); // If true, the codestream contains an ICC profile and we do not serialize @@ -414,6 +415,7 @@ struct ColorEncoding : public Fields { PaddedBytes icc_; // Valid ICC profile ColorSpace color_space_; // Can be kUnknown + bool cmyk_ = false; // Only used if white_point == kCustom. Customxy white_; diff --git a/third_party/jpeg-xl/lib/jxl/compressed_image_test.cc b/third_party/jpeg-xl/lib/jxl/compressed_image_test.cc index e324ff4b0ada..853570385681 100644 --- a/third_party/jpeg-xl/lib/jxl/compressed_image_test.cc +++ b/third_party/jpeg-xl/lib/jxl/compressed_image_test.cc @@ -77,6 +77,8 @@ void RunRGBRoundTrip(float distance, bool fast) { PassesEncoderState enc_state; JXL_CHECK(InitializePassesSharedState(frame_header, &enc_state.shared)); + JXL_CHECK(enc_state.shared.matrices.EnsureComputed(~0u)); + enc_state.shared.quantizer.SetQuant(4.0f, 4.0f, &enc_state.shared.raw_quant_field); enc_state.shared.ac_strategy.FillDCT8(); diff --git a/third_party/jpeg-xl/lib/jxl/dec_frame.cc b/third_party/jpeg-xl/lib/jxl/dec_frame.cc index a435e3eb15bb..fbc3c49b1e32 100644 --- a/third_party/jpeg-xl/lib/jxl/dec_frame.cc +++ b/third_party/jpeg-xl/lib/jxl/dec_frame.cc @@ -545,6 +545,8 @@ Status FrameDecoder::ProcessACGlobal(BitReader* br) { if (frame_header_.encoding == FrameEncoding::kVarDCT) { JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.Decode( br, &modular_frame_decoder_)); + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.EnsureComputed( + dec_state_->used_acs)); size_t num_histo_bits = CeilLog2Nonzero(dec_state_->shared->frame_dim.num_groups); diff --git a/third_party/jpeg-xl/lib/jxl/dec_group.cc b/third_party/jpeg-xl/lib/jxl/dec_group.cc index 88d3fb9f02a2..8ae23c1a02cf 100644 --- a/third_party/jpeg-xl/lib/jxl/dec_group.cc +++ b/third_party/jpeg-xl/lib/jxl/dec_group.cc @@ -97,15 +97,14 @@ void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) { template void DequantLane(Vec scaled_dequant_x, Vec scaled_dequant_y, Vec scaled_dequant_b, - const float* JXL_RESTRICT dequant_matrices, size_t dq_ofs, - size_t size, size_t k, Vec x_cc_mul, Vec b_cc_mul, + const float* JXL_RESTRICT dequant_matrices, size_t size, + size_t k, Vec x_cc_mul, Vec b_cc_mul, const float* JXL_RESTRICT biases, ACPtr qblock[3], float* JXL_RESTRICT block) { - const auto x_mul = Load(d, dequant_matrices + dq_ofs + k) * scaled_dequant_x; - const auto y_mul = - Load(d, dequant_matrices + dq_ofs + size + k) * scaled_dequant_y; + const auto x_mul = Load(d, dequant_matrices + k) * scaled_dequant_x; + const auto y_mul = Load(d, dequant_matrices + size + k) * scaled_dequant_y; const auto b_mul = - Load(d, dequant_matrices + dq_ofs + 2 * size + k) * scaled_dequant_b; + Load(d, dequant_matrices + 2 * size + k) * scaled_dequant_b; Vec quantized_x_int; Vec quantized_y_int; @@ -139,9 +138,8 @@ template void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, float x_dm_multiplier, float b_dm_multiplier, Vec x_cc_mul, Vec b_cc_mul, size_t kind, size_t size, - const Quantizer& quantizer, - const float* JXL_RESTRICT dequant_matrices, - size_t covered_blocks, const size_t* sbx, + const Quantizer& quantizer, size_t covered_blocks, + const size_t* sbx, const float* JXL_RESTRICT* JXL_RESTRICT dc_row, size_t dc_stride, const float* JXL_RESTRICT biases, ACPtr qblock[3], float* JXL_RESTRICT block) { @@ -153,12 +151,12 @@ void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, const auto scaled_dequant_y = Set(d, scaled_dequant_s); const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier); - const size_t dq_ofs = quantizer.DequantMatrixOffset(kind, 0); + const float* dequant_matrices = quantizer.DequantMatrix(kind, 0); for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) { DequantLane(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b, - dequant_matrices, dq_ofs, size, k, x_cc_mul, b_cc_mul, - biases, qblock, block); + dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases, + qblock, block); } for (size_t c = 0; c < 3; c++) { LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride, @@ -186,8 +184,6 @@ Status DecodeGroupImpl(GetBlock* JXL_RESTRICT get_block, const size_t dc_stride = dec_state->shared->dc->PixelsPerRow(); const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale(); - const float* JXL_RESTRICT dequant_matrices = - dec_state->shared->quantizer.DequantMatrix(0, 0); const YCbCrChromaSubsampling& cs = dec_state->shared->frame_header.chroma_subsampling; @@ -428,7 +424,7 @@ Status DecodeGroupImpl(GetBlock* JXL_RESTRICT get_block, dequant_block( acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier, dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(), - size, dec_state->shared->quantizer, dequant_matrices, + size, dec_state->shared->quantizer, acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows, dc_stride, dec_state->output_encoding_info.opsin_params.quant_biases, qblock, diff --git a/third_party/jpeg-xl/lib/jxl/dec_reconstruct.cc b/third_party/jpeg-xl/lib/jxl/dec_reconstruct.cc index eaab5feb9156..5f336d5f1fd8 100644 --- a/third_party/jpeg-xl/lib/jxl/dec_reconstruct.cc +++ b/third_party/jpeg-xl/lib/jxl/dec_reconstruct.cc @@ -1113,24 +1113,25 @@ Status FinalizeImageRect( // TODO(veluca): all blending should happen here. + Rect image_line_rect = upsampled_frame_rect.Lines(available_y, num_ys) + .Crop(Rect(0, 0, frame_dim.xsize_upsampled, + frame_dim.ysize_upsampled)); + if ((image_line_rect.xsize()) == 0 || image_line_rect.ysize() == 0) { + continue; + } + if (dec_state->rgb_output != nullptr) { HWY_DYNAMIC_DISPATCH(FloatToRGBA8) (*output_pixel_data_storage, upsampled_frame_rect_for_storage.Lines(available_y, num_ys), dec_state->rgb_output_is_rgba, alpha, - alpha_rect.Lines(available_y, num_ys), - upsampled_frame_rect.Lines(available_y, num_ys) - .Crop(Rect(0, 0, frame_dim.xsize_upsampled, - frame_dim.ysize_upsampled)), + alpha_rect.Lines(available_y, num_ys), image_line_rect, dec_state->rgb_output, dec_state->rgb_stride); } if (dec_state->pixel_callback != nullptr) { Rect alpha_line_rect = alpha_rect.Lines(available_y, num_ys); Rect color_input_line_rect = upsampled_frame_rect_for_storage.Lines(available_y, num_ys); - Rect image_line_rect = upsampled_frame_rect.Lines(available_y, num_ys) - .Crop(Rect(0, 0, frame_dim.xsize_upsampled, - frame_dim.ysize_upsampled)); const float* line_buffers[4]; for (size_t iy = 0; iy < image_line_rect.ysize(); iy++) { for (size_t c = 0; c < 3; c++) { diff --git a/third_party/jpeg-xl/lib/jxl/decode.cc b/third_party/jpeg-xl/lib/jxl/decode.cc index 8aa89bd9ea6a..0cbf0b474b97 100644 --- a/third_party/jpeg-xl/lib/jxl/decode.cc +++ b/third_party/jpeg-xl/lib/jxl/decode.cc @@ -1462,6 +1462,7 @@ JxlDecoderStatus JxlDecoderProcessCodestream(JxlDecoder* dec, const uint8_t* in, bool is_rgba = dec->image_out_format.num_channels == 4; dec->frame_dec->MaybeSetFloatCallback( [dec](const float* pixels, size_t x, size_t y, size_t num_pixels) { + JXL_DASSERT(num_pixels > 0); dec->image_out_callback(dec->image_out_opaque, x, y, num_pixels, pixels); }, @@ -2424,6 +2425,10 @@ JxlDecoderStatus JxlDecoderFlushImage(JxlDecoder* dec) { return JXL_DEC_ERROR; } + if (dec->jpeg_decoder.IsOutputSet() && dec->ib->jpeg_data != nullptr) { + return JXL_DEC_SUCCESS; + } + if (dec->frame_dec->HasRGBBuffer()) { return JXL_DEC_SUCCESS; } diff --git a/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.h b/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.h index f4e1ae9a8f57..68fd06e665a7 100644 --- a/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.h +++ b/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.h @@ -121,7 +121,7 @@ class JxlToJpegDecoder { auto write = [&tmp_next_out, &tmp_avail_size](const uint8_t* buf, size_t len) { size_t to_write = std::min(tmp_avail_size, len); - memcpy(tmp_next_out, buf, to_write); + if (to_write != 0) memcpy(tmp_next_out, buf, to_write); tmp_next_out += to_write; tmp_avail_size -= to_write; return to_write; diff --git a/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc index c0ed68fde1cd..3ef979242847 100644 --- a/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc +++ b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc @@ -1008,6 +1008,17 @@ void AcStrategyHeuristics::Init(const Image3F& src, const CompressParams& cparams = enc_state->cparams; const float butteraugli_target = cparams.butteraugli_distance; + if (cparams.speed_tier >= SpeedTier::kCheetah) { + JXL_CHECK(enc_state->shared.matrices.EnsureComputed(1)); // DCT8 only + } else { + uint32_t acs_mask = 0; + // All transforms up to 64x64. + for (size_t i = 0; i < AcStrategy::DCT128X128; i++) { + acs_mask |= (1 << i); + } + JXL_CHECK(enc_state->shared.matrices.EnsureComputed(acs_mask)); + } + // Image row pointers and strides. config.quant_field_row = enc_state->initial_quant_field.Row(0); config.quant_field_stride = enc_state->initial_quant_field.PixelsPerRow(); diff --git a/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc index 17afa4c7c001..14b796fe4de3 100644 --- a/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc +++ b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc @@ -104,7 +104,8 @@ void BitWriter::AppendByteAligned(const std::vector& others) { size_t pos = BitsWritten() / kBitsPerByte; for (const BitWriter& writer : others) { const Span span = writer.GetSpan(); - memcpy(storage_.data() + pos, span.data(), span.size()); + if (span.size() != 0) + memcpy(storage_.data() + pos, span.data(), span.size()); pos += span.size(); } storage_[pos++] = 0; // for next Write diff --git a/third_party/jpeg-xl/lib/jxl/enc_color_management.cc b/third_party/jpeg-xl/lib/jxl/enc_color_management.cc index bb8d24bacfad..419a2b6b68c4 100644 --- a/third_party/jpeg-xl/lib/jxl/enc_color_management.cc +++ b/third_party/jpeg-xl/lib/jxl/enc_color_management.cc @@ -229,6 +229,15 @@ Status DoColorSpaceTransform(void* cms_data, const size_t thread, } xform_src = mutable_xform_src; } +#else + if (t->channels_src == 4 && !t->skip_lcms) { + // LCMS does CMYK in a weird way: 0 = white, 100 = max ink + float* mutable_xform_src = t->buf_src.Row(thread); + for (size_t x = 0; x < xsize * 4; ++x) { + mutable_xform_src[x] = 100.f - 100.f * mutable_xform_src[x]; + } + xform_src = mutable_xform_src; + } #endif #if JXL_CMS_VERBOSE >= 2 @@ -244,10 +253,13 @@ Status DoColorSpaceTransform(void* cms_data, const size_t thread, } // else: in-place, no need to copy } else { #if JPEGXL_ENABLE_SKCMS - JXL_CHECK(skcms_Transform( - xform_src, skcms_PixelFormat_RGB_fff, skcms_AlphaFormat_Opaque, - &t->profile_src, buf_dst, skcms_PixelFormat_RGB_fff, - skcms_AlphaFormat_Opaque, &t->profile_dst, xsize)); + JXL_CHECK( + skcms_Transform(xform_src, + (t->channels_src == 4 ? skcms_PixelFormat_RGBA_ffff + : skcms_PixelFormat_RGB_fff), + skcms_AlphaFormat_Opaque, &t->profile_src, buf_dst, + skcms_PixelFormat_RGB_fff, skcms_AlphaFormat_Opaque, + &t->profile_dst, xsize)); #else // JPEGXL_ENABLE_SKCMS cmsDoTransform(t->lcms_transform, xform_src, buf_dst, static_cast(xsize)); @@ -396,6 +408,9 @@ Status DecodeProfile(const cmsContext context, const PaddedBytes& icc, ColorSpace ColorSpaceFromProfile(const skcms_ICCProfile& profile) { switch (profile.data_color_space) { case skcms_Signature_RGB: + case skcms_Signature_CMYK: + // spec says CMYK is encoded as RGB (the kBlack extra channel signals that + // it is actually CMYK) return ColorSpace::kRGB; case skcms_Signature_Gray: return ColorSpace::kGray; @@ -518,7 +533,8 @@ void DetectTransferFunction(const skcms_ICCProfile& profile, #else // JPEGXL_ENABLE_SKCMS -uint32_t Type32(const ColorEncoding& c) { +uint32_t Type32(const ColorEncoding& c, bool cmyk) { + if (cmyk) return TYPE_CMYK_FLT; if (c.IsGray()) return TYPE_GRAY_FLT; return TYPE_RGB_FLT; } @@ -531,6 +547,7 @@ uint32_t Type64(const ColorEncoding& c) { ColorSpace ColorSpaceFromProfile(const Profile& profile) { switch (cmsGetColorSpace(profile.get())) { case cmsSigRgbData: + case cmsSigCmykData: return ColorSpace::kRGB; case cmsSigGrayData: return ColorSpace::kGray; @@ -826,6 +843,7 @@ Status ColorEncoding::SetFieldsFromICC() { } SetColorSpace(ColorSpaceFromProfile(profile)); + cmyk_ = (profile.data_color_space == skcms_Signature_CMYK); CIExy wp_unadapted; JXL_RETURN_IF_ERROR(UnadaptedWhitePoint(profile, &wp_unadapted)); @@ -838,7 +856,7 @@ Status ColorEncoding::SetFieldsFromICC() { DetectTransferFunction(profile, this); // ICC and RenderingIntent have the same values (0..3). rendering_intent = static_cast(rendering_intent32); -#else // JPEGXL_ENABLE_SKCMS +#else // JPEGXL_ENABLE_SKCMS std::lock_guard guard(LcmsMutex()); const cmsContext context = GetContext(); @@ -851,8 +869,14 @@ Status ColorEncoding::SetFieldsFromICC() { if (rendering_intent32 > 3) { return JXL_FAILURE("Invalid rendering intent %u\n", rendering_intent32); } + // ICC and RenderingIntent have the same values (0..3). + rendering_intent = static_cast(rendering_intent32); SetColorSpace(ColorSpaceFromProfile(profile)); + if (cmsGetColorSpace(profile.get()) == cmsSigCmykData) { + cmyk_ = true; + return true; + } const cmsCIEXYZ wp_unadapted = UnadaptedWhitePoint(context, profile, *this); JXL_RETURN_IF_ERROR(SetWhitePoint(CIExyFromXYZ(wp_unadapted))); @@ -863,8 +887,6 @@ Status ColorEncoding::SetFieldsFromICC() { // Relies on color_space/white point/primaries being set already. DetectTransferFunction(context, profile, this); - // ICC and RenderingIntent have the same values (0..3). - rendering_intent = static_cast(rendering_intent32); #endif // JPEGXL_ENABLE_SKCMS return true; @@ -882,6 +904,7 @@ void ColorEncoding::DecideIfWantICC() { const cmsContext context = GetContext(); Profile profile; if (!DecodeProfile(context, ICC(), &profile)) return; + if (cmsGetColorSpace(profile.get()) == cmsSigCmykData) return; if (!MaybeCreateProfile(*this, &icc_new)) return; equivalent = ProfileEquivalentToICC(context, profile, icc_new, *this); #endif // JPEGXL_ENABLE_SKCMS @@ -1071,9 +1094,10 @@ void* JxlCmsInit(void* init_data, size_t num_threads, size_t xsize, #endif // JPEGXL_ENABLE_SKCMS // Not including alpha channel (copied separately). - const size_t channels_src = c_src.Channels(); + const size_t channels_src = (c_src.IsCMYK() ? 4 : c_src.Channels()); const size_t channels_dst = c_dst.Channels(); - JXL_CHECK(channels_src == channels_dst); + JXL_CHECK(channels_src == channels_dst || + (channels_src == 4 && channels_dst == 3)); #if JXL_CMS_VERBOSE printf("Channels: %" PRIuS "; Threads: %" PRIuS "\n", channels_src, num_threads); @@ -1081,8 +1105,8 @@ void* JxlCmsInit(void* init_data, size_t num_threads, size_t xsize, #if !JPEGXL_ENABLE_SKCMS // Type includes color space (XYZ vs RGB), so can be different. - const uint32_t type_src = Type32(c_src); - const uint32_t type_dst = Type32(c_dst); + const uint32_t type_src = Type32(c_src, channels_src == 4); + const uint32_t type_dst = Type32(c_dst, false); const uint32_t intent = static_cast(c_dst.rendering_intent); // Use cmsFLAGS_NOCACHE to disable the 1-pixel cache and make calling // cmsDoTransform() thread-safe. @@ -1110,7 +1134,7 @@ void* JxlCmsInit(void* init_data, size_t num_threads, size_t xsize, #if JPEGXL_ENABLE_SKCMS // SkiaCMS doesn't support grayscale float buffers, so we create space for RGB // float buffers anyway. - t->buf_src = ImageF(xsize * 3, num_threads); + t->buf_src = ImageF(xsize * (channels_src == 4 ? 4 : 3), num_threads); t->buf_dst = ImageF(xsize * 3, num_threads); #else t->buf_src = ImageF(xsize * channels_src, num_threads); diff --git a/third_party/jpeg-xl/lib/jxl/enc_color_management.h b/third_party/jpeg-xl/lib/jxl/enc_color_management.h index 03d67b833191..0d701d74f51c 100644 --- a/third_party/jpeg-xl/lib/jxl/enc_color_management.h +++ b/third_party/jpeg-xl/lib/jxl/enc_color_management.h @@ -46,13 +46,15 @@ class ColorSpaceTransform { input_profile.icc.size = icc_src_.size(); ConvertInternalToExternalColorEncoding(c_src, &input_profile.color_encoding); - input_profile.num_channels = c_src.Channels(); + input_profile.num_channels = c_src.IsCMYK() ? 4 : c_src.Channels(); JxlColorProfile output_profile; icc_dst_ = c_dst.ICC(); output_profile.icc.data = icc_dst_.data(); output_profile.icc.size = icc_dst_.size(); ConvertInternalToExternalColorEncoding(c_dst, &output_profile.color_encoding); + if (c_dst.IsCMYK()) + return JXL_FAILURE("Conversion to CMYK is not supported"); output_profile.num_channels = c_dst.Channels(); cms_data_ = cms_.init(cms_.init_data, num_threads, xsize, &input_profile, &output_profile, intensity_target); diff --git a/third_party/jpeg-xl/lib/jxl/enc_frame.cc b/third_party/jpeg-xl/lib/jxl/enc_frame.cc index bb25e8a91ba5..de921020b6fc 100644 --- a/third_party/jpeg-xl/lib/jxl/enc_frame.cc +++ b/third_party/jpeg-xl/lib/jxl/enc_frame.cc @@ -114,7 +114,6 @@ void ClusterGroups(PassesEncoderState* enc_state) { return token_cost(tokens, num_contexts) - costs[i] - costs[j]; }; std::vector out{max}; - std::vector old_map(ac.size()); std::vector dists(ac.size()); size_t farthest = 0; for (size_t i = 0; i < ac.size(); i++) { @@ -133,7 +132,6 @@ void ClusterGroups(PassesEncoderState* enc_state) { float d = dist(out.back(), i); if (d < dists[i]) { dists[i] = d; - old_map[i] = enc_state->histogram_idx[i]; enc_state->histogram_idx[i] = out.size() - 1; } if (dists[i] > dists[farthest]) { @@ -507,13 +505,18 @@ class LossyFrameEncoder { PassesSharedState& shared = enc_state_->shared; if (!enc_state_->cparams.max_error_mode) { - float x_qm_scale_steps[3] = {0.65f, 1.25f, 9.0f}; - shared.frame_header.x_qm_scale = 1; + float x_qm_scale_steps[2] = {1.25f, 9.0f}; + shared.frame_header.x_qm_scale = 2; for (float x_qm_scale_step : x_qm_scale_steps) { if (enc_state_->cparams.butteraugli_distance > x_qm_scale_step) { shared.frame_header.x_qm_scale++; } } + if (enc_state_->cparams.butteraugli_distance < 0.299f) { + // Favor chromacity preservation for making images appear more + // faithful to original even with extreme (5-10x) zooming. + shared.frame_header.x_qm_scale++; + } } JXL_RETURN_IF_ERROR(enc_state_->heuristics->LossyFrameHeuristics( diff --git a/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc b/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc index cc4af55aaaa2..9de389cd5a1b 100644 --- a/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc +++ b/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc @@ -866,6 +866,9 @@ Status DefaultEncoderHeuristics::LossyFrameHeuristics( GaborishInverse(opsin, 0.9908511000000001f, pool); } + FindBestDequantMatrices(cparams, *opsin, modular_frame_encoder, + &enc_state->shared.matrices); + cfl_heuristics.Init(*opsin); acs_heuristics.Init(*opsin, enc_state); @@ -934,9 +937,6 @@ Status DefaultEncoderHeuristics::LossyFrameHeuristics( &enc_state->shared.cmap); } - FindBestDequantMatrices(cparams, *opsin, modular_frame_encoder, - &enc_state->shared.matrices); - // Refine quantization levels. FindBestQuantizer(original_pixels, *opsin, enc_state, cms, pool, aux_out); diff --git a/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc index 05d9105b4331..fe062475e6f2 100644 --- a/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc +++ b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc @@ -50,6 +50,25 @@ Status CopyToT(const ImageMetadata* metadata, const ImageBundle* ib, // Interleave input. if (is_gray) { src_buf = rect.ConstPlaneRow(ib->color(), 0, y); + } else if (ib->c_current().IsCMYK()) { + if (!ib->HasBlack()) { + ok.store(false); + return; + } + const float* JXL_RESTRICT row_in0 = + rect.ConstPlaneRow(ib->color(), 0, y); + const float* JXL_RESTRICT row_in1 = + rect.ConstPlaneRow(ib->color(), 1, y); + const float* JXL_RESTRICT row_in2 = + rect.ConstPlaneRow(ib->color(), 2, y); + const float* JXL_RESTRICT row_in3 = rect.ConstRow(ib->black(), y); + for (size_t x = 0; x < rect.xsize(); x++) { + // CMYK convention in JXL: 0 = max ink, 1 = white + mutable_src_buf[4 * x + 0] = row_in0[x]; + mutable_src_buf[4 * x + 1] = row_in1[x]; + mutable_src_buf[4 * x + 2] = row_in2[x]; + mutable_src_buf[4 * x + 3] = row_in3[x]; + } } else { const float* JXL_RESTRICT row_in0 = rect.ConstPlaneRow(ib->color(), 0, y); @@ -107,7 +126,7 @@ Status ImageBundle::CopyTo(const Rect& rect, const ColorEncoding& c_desired, Status TransformIfNeeded(const ImageBundle& in, const ColorEncoding& c_desired, const JxlCmsInterface& cms, ThreadPool* pool, ImageBundle* store, const ImageBundle** out) { - if (in.c_current().SameColorEncoding(c_desired)) { + if (in.c_current().SameColorEncoding(c_desired) && !in.HasBlack()) { *out = ∈ return true; } diff --git a/third_party/jpeg-xl/lib/jxl/encode.cc b/third_party/jpeg-xl/lib/jxl/encode.cc index db9c2c4d40d9..6ba1e2443d9d 100644 --- a/third_party/jpeg-xl/lib/jxl/encode.cc +++ b/third_party/jpeg-xl/lib/jxl/encode.cc @@ -297,6 +297,11 @@ JxlEncoderStatus JxlEncoderStruct::RefillOutputByteQueue() { std::move(input.frame); input_queue.erase(input_queue.begin()); num_queued_frames--; + for (unsigned idx = 0; idx < input_frame->ec_initialized.size(); idx++) { + if (!input_frame->ec_initialized[idx]) { + return JXL_API_ERROR("Extra channel %u is not initialized", idx); + } + } // TODO(zond): If the input queue is empty and the frames_closed is true, // then mark this frame as the last. @@ -1055,6 +1060,28 @@ JxlEncoderStatus JxlEncoderSetParallelRunner(JxlEncoder* enc, return JXL_ENC_SUCCESS; } +namespace { +JxlEncoderStatus GetCurrentDimensions( + const JxlEncoderFrameSettings* frame_settings, size_t& xsize, + size_t& ysize) { + xsize = frame_settings->enc->metadata.xsize(); + ysize = frame_settings->enc->metadata.ysize(); + if (frame_settings->values.header.layer_info.have_crop) { + xsize = frame_settings->values.header.layer_info.xsize; + ysize = frame_settings->values.header.layer_info.ysize; + } + if (frame_settings->values.cparams.already_downsampled) { + size_t factor = frame_settings->values.cparams.resampling; + xsize = jxl::DivCeil(xsize, factor); + ysize = jxl::DivCeil(ysize, factor); + } + if (xsize == 0 || ysize == 0) { + return JXL_API_ERROR("zero-sized frame is not allowed"); + } + return JXL_ENC_SUCCESS; +} +} // namespace + JxlEncoderStatus JxlEncoderAddJPEGFrame( const JxlEncoderFrameSettings* frame_settings, const uint8_t* buffer, size_t size) { @@ -1108,13 +1135,28 @@ JxlEncoderStatus JxlEncoderAddJPEGFrame( // default move constructor there. jxl::JxlEncoderQueuedFrame{ frame_settings->values, - jxl::ImageBundle(&frame_settings->enc->metadata.m)}); + jxl::ImageBundle(&frame_settings->enc->metadata.m), + {}}); if (!queued_frame) { return JXL_ENC_ERROR; } queued_frame->frame.SetFromImage(std::move(*io.Main().color()), io.Main().c_current()); - // TODO(firsching) add extra channels here + size_t xsize, ysize; + if (GetCurrentDimensions(frame_settings, xsize, ysize) != JXL_ENC_SUCCESS) { + return JXL_API_ERROR("bad dimensions"); + } + if (xsize != static_cast(io.Main().jpeg_data->width) || + ysize != static_cast(io.Main().jpeg_data->height)) { + return JXL_API_ERROR("JPEG dimensions don't match frame dimensions"); + } + std::vector extra_channels( + frame_settings->enc->metadata.m.num_extra_channels); + for (auto& extra_channel : extra_channels) { + extra_channel = jxl::ImageF(xsize, ysize); + queued_frame->ec_initialized.push_back(0); + } + queued_frame->frame.SetExtraChannels(std::move(extra_channels)); queued_frame->frame.jpeg_data = std::move(io.Main().jpeg_data); queued_frame->frame.color_transform = io.Main().color_transform; queued_frame->frame.chroma_subsampling = io.Main().chroma_subsampling; @@ -1123,28 +1165,6 @@ JxlEncoderStatus JxlEncoderAddJPEGFrame( return JXL_ENC_SUCCESS; } -namespace { -JxlEncoderStatus GetCurrentDimensions( - const JxlEncoderFrameSettings* frame_settings, size_t& xsize, - size_t& ysize) { - xsize = frame_settings->enc->metadata.xsize(); - ysize = frame_settings->enc->metadata.ysize(); - if (frame_settings->values.header.layer_info.have_crop) { - xsize = frame_settings->values.header.layer_info.xsize; - ysize = frame_settings->values.header.layer_info.ysize; - } - if (frame_settings->values.cparams.already_downsampled) { - size_t factor = frame_settings->values.cparams.resampling; - xsize = jxl::DivCeil(xsize, factor); - ysize = jxl::DivCeil(ysize, factor); - } - if (xsize == 0 || ysize == 0) { - return JXL_API_ERROR("zero-sized frame is not allowed"); - } - return JXL_ENC_SUCCESS; -} -} // namespace - JxlEncoderStatus JxlEncoderAddImageFrame( const JxlEncoderFrameSettings* frame_settings, const JxlPixelFormat* pixel_format, const void* buffer, size_t size) { @@ -1167,7 +1187,8 @@ JxlEncoderStatus JxlEncoderAddImageFrame( // default move constructor there. jxl::JxlEncoderQueuedFrame{ frame_settings->values, - jxl::ImageBundle(&frame_settings->enc->metadata.m)}); + jxl::ImageBundle(&frame_settings->enc->metadata.m), + {}}); if (!queued_frame) { return JXL_ENC_ERROR; @@ -1202,6 +1223,14 @@ JxlEncoderStatus JxlEncoderAddImageFrame( extra_channel = jxl::ImageF(xsize, ysize); } queued_frame->frame.SetExtraChannels(std::move(extra_channels)); + for (auto& ec_info : frame_settings->enc->metadata.m.extra_channel_info) { + if (has_interleaved_alpha && ec_info.type == jxl::ExtraChannel::kAlpha) { + queued_frame->ec_initialized.push_back(1); + has_interleaved_alpha = 0; // only first Alpha is initialized + } else { + queued_frame->ec_initialized.push_back(0); + } + } if (!jxl::BufferToImageBundle(*pixel_format, xsize, ysize, buffer, size, frame_settings->enc->thread_pool.get(), @@ -1283,6 +1312,7 @@ JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelBuffer( .frame->frame.extra_channels()[index])) { return JXL_API_ERROR("Failed to set buffer for extra channel"); } + frame_settings->enc->input_queue.back().frame->ec_initialized[index] = 1; return JXL_ENC_SUCCESS; } diff --git a/third_party/jpeg-xl/lib/jxl/encode_internal.h b/third_party/jpeg-xl/lib/jxl/encode_internal.h index 8f112d333683..6d8d87ea06cc 100644 --- a/third_party/jpeg-xl/lib/jxl/encode_internal.h +++ b/third_party/jpeg-xl/lib/jxl/encode_internal.h @@ -51,6 +51,7 @@ constexpr unsigned char kLevelBoxHeader[] = {0, 0, 0, 0x9, 'j', 'x', 'l', 'l'}; struct JxlEncoderQueuedFrame { JxlEncoderFrameSettingsValues option_values; ImageBundle frame; + std::vector ec_initialized; }; struct JxlEncoderQueuedBox { diff --git a/third_party/jpeg-xl/lib/jxl/image_bundle.cc b/third_party/jpeg-xl/lib/jxl/image_bundle.cc index b91baf8aba8f..189b9768c313 100644 --- a/third_party/jpeg-xl/lib/jxl/image_bundle.cc +++ b/third_party/jpeg-xl/lib/jxl/image_bundle.cc @@ -78,6 +78,13 @@ size_t ImageBundle::DetectRealBitdepth() const { // and there may be slight imprecisions in the floating point image. } +const ImageF& ImageBundle::black() const { + JXL_ASSERT(HasBlack()); + const size_t ec = metadata_->Find(ExtraChannel::kBlack) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return extra_channels_[ec]; +} const ImageF& ImageBundle::alpha() const { JXL_ASSERT(HasAlpha()); const size_t ec = metadata_->Find(ExtraChannel::kAlpha) - diff --git a/third_party/jpeg-xl/lib/jxl/image_bundle.h b/third_party/jpeg-xl/lib/jxl/image_bundle.h index 80844295d0a4..de5951ddc2c9 100644 --- a/third_party/jpeg-xl/lib/jxl/image_bundle.h +++ b/third_party/jpeg-xl/lib/jxl/image_bundle.h @@ -178,6 +178,10 @@ class ImageBundle { ImageF* alpha(); // -- EXTRA CHANNELS + bool HasBlack() const { + return metadata_->Find(ExtraChannel::kBlack) != nullptr; + } + const ImageF& black() const; // Extra channels of unknown interpretation (e.g. spot colors). void SetExtraChannels(std::vector&& extra_channels); diff --git a/third_party/jpeg-xl/lib/jxl/image_ops.h b/third_party/jpeg-xl/lib/jxl/image_ops.h index 649e78bbae05..5b1cc1d4bb7a 100644 --- a/third_party/jpeg-xl/lib/jxl/image_ops.h +++ b/third_party/jpeg-xl/lib/jxl/image_ops.h @@ -776,7 +776,7 @@ void ZeroFillImage(Image3* image) { for (size_t c = 0; c < 3; ++c) { for (size_t y = 0; y < image->ysize(); ++y) { T* JXL_RESTRICT row = image->PlaneRow(c, y); - memset(row, 0, image->xsize() * sizeof(T)); + if (image->xsize() != 0) memset(row, 0, image->xsize() * sizeof(T)); } } } diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc index 60fd394907f0..5336e47fd272 100644 --- a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc @@ -477,6 +477,7 @@ bool EncodeDCTBlockSequential(const coeff_t* coeffs, temp2 = temp; if (temp < 0) { temp = -temp; + if (temp < 0) return false; temp2--; } int dc_nbits = (temp == 0) ? 0 : (FloorLog2Nonzero(temp) + 1); @@ -493,6 +494,7 @@ bool EncodeDCTBlockSequential(const coeff_t* coeffs, } if (temp < 0) { temp = -temp; + if (temp < 0) return false; temp2 = ~temp; } else { temp2 = temp; @@ -534,6 +536,7 @@ bool EncodeDCTBlockProgressive(const coeff_t* coeffs, temp2 = temp; if (temp < 0) { temp = -temp; + if (temp < 0) return false; temp2--; } int nbits = (temp == 0) ? 0 : (FloorLog2Nonzero(temp) + 1); @@ -554,6 +557,7 @@ bool EncodeDCTBlockProgressive(const coeff_t* coeffs, } if (temp < 0) { temp = -temp; + if (temp < 0) return false; temp >>= Al; temp2 = ~temp; } else { diff --git a/third_party/jpeg-xl/lib/jxl/jxl_test.cc b/third_party/jpeg-xl/lib/jxl/jxl_test.cc index 4892e939bbca..7c30c5525cf7 100644 --- a/third_party/jpeg-xl/lib/jxl/jxl_test.cc +++ b/third_party/jpeg-xl/lib/jxl/jxl_test.cc @@ -595,7 +595,7 @@ TEST(JxlTest, RoundtripSmallPatchesAlpha) { EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 2000u); EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), /*distmap=*/nullptr, pool), - IsSlightlyBelow(0.24f)); + IsSlightlyBelow(0.32f)); } TEST(JxlTest, RoundtripSmallPatches) { @@ -623,7 +623,7 @@ TEST(JxlTest, RoundtripSmallPatches) { EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 2000u); EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), /*distmap=*/nullptr, pool), - IsSlightlyBelow(0.24f)); + IsSlightlyBelow(0.32f)); } // Test header encoding of original bits per sample diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc index 98ead131cae5..8f0fd750545d 100644 --- a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc @@ -491,14 +491,12 @@ Status FwdPaletteIteration(Image &input, uint32_t begin_c, uint32_t end_c, {1, 3}, {2, 2}, {1, 0}, {1, 4}, {2, 1}, {2, 3}, {2, 0}, {2, 4}}; float total_available = 0; - int n = 0; for (int i = 0; i < 11; ++i) { const int row = offsets[i][0]; const int col = offsets[i][1]; if (std::signbit(error_row[row][c][x + col]) != std::signbit(total_error)) { total_available += error_row[row][c][x + col]; - n++; } } float weight = diff --git a/third_party/jpeg-xl/lib/jxl/quant_weights.cc b/third_party/jpeg-xl/lib/jxl/quant_weights.cc index 399a559ab93f..e8d9a10ed612 100644 --- a/third_party/jpeg-xl/lib/jxl/quant_weights.cc +++ b/third_party/jpeg-xl/lib/jxl/quant_weights.cc @@ -21,14 +21,21 @@ #include "lib/jxl/fields.h" #include "lib/jxl/image.h" +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/quant_weights.cc" +#include +#include + +#include "lib/jxl/fast_math-inl.h" + +HWY_BEFORE_NAMESPACE(); namespace jxl { +namespace HWY_NAMESPACE { // kQuantWeights[N * N * c + N * y + x] is the relative weight of the (x, y) // coefficient in component c. Higher weights correspond to finer quantization // intervals and more bits spent in encoding. -namespace { - static constexpr const float kAlmostZero = 1e-8f; void GetQuantWeightsDCT2(const QuantEncoding::DCT2Weights& dct2weights, @@ -75,33 +82,47 @@ void GetQuantWeightsIdentity(const QuantEncoding::IdWeights& idweights, } } -float Mult(float v) { - if (v > 0) return 1 + v; - return 1 / (1 - v); -} - float Interpolate(float pos, float max, const float* array, size_t len) { float scaled_pos = pos * (len - 1) / max; size_t idx = scaled_pos; - JXL_ASSERT(idx + 1 < len); + JXL_DASSERT(idx + 1 < len); float a = array[idx]; float b = array[idx + 1]; - return a * pow(b / a, scaled_pos - idx); + return a * FastPowf(b / a, scaled_pos - idx); +} + +float Mult(float v) { + if (v > 0.0f) return 1.0f + v; + return 1.0f / (1.0f - v); +} + +using DF4 = HWY_CAPPED(float, 4); + +hwy::HWY_NAMESPACE::Vec InterpolateVec( + hwy::HWY_NAMESPACE::Vec scaled_pos, const float* array) { + HWY_CAPPED(int32_t, 4) di; + + auto idx = ConvertTo(di, scaled_pos); + + auto frac = scaled_pos - ConvertTo(DF4(), idx); + + // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but + // it's probably slower. + auto a = GatherIndex(DF4(), array, idx); + auto b = GatherIndex(DF4(), array + 1, idx); + + return a * FastPowf(DF4(), b / a, frac); } // Computes quant weights for a COLS*ROWS-sized transform, using num_bands // eccentricity bands and num_ebands eccentricity bands. If print_mode is 1, // prints the resulting matrix; if print_mode is 2, prints the matrix in a // format suitable for a 3d plot with gnuplot. -template Status GetQuantWeights( size_t ROWS, size_t COLS, const DctQuantWeightParams::DistanceBandsArray& distance_bands, size_t num_bands, float* out) { for (size_t c = 0; c < 3; c++) { - if (print_mode) { - fprintf(stderr, "Channel %" PRIuS "\n", c); - } float bands[DctQuantWeightParams::kMaxDistanceBands] = { distance_bands[c][0]}; if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); @@ -109,32 +130,235 @@ Status GetQuantWeights( bands[i] = bands[i - 1] * Mult(distance_bands[c][i]); if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); } - for (size_t y = 0; y < ROWS; y++) { - for (size_t x = 0; x < COLS; x++) { - float dx = 1.0f * x / (COLS - 1); - float dy = 1.0f * y / (ROWS - 1); - float distance = std::sqrt(dx * dx + dy * dy); - float weight = - num_bands == 1 - ? bands[0] - : Interpolate(distance, std::sqrt(2) + 1e-6f, bands, num_bands); - - if (print_mode == 1) { - fprintf(stderr, "%15.12f, ", weight); - } - if (print_mode == 2) { - fprintf(stderr, "%" PRIuS " %" PRIuS " %15.12f\n", x, y, weight); - } - out[c * COLS * ROWS + y * COLS + x] = weight; + float scale = (num_bands - 1) / (kSqrt2 + 1e-6f); + float rcpcol = scale / (COLS - 1); + float rcprow = scale / (ROWS - 1); + JXL_ASSERT(COLS >= Lanes(DF4())); + HWY_ALIGN float l0123[4] = {0, 1, 2, 3}; + for (uint32_t y = 0; y < ROWS; y++) { + float dy = y * rcprow; + float dy2 = dy * dy; + for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) { + auto dx = (Set(DF4(), x) + Load(DF4(), l0123)) * Set(DF4(), rcpcol); + auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2))); + auto weight = num_bands == 1 ? Set(DF4(), bands[0]) + : InterpolateVec(scaled_distance, bands); + StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x); } - if (print_mode) fprintf(stderr, "\n"); - if (print_mode == 1) fprintf(stderr, "\n"); } - if (print_mode) fprintf(stderr, "\n"); } return true; } +// TODO(veluca): SIMD-fy. With 256x256, this is actually slow. +Status ComputeQuantTable(const QuantEncoding& encoding, + float* JXL_RESTRICT table, + float* JXL_RESTRICT inv_table, size_t table_num, + DequantMatrices::QuantTable kind, size_t* pos) { + constexpr size_t N = kBlockDim; + size_t wrows = 8 * DequantMatrices::required_size_x[kind], + wcols = 8 * DequantMatrices::required_size_y[kind]; + size_t num = wrows * wcols; + + std::vector weights(3 * num); + + switch (encoding.mode) { + case QuantEncoding::kQuantModeLibrary: { + // Library and copy quant encoding should get replaced by the actual + // parameters by the caller. + JXL_ASSERT(false); + break; + } + case QuantEncoding::kQuantModeID: { + JXL_ASSERT(num == kDCTBlockSize); + GetQuantWeightsIdentity(encoding.idweights, weights.data()); + break; + } + case QuantEncoding::kQuantModeDCT2: { + JXL_ASSERT(num == kDCTBlockSize); + GetQuantWeightsDCT2(encoding.dct2weights, weights.data()); + break; + } + case QuantEncoding::kQuantModeDCT4: { + JXL_ASSERT(num == kDCTBlockSize); + float weights4x4[3 * 4 * 4]; + // Always use 4x4 GetQuantWeights for DCT4 quantization tables. + JXL_RETURN_IF_ERROR( + GetQuantWeights(4, 4, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x4)); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + weights[c * num + y * kBlockDim + x] = + weights4x4[c * 16 + (y / 2) * 4 + (x / 2)]; + } + } + weights[c * num + 1] /= encoding.dct4multipliers[c][0]; + weights[c * num + N] /= encoding.dct4multipliers[c][0]; + weights[c * num + N + 1] /= encoding.dct4multipliers[c][1]; + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + JXL_ASSERT(num == kDCTBlockSize); + float weights4x8[3 * 4 * 8]; + // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables. + JXL_RETURN_IF_ERROR( + GetQuantWeights(4, 8, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x8)); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + weights[c * num + y * kBlockDim + x] = + weights4x8[c * 32 + (y / 2) * 8 + x]; + } + } + weights[c * num + N] /= encoding.dct4x8multipliers[c]; + } + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(GetQuantWeights( + wrows, wcols, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights.data())); + break; + } + case QuantEncoding::kQuantModeRAW: { + if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) { + return JXL_FAILURE("Invalid table encoding"); + } + for (size_t i = 0; i < 3 * num; i++) { + weights[i] = + 1.f / (encoding.qraw.qtable_den * (*encoding.qraw.qtable)[i]); + } + break; + } + case QuantEncoding::kQuantModeAFV: { + constexpr float kFreqs[] = { + 0xBAD, + 0xBAD, + 0.8517778890324296, + 5.37778436506804, + 0xBAD, + 0xBAD, + 4.734747904497923, + 5.449245381693219, + 1.6598270267479331, + 4, + 7.275749096817861, + 10.423227632456525, + 2.662932286148962, + 7.630657783650829, + 8.962388608184032, + 12.97166202570235, + }; + + float weights4x8[3 * 4 * 8]; + JXL_RETURN_IF_ERROR(( + GetQuantWeights(4, 8, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x8))); + float weights4x4[3 * 4 * 4]; + JXL_RETURN_IF_ERROR((GetQuantWeights( + 4, 4, encoding.dct_params_afv_4x4.distance_bands, + encoding.dct_params_afv_4x4.num_distance_bands, weights4x4))); + + constexpr float lo = 0.8517778890324296; + constexpr float hi = 12.97166202570235f - lo + 1e-6f; + for (size_t c = 0; c < 3; c++) { + float bands[4]; + bands[0] = encoding.afv_weights[c][5]; + if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); + for (size_t i = 1; i < 4; i++) { + bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]); + if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); + } + size_t start = c * 64; + auto set_weight = [&start, &weights](size_t x, size_t y, float val) { + weights[start + y * 8 + x] = val; + }; + weights[start] = 1; // Not used, but causes MSAN error otherwise. + // Weights for (0, 1) and (1, 0). + set_weight(0, 1, encoding.afv_weights[c][0]); + set_weight(1, 0, encoding.afv_weights[c][1]); + // AFV special weights for 3-pixel corner. + set_weight(0, 2, encoding.afv_weights[c][2]); + set_weight(2, 0, encoding.afv_weights[c][3]); + set_weight(2, 2, encoding.afv_weights[c][4]); + + // All other AFV weights. + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + if (x < 2 && y < 2) continue; + float val = Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4); + set_weight(2 * x, 2 * y, val); + } + } + + // Put 4x8 weights in odd rows, except (1, 0). + for (size_t y = 0; y < kBlockDim / 2; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + if (x == 0 && y == 0) continue; + weights[c * num + (2 * y + 1) * kBlockDim + x] = + weights4x8[c * 32 + y * 8 + x]; + } + } + // Put 4x4 weights in even rows / odd columns, except (0, 1). + for (size_t y = 0; y < kBlockDim / 2; y++) { + for (size_t x = 0; x < kBlockDim / 2; x++) { + if (x == 0 && y == 0) continue; + weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] = + weights4x4[c * 16 + y * 4 + x]; + } + } + } + break; + } + } + size_t prev_pos = *pos; + HWY_CAPPED(float, 64) d; + for (size_t i = 0; i < num * 3; i += Lanes(d)) { + auto inv_val = LoadU(d, weights.data() + i); + if (JXL_UNLIKELY(!AllFalse(d, inv_val >= Set(d, 1.0f / kAlmostZero)) || + !AllFalse(d, inv_val < Set(d, kAlmostZero)))) { + return JXL_FAILURE("Invalid quantization table"); + } + auto val = Set(d, 1.0f) / inv_val; + StoreU(val, d, table + *pos + i); + StoreU(inv_val, d, inv_table + *pos + i); + } + (*pos) += 3 * num; + + // Ensure that the lowest frequencies have a 0 inverse table. + // This does not affect en/decoding, but allows AC strategy selection to be + // slightly simpler. + size_t xs = DequantMatrices::required_size_x[kind]; + size_t ys = DequantMatrices::required_size_y[kind]; + CoefficientLayout(&ys, &xs); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ys; y++) { + for (size_t x = 0; x < xs; x++) { + inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs + + x] = 0; + } + } + } + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { +namespace { + +HWY_EXPORT(ComputeQuantTable); + +static constexpr const float kAlmostZero = 1e-8f; + Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) { params->num_distance_bands = br->ReadFixedBits() + 1; @@ -250,200 +474,6 @@ Status Decode(BitReader* br, QuantEncoding* encoding, size_t required_size_x, return true; } -// TODO(veluca): SIMD-fy. With 256x256, this is actually slow. -Status ComputeQuantTable(const QuantEncoding& encoding, - float* JXL_RESTRICT table, - float* JXL_RESTRICT inv_table, size_t table_num, - DequantMatrices::QuantTable kind, size_t* pos) { - std::vector weights(3 * kMaxQuantTableSize); - - constexpr size_t N = kBlockDim; - size_t wrows = 8 * DequantMatrices::required_size_x[kind], - wcols = 8 * DequantMatrices::required_size_y[kind]; - size_t num = wrows * wcols; - - switch (encoding.mode) { - case QuantEncoding::kQuantModeLibrary: { - // Library and copy quant encoding should get replaced by the actual - // parameters by the caller. - JXL_ASSERT(false); - break; - } - case QuantEncoding::kQuantModeID: { - JXL_ASSERT(num == kDCTBlockSize); - GetQuantWeightsIdentity(encoding.idweights, weights.data()); - break; - } - case QuantEncoding::kQuantModeDCT2: { - JXL_ASSERT(num == kDCTBlockSize); - GetQuantWeightsDCT2(encoding.dct2weights, weights.data()); - break; - } - case QuantEncoding::kQuantModeDCT4: { - JXL_ASSERT(num == kDCTBlockSize); - float weights4x4[3 * 4 * 4]; - // Always use 4x4 GetQuantWeights for DCT4 quantization tables. - JXL_RETURN_IF_ERROR( - GetQuantWeights(4, 4, encoding.dct_params.distance_bands, - encoding.dct_params.num_distance_bands, weights4x4)); - for (size_t c = 0; c < 3; c++) { - for (size_t y = 0; y < kBlockDim; y++) { - for (size_t x = 0; x < kBlockDim; x++) { - weights[c * num + y * kBlockDim + x] = - weights4x4[c * 16 + (y / 2) * 4 + (x / 2)]; - } - } - weights[c * num + 1] /= encoding.dct4multipliers[c][0]; - weights[c * num + N] /= encoding.dct4multipliers[c][0]; - weights[c * num + N + 1] /= encoding.dct4multipliers[c][1]; - } - break; - } - case QuantEncoding::kQuantModeDCT4X8: { - JXL_ASSERT(num == kDCTBlockSize); - float weights4x8[3 * 4 * 8]; - // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables. - JXL_RETURN_IF_ERROR( - GetQuantWeights(4, 8, encoding.dct_params.distance_bands, - encoding.dct_params.num_distance_bands, weights4x8)); - for (size_t c = 0; c < 3; c++) { - for (size_t y = 0; y < kBlockDim; y++) { - for (size_t x = 0; x < kBlockDim; x++) { - weights[c * num + y * kBlockDim + x] = - weights4x8[c * 32 + (y / 2) * 8 + x]; - } - } - weights[c * num + N] /= encoding.dct4x8multipliers[c]; - } - break; - } - case QuantEncoding::kQuantModeDCT: { - JXL_RETURN_IF_ERROR(GetQuantWeights( - wrows, wcols, encoding.dct_params.distance_bands, - encoding.dct_params.num_distance_bands, weights.data())); - break; - } - case QuantEncoding::kQuantModeRAW: { - if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) { - return JXL_FAILURE("Invalid table encoding"); - } - for (size_t i = 0; i < 3 * num; i++) { - weights[i] = - 1.f / (encoding.qraw.qtable_den * (*encoding.qraw.qtable)[i]); - } - break; - } - case QuantEncoding::kQuantModeAFV: { - constexpr float kFreqs[] = { - 0xBAD, - 0xBAD, - 0.8517778890324296, - 5.37778436506804, - 0xBAD, - 0xBAD, - 4.734747904497923, - 5.449245381693219, - 1.6598270267479331, - 4, - 7.275749096817861, - 10.423227632456525, - 2.662932286148962, - 7.630657783650829, - 8.962388608184032, - 12.97166202570235, - }; - - float weights4x8[3 * 4 * 8]; - JXL_RETURN_IF_ERROR(( - GetQuantWeights(4, 8, encoding.dct_params.distance_bands, - encoding.dct_params.num_distance_bands, weights4x8))); - float weights4x4[3 * 4 * 4]; - JXL_RETURN_IF_ERROR((GetQuantWeights( - 4, 4, encoding.dct_params_afv_4x4.distance_bands, - encoding.dct_params_afv_4x4.num_distance_bands, weights4x4))); - - constexpr float lo = 0.8517778890324296; - constexpr float hi = 12.97166202570235 - lo + 1e-6; - for (size_t c = 0; c < 3; c++) { - float bands[4]; - bands[0] = encoding.afv_weights[c][5]; - if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); - for (size_t i = 1; i < 4; i++) { - bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]); - if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); - } - size_t start = c * 64; - auto set_weight = [&start, &weights](size_t x, size_t y, float val) { - weights[start + y * 8 + x] = val; - }; - weights[start] = 1; // Not used, but causes MSAN error otherwise. - // Weights for (0, 1) and (1, 0). - set_weight(0, 1, encoding.afv_weights[c][0]); - set_weight(1, 0, encoding.afv_weights[c][1]); - // AFV special weights for 3-pixel corner. - set_weight(0, 2, encoding.afv_weights[c][2]); - set_weight(2, 0, encoding.afv_weights[c][3]); - set_weight(2, 2, encoding.afv_weights[c][4]); - - // All other AFV weights. - for (size_t y = 0; y < 4; y++) { - for (size_t x = 0; x < 4; x++) { - if (x < 2 && y < 2) continue; - float val = Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4); - set_weight(2 * x, 2 * y, val); - } - } - - // Put 4x8 weights in odd rows, except (1, 0). - for (size_t y = 0; y < kBlockDim / 2; y++) { - for (size_t x = 0; x < kBlockDim; x++) { - if (x == 0 && y == 0) continue; - weights[c * num + (2 * y + 1) * kBlockDim + x] = - weights4x8[c * 32 + y * 8 + x]; - } - } - // Put 4x4 weights in even rows / odd columns, except (0, 1). - for (size_t y = 0; y < kBlockDim / 2; y++) { - for (size_t x = 0; x < kBlockDim / 2; x++) { - if (x == 0 && y == 0) continue; - weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] = - weights4x4[c * 16 + y * 4 + x]; - } - } - } - break; - } - } - size_t prev_pos = *pos; - for (size_t c = 0; c < 3; c++) { - for (size_t i = 0; i < num; i++) { - float inv_val = weights[c * num + i]; - if (inv_val > 1.0f / kAlmostZero || inv_val < kAlmostZero) { - return JXL_FAILURE("Invalid quantization table"); - } - float val = 1.0f / inv_val; - table[*pos] = val; - inv_table[*pos] = inv_val; - (*pos)++; - } - } - // Ensure that the lowest frequencies have a 0 inverse table. - // This does not affect en/decoding, but allows AC strategy selection to be - // slightly simpler. - size_t xs = DequantMatrices::required_size_x[kind]; - size_t ys = DequantMatrices::required_size_y[kind]; - CoefficientLayout(&ys, &xs); - for (size_t c = 0; c < 3; c++) { - for (size_t y = 0; y < ys; y++) { - for (size_t x = 0; x < xs; x++) { - inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs + - x] = 0; - } - } - } - return true; -} - } // namespace // These definitions are needed before C++17. @@ -463,7 +493,8 @@ Status DequantMatrices::Decode(BitReader* br, jxl::Decode(br, &encodings_[i], required_size_x[i % kNum], required_size_y[i % kNum], i, modular_frame_decoder)); } - return DequantMatrices::Compute(); + computed_mask_ = 0; + return true; } Status DequantMatrices::DecodeDC(BitReader* br) { @@ -1126,61 +1157,78 @@ const QuantEncoding* DequantMatrices::Library() { return reinterpret_cast(kDequantLibrary.data()); } -Status DequantMatrices::Compute() { +DequantMatrices::DequantMatrices() { + encodings_.resize(size_t(QuantTable::kNum), QuantEncoding::Library(0)); size_t pos = 0; - - struct DefaultMatrices { - DefaultMatrices() { - const QuantEncoding* library = Library(); - size_t pos = 0; - for (size_t i = 0; i < kNum; i++) { - JXL_CHECK(ComputeQuantTable(library[i], table, inv_table, i, - QuantTable(i), &pos)); - } - JXL_CHECK(pos == kTotalTableSize); + size_t offsets[kNum * 3]; + for (size_t i = 0; i < size_t(QuantTable::kNum); i++) { + size_t num = required_size_[i] * kDCTBlockSize; + for (size_t c = 0; c < 3; c++) { + offsets[3 * i + c] = pos + c * num; } - HWY_ALIGN_MAX float table[kTotalTableSize]; - HWY_ALIGN_MAX float inv_table[kTotalTableSize]; - }; - - static const DefaultMatrices& default_matrices = - *hwy::MakeUniqueAligned().release(); - - JXL_ASSERT(encodings_.size() == kNum); - - bool has_nondefault_matrix = false; - for (const auto& enc : encodings_) { - if (enc.mode != QuantEncoding::kQuantModeLibrary) { - has_nondefault_matrix = true; + pos += 3 * num; + } + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + for (size_t c = 0; c < 3; c++) { + table_offsets_[i * 3 + c] = offsets[kQuantTable[i] * 3 + c]; } } - if (has_nondefault_matrix) { +} + +Status DequantMatrices::EnsureComputed(uint32_t acs_mask) { + const QuantEncoding* library = Library(); + + if (!table_storage_) { table_storage_ = hwy::AllocateAligned(2 * kTotalTableSize); table_ = table_storage_.get(); inv_table_ = table_storage_.get() + kTotalTableSize; - for (size_t table = 0; table < kNum; table++) { - size_t prev_pos = pos; - if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) { - size_t num = required_size_[table] * kDCTBlockSize; - memcpy(table_storage_.get() + prev_pos, - default_matrices.table + prev_pos, num * sizeof(float) * 3); - memcpy(table_storage_.get() + kTotalTableSize + prev_pos, - default_matrices.inv_table + prev_pos, num * sizeof(float) * 3); - pos += num * 3; - } else { - JXL_RETURN_IF_ERROR( - ComputeQuantTable(encodings_[table], table_storage_.get(), - table_storage_.get() + kTotalTableSize, table, - QuantTable(table), &pos)); - } - } - JXL_ASSERT(pos == kTotalTableSize); - } else { - table_ = default_matrices.table; - inv_table_ = default_matrices.inv_table; } + size_t offsets[kNum * 3 + 1]; + size_t pos = 0; + for (size_t i = 0; i < kNum; i++) { + size_t num = required_size_[i] * kDCTBlockSize; + for (size_t c = 0; c < 3; c++) { + offsets[3 * i + c] = pos + c * num; + } + pos += 3 * num; + } + offsets[kNum * 3] = pos; + JXL_ASSERT(pos == kTotalTableSize); + + uint32_t kind_mask = 0; + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + if (acs_mask & (1u << i)) { + kind_mask |= 1u << kQuantTable[i]; + } + } + uint32_t computed_kind_mask = 0; + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + if (computed_mask_ & (1u << i)) { + computed_kind_mask |= 1u << kQuantTable[i]; + } + } + for (size_t table = 0; table < kNum; table++) { + if ((1 << table) & computed_kind_mask) continue; + if ((1 << table) & ~kind_mask) continue; + size_t pos = offsets[table * 3]; + if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) { + JXL_CHECK(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( + library[table], table_storage_.get(), + table_storage_.get() + kTotalTableSize, table, QuantTable(table), + &pos)); + } else { + JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( + encodings_[table], table_storage_.get(), + table_storage_.get() + kTotalTableSize, table, QuantTable(table), + &pos)); + } + JXL_ASSERT(pos == offsets[table * 3 + 3]); + } + computed_mask_ |= acs_mask; + return true; } } // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/quant_weights.h b/third_party/jpeg-xl/lib/jxl/quant_weights.h index 816362f81c83..b235dc754edd 100644 --- a/third_party/jpeg-xl/lib/jxl/quant_weights.h +++ b/third_party/jpeg-xl/lib/jxl/quant_weights.h @@ -363,26 +363,7 @@ class DequantMatrices { sizeof(kQuantTable) / sizeof *kQuantTable, "Update this array when adding or removing AC strategies."); - DequantMatrices() { - encodings_.resize(size_t(QuantTable::kNum), QuantEncoding::Library(0)); - size_t pos = 0; - size_t offsets[kNum * 3]; - for (size_t i = 0; i < size_t(QuantTable::kNum); i++) { - encodings_[i] = QuantEncoding::Library(0); - size_t num = required_size_[i] * kDCTBlockSize; - for (size_t c = 0; c < 3; c++) { - offsets[3 * i + c] = pos + c * num; - } - pos += 3 * num; - } - for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { - for (size_t c = 0; c < 3; c++) { - table_offsets_[i * 3 + c] = offsets[kQuantTable[i] * 3 + c]; - } - } - // Default quantization tables need to be valid. - JXL_CHECK(Compute()); - } + DequantMatrices(); static const QuantEncoding* Library(); @@ -393,20 +374,17 @@ class DequantMatrices { // .cc file. static const DequantLibraryInternal LibraryInit(); - JXL_INLINE size_t MatrixOffset(size_t quant_kind, size_t c) const { - JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); - return table_offsets_[quant_kind * 3 + c]; - } - // Returns aligned memory. JXL_INLINE const float* Matrix(size_t quant_kind, size_t c) const { JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); - return &table_[MatrixOffset(quant_kind, c)]; + JXL_DASSERT((1 << quant_kind) & computed_mask_); + return &table_[table_offsets_[quant_kind * 3 + c]]; } JXL_INLINE const float* InvMatrix(size_t quant_kind, size_t c) const { JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); - return &inv_table_[MatrixOffset(quant_kind, c)]; + JXL_DASSERT((1 << quant_kind) & computed_mask_); + return &inv_table_[table_offsets_[quant_kind * 3 + c]]; } // DC quants are used in modular mode for XYB multipliers. @@ -418,6 +396,7 @@ class DequantMatrices { // For encoder. void SetEncodings(const std::vector& encodings) { encodings_ = encodings; + computed_mask_ = 0; } // For encoder. @@ -444,9 +423,9 @@ class DequantMatrices { static_assert(kNum == sizeof(required_size_y) / sizeof(*required_size_y), "Update this array when adding or removing quant tables."); - private: - Status Compute(); + Status EnsureComputed(uint32_t kind_mask); + private: static constexpr size_t required_size_[] = { 1, 1, 1, 1, 4, 16, 2, 4, 8, 1, 1, 64, 32, 256, 128, 1024, 512}; static_assert(kNum == sizeof(required_size_) / sizeof(*required_size_), @@ -454,6 +433,7 @@ class DequantMatrices { static constexpr size_t kTotalTableSize = ArraySum(required_size_) * kDCTBlockSize * 3; + uint32_t computed_mask_ = 0; // kTotalTableSize entries followed by kTotalTableSize for inv_table hwy::AlignedFreeUniquePtr table_storage_; const float* table_; diff --git a/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc b/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc index a700d178c45a..f0497948a7ff 100644 --- a/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc +++ b/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc @@ -173,6 +173,7 @@ TEST_P(QuantWeightsTargetTest, DCTUniform) { FrameHeader frame_header(&metadata); ModularFrameEncoder encoder(frame_header, CompressParams{}); DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder); + JXL_CHECK(dequant_matrices.EnsureComputed(~0u)); const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant, 1.0f / kUniformQuant}; diff --git a/third_party/jpeg-xl/lib/jxl/quantizer.h b/third_party/jpeg-xl/lib/jxl/quantizer.h index 8d9a2347901c..1ff593e5c18a 100644 --- a/third_party/jpeg-xl/lib/jxl/quantizer.h +++ b/third_party/jpeg-xl/lib/jxl/quantizer.h @@ -123,10 +123,6 @@ class Quantizer { return dequant_->InvMatrix(quant_kind, c); } - JXL_INLINE size_t DequantMatrixOffset(size_t quant_kind, size_t c) const { - return dequant_->MatrixOffset(quant_kind, c); - } - // Calculates DC quantization step. JXL_INLINE float GetDcStep(size_t c) const { return inv_quant_dc_ * dequant_->DCQuant(c); diff --git a/third_party/jpeg-xl/lib/jxl/splines.cc b/third_party/jpeg-xl/lib/jxl/splines.cc index c55fb1638ee9..edbeb28dba6f 100644 --- a/third_party/jpeg-xl/lib/jxl/splines.cc +++ b/third_party/jpeg-xl/lib/jxl/splines.cc @@ -102,9 +102,13 @@ void ComputeSegments(const Spline::Point& center, const float intensity, std::vector& segments, std::vector>& segments_by_y, size_t* pixel_limit) { + // In worst case zero-sized dot spans over 2 rows / columns. + constexpr const float kThinDotSpan = 2.0f; // Sanity check sigma, inverse sigma and intensity if (!(std::isfinite(sigma) && sigma != 0.0f && std::isfinite(1.0f / sigma) && std::isfinite(intensity))) { + // Even no-draw should still be accounted. + *pixel_limit -= std::min(*pixel_limit, kThinDotSpan * kThinDotSpan); return; } #if JXL_HIGH_PRECISION @@ -130,7 +134,7 @@ void ComputeSegments(const Spline::Point& center, const float intensity, segment.inv_sigma = 1.0f / sigma; segment.sigma_over_4_times_intensity = .25f * sigma * intensity; segment.maximum_distance = maximum_distance; - float cost = 2.0f * maximum_distance + 2.0f; + float cost = 2.0f * maximum_distance + kThinDotSpan; // Check cost^2 fits size_t. if (cost >= static_cast(1 << 15)) { // Too much to rasterize. @@ -142,9 +146,8 @@ void ComputeSegments(const Spline::Point& center, const float intensity, *pixel_limit = 0; return; } + // TODO(eustas): perhaps we should charge less: (y1 - y0) <= cost *pixel_limit -= area_cost; - // TODO(eustas): this will work incorrectly for (center.y >= 1 << 23) - // we have to use double precision in that case... ssize_t y0 = center.y - maximum_distance + .5f; ssize_t y1 = center.y + maximum_distance + 1.5f; // one-past-the-end for (ssize_t y = std::max(y0, 0); y < y1; y++) { @@ -418,7 +421,7 @@ Status QuantizedSpline::Dequantize(const Spline::Point& starting_point, int current_x = static_cast(roundf(starting_point.x)), current_y = static_cast(roundf(starting_point.y)); // It is not in spec, but reasonable limit to avoid overflows. - constexpr int kPosLimit = 1u << 30; + constexpr int kPosLimit = 1u << 23; if ((current_x >= kPosLimit) || (current_x <= -kPosLimit) || (current_y >= kPosLimit) || (current_y <= -kPosLimit)) { return JXL_FAILURE("Spline coordinates out of bounds"); diff --git a/third_party/jpeg-xl/lib/profiler/tsc_timer.h b/third_party/jpeg-xl/lib/profiler/tsc_timer.h index 802e40dde745..e3e1fb6ba7f9 100644 --- a/third_party/jpeg-xl/lib/profiler/tsc_timer.h +++ b/third_party/jpeg-xl/lib/profiler/tsc_timer.h @@ -24,7 +24,7 @@ #undef LoadFence #endif -#if defined(__MACH__) +#if defined(__APPLE__) #include #include #endif @@ -122,7 +122,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED Ticks TicksBefore() { LARGE_INTEGER counter; (void)QueryPerformanceCounter(&counter); t = counter.QuadPart; -#elif defined(__MACH__) +#elif defined(__APPLE__) t = mach_absolute_time(); #elif defined(__HAIKU__) t = system_time_nsecs(); // since boot diff --git a/third_party/jpeg-xl/third_party/CMakeLists.txt b/third_party/jpeg-xl/third_party/CMakeLists.txt index 82d55d47ae98..afefbaa80b32 100644 --- a/third_party/jpeg-xl/third_party/CMakeLists.txt +++ b/third_party/jpeg-xl/third_party/CMakeLists.txt @@ -82,6 +82,7 @@ endif() # BUILD_TESTING # Highway set(HWY_SYSTEM_GTEST ON CACHE INTERNAL "") +set(HWY_FORCE_STATIC_LIBS ON CACHE INTERNAL "") if((SANITIZER STREQUAL "asan") OR (SANITIZER STREQUAL "msan")) set(HWY_EXAMPLES_TESTS_INSTALL OFF CACHE INTERNAL "") endif()