LCOV - code coverage report
Current view: top level - Source/nnue - simd.hpp (source / functions) Coverage Total Hit
Test: coverage Lines: 100.0 % 26 26
Test Date: 2026-03-02 16:42:41 Functions: 100.0 % 3 3

            Line data    Source code
       1              : #pragma once
       2              : 
       3              : #include <cstring>
       4              : 
       5              : // Highly inspired by/copied from https://github.com/xianyi/OpenBLAS, same naming convention here.
       6              : 
       7              : /*
       8              : My gcc (and clang) gives those macros for simd extension :
       9              : 
      10              : >> gcc -march=skylake-avx512 -dM -E - < /dev/null | egrep "SSE|AVX" | sort
      11              : 
      12              : #define __AVX__ 1
      13              : #define __AVX2__ 1
      14              : #define __AVX512BW__ 1
      15              : #define __AVX512CD__ 1
      16              : #define __AVX512DQ__ 1
      17              : #define __AVX512F__ 1
      18              : #define __AVX512VL__ 1
      19              : #define __MMX_WITH_SSE__ 1
      20              : #define __SSE__ 1
      21              : #define __SSE2__ 1
      22              : #define __SSE2_MATH__ 1
      23              : #define __SSE3__ 1
      24              : #define __SSE4_1__ 1
      25              : #define __SSE4_2__ 1
      26              : #define __SSE_MATH__ 1
      27              : #define __SSSE3__ 1
      28              : */
      29              : 
      30              : /** SSE **/
      31              : #ifdef __SSE__
      32              : #include <xmmintrin.h>
      33              : #endif
      34              : /** SSE2 **/
      35              : #ifdef __SSE2__
      36              : #include <emmintrin.h>
      37              : #endif
      38              : /** SSE3 **/
      39              : #ifdef __SSE3__
      40              : #include <pmmintrin.h>
      41              : #endif
      42              : /** SSSE3 **/
      43              : #ifdef __SSSE3__
      44              : #include <tmmintrin.h>
      45              : #endif
      46              : /** SSE41 **/
      47              : #ifdef __SSE4_1__
      48              : #include <smmintrin.h>
      49              : #endif
      50              : /** AVX **/
      51              : #if defined(__AVX__) || defined(__FMA__)
      52              : #include <immintrin.h>
      53              : #endif
      54              : 
      55              : //----------------------------------
      56              : // AVX512
      57              : //----------------------------------
      58              : #if defined(__AVX512VL__NONONO)
      59              : #define V_SIMD_512 512
      60              : using v_f32_512 = __m512;
      61              : inline constexpr auto v_nlanes_f32_512 = 16;
      62              : #define v_add_f32_512    _mm512_add_ps
      63              : #define v_mul_f32_512    _mm512_mul_ps
      64              : #define v_muladd_f32_512 _mm512_fmadd_ps
      65              : 
      66              : FORCE_FINLINE float v_sum_f32_256(__m256 a) {
      67              :    __m256 sum_halves = _mm256_hadd_ps(a, a);
      68              :    sum_halves        = _mm256_hadd_ps(sum_halves, sum_halves);
      69              :    const __m128 lo   = _mm256_castps256_ps128(sum_halves);
      70              :    const __m128 hi   = _mm256_extractf128_ps(sum_halves, 1);
      71              :    const __m128 sum  = _mm_add_ps(lo, hi);
      72              :    return _mm_cvtss_f32(sum);
      73              : }
      74              : 
      75              : FORCE_FINLINE float v_sum_f32_512(__m512 a) {
      76              :    const __m256 low  = _mm512_castps512_ps256(a);
      77              :    const __m256 high = _mm512_extractf32x8_ps(a,1);
      78              :    return v_sum_f32_256(low+high);
      79              : }
      80              : 
      81              : #define v_load_f32_512(PTR) _mm512_load_ps((const __m512*)(PTR))
      82              : #define v_store_f32_512     _mm512_store_ps
      83              : #define v_zero_f32_512      _mm512_setzero_ps
      84              : #define v_set_f32_512       _mm512_set1_ps
      85              : #define v_max_f32_512       _mm512_max_ps
      86              : #define v_min_f32_512       _mm512_min_ps
      87              : 
      88              : template<bool Q>
      89              : FORCE_FINLINE void simdClippedReLU512Helper(float * RESTRICT x, const __m512 & zero, const __m512 & un){
      90              :    v_store_f32_512(x, v_max_f32_512(zero, v_min_f32_512(un, v_load_f32_512(x))));
      91              : }
      92              : 
      93              : template<size_t N, bool Q>
      94              : void simdActivation512(float * RESTRICT x, const __m512 & zero, const __m512 & un){
      95              :    constexpr int vstep    = v_nlanes_f32_512;
      96              :    constexpr int unrollx4 = N & (-vstep * 4);
      97              :    constexpr int unrollx  = N & -vstep;
      98              :    int i = 0;
      99              :    if constexpr(unrollx4){
     100              :       while (i < unrollx4) {
     101              :             simdClippedReLU512Helper<Q>(x + i            , zero, un);
     102              :             simdClippedReLU512Helper<Q>(x + i + vstep    , zero, un);
     103              :             simdClippedReLU512Helper<Q>(x + i + vstep * 2, zero, un);
     104              :             simdClippedReLU512Helper<Q>(x + i + vstep * 3, zero, un);
     105              :          i += vstep * 4;
     106              :       }
     107              :    }
     108              :    while (i < unrollx) {
     109              :          simdClippedReLU512Helper<Q>(x + i, zero, un);
     110              :       i += vstep;
     111              :    }
     112              : }
     113              : 
     114              : template<size_t N, bool Q>
     115              : [[nodiscard]] float simdDotProduct512(const float* RESTRICT x, const float* RESTRICT y) {
     116              :    constexpr int vstep    = v_nlanes_f32_512;
     117              :    constexpr int unrollx4 = N & (-vstep * 4);
     118              :    constexpr int unrollx  = N & -vstep;
     119              :    int i = 0;
     120              :    v_f32_512 vsum0 = v_zero_f32_512();
     121              :    if constexpr(unrollx4){
     122              :       v_f32_512 vsum1 = v_zero_f32_512();
     123              :       v_f32_512 vsum2 = v_zero_f32_512();
     124              :       v_f32_512 vsum3 = v_zero_f32_512();
     125              :       while (i < unrollx4) {
     126              :          vsum0 = v_muladd_f32_512(v_load_f32_512(x + i            ), v_load_f32_512(y + i            ), vsum0);
     127              :          vsum1 = v_muladd_f32_512(v_load_f32_512(x + i + vstep    ), v_load_f32_512(y + i + vstep    ), vsum1);
     128              :          vsum2 = v_muladd_f32_512(v_load_f32_512(x + i + vstep * 2), v_load_f32_512(y + i + vstep * 2), vsum2);
     129              :          vsum3 = v_muladd_f32_512(v_load_f32_512(x + i + vstep * 3), v_load_f32_512(y + i + vstep * 3), vsum3);
     130              :          i += vstep * 4;
     131              :       }
     132              :       vsum0 = v_add_f32_512(v_add_f32_512(vsum0, vsum1), v_add_f32_512(vsum2, vsum3));
     133              :    }
     134              :    while (i < unrollx) {
     135              :       vsum0 = v_muladd_f32_512(v_load_f32_512(x + i), v_load_f32_512(y + i), vsum0);
     136              :       i += vstep;
     137              :    }
     138              :    return v_sum_f32_512(vsum0);
     139              : }
     140              : #endif
     141              : 
     142              : //----------------------------------
     143              : // AVX
     144              : //----------------------------------
     145              : #if defined(__AVX2__)
     146              : #define V_SIMD_256 256
     147              : using v_f32_256 = __m256;
     148              : inline constexpr auto v_nlanes_f32_256 = 8;
     149              : #define v_add_f32_256    _mm256_add_ps
     150              : #define v_mul_f32_256    _mm256_mul_ps
     151              : #ifdef __FMA__
     152              : #define v_muladd_f32_256 _mm256_fmadd_ps
     153              : #else
     154              : FORCE_FINLINE __m256 v_muladd_f32_256(__m256 a, __m256 b, __m256 c) { return v_add_f32_256(v_mul_f32_256(a, b), c); }
     155              : #endif
     156              : #ifndef V_SIMD_512
     157              : FORCE_FINLINE float v_sum_f32_256(__m256 a) {
     158              :    __m256 sum_halves = _mm256_hadd_ps(a, a);
     159              :    sum_halves        = _mm256_hadd_ps(sum_halves, sum_halves);
     160              :    const __m128 lo   = _mm256_castps256_ps128(sum_halves);
     161              :    const __m128 hi   = _mm256_extractf128_ps(sum_halves, 1);
     162              :    const __m128 sum  = _mm_add_ps(lo, hi);
     163              :    return _mm_cvtss_f32(sum);
     164              : }
     165              : #endif
     166              : #define v_load_f32_256  _mm256_load_ps
     167              : #define v_store_f32_256 _mm256_store_ps
     168              : #define v_zero_f32_256  _mm256_setzero_ps
     169              : #define v_set_f32_256   _mm256_set1_ps
     170              : #define v_max_f32_256   _mm256_max_ps
     171              : #define v_min_f32_256   _mm256_min_ps
     172              : 
     173              : template<bool Q>
     174              : FORCE_FINLINE void simdClippedReLU256Helper(float * RESTRICT x, const __m256 & zero, const __m256 & un){
     175              :    v_store_f32_256(x, v_max_f32_256(zero, v_min_f32_256(un, v_load_f32_256(x))));
     176              : }
     177              : 
     178              : template<size_t N, bool Q>
     179              : void simdActivation256(float * RESTRICT x, const __m256 & zero, const __m256 & un){
     180              :    constexpr int vstep    = v_nlanes_f32_256;
     181              :    constexpr int unrollx4 = N & (-vstep * 4);
     182              :    constexpr int unrollx  = N & -vstep;
     183              :    int i = 0;
     184              :    if constexpr(unrollx4){
     185              :       while (i < unrollx4) {
     186              :             simdClippedReLU256Helper<Q>(x + i            , zero, un);
     187              :             simdClippedReLU256Helper<Q>(x + i + vstep    , zero, un);
     188              :             simdClippedReLU256Helper<Q>(x + i + vstep * 2, zero, un);
     189              :             simdClippedReLU256Helper<Q>(x + i + vstep * 3, zero, un);
     190              :          i += vstep * 4;
     191              :       }
     192              :    }
     193              :    while (i < unrollx) {
     194              :          simdClippedReLU256Helper<Q>(x + i, zero, un);
     195              :       i += vstep;
     196              :    }
     197              : }
     198              : 
     199              : template<size_t N, bool Q> 
     200              : [[nodiscard]] float simdDotProduct256(const float* RESTRICT x, const float* RESTRICT y) {
     201              :    constexpr int vstep    = v_nlanes_f32_256;
     202              :    constexpr int unrollx4 = N & (-vstep * 4);
     203              :    constexpr int unrollx  = N & -vstep;
     204              :    int i = 0;
     205              :    v_f32_256 vsum0 = v_zero_f32_256();
     206              :    if constexpr(unrollx4){
     207              :       v_f32_256 vsum1 = v_zero_f32_256();
     208              :       v_f32_256 vsum2 = v_zero_f32_256();
     209              :       v_f32_256 vsum3 = v_zero_f32_256();
     210              :       while (i < unrollx4) {
     211              :          vsum0 = v_muladd_f32_256(v_load_f32_256(x + i            ), v_load_f32_256(y + i            ), vsum0);
     212              :          vsum1 = v_muladd_f32_256(v_load_f32_256(x + i + vstep    ), v_load_f32_256(y + i + vstep    ), vsum1);
     213              :          vsum2 = v_muladd_f32_256(v_load_f32_256(x + i + vstep * 2), v_load_f32_256(y + i + vstep * 2), vsum2);
     214              :          vsum3 = v_muladd_f32_256(v_load_f32_256(x + i + vstep * 3), v_load_f32_256(y + i + vstep * 3), vsum3);
     215              :          i += vstep * 4;
     216              :       }
     217              :       vsum0 = v_add_f32_256(v_add_f32_256(vsum0, vsum1), v_add_f32_256(vsum2, vsum3));
     218              :    }
     219              :    while (i < unrollx) {
     220              :       vsum0 = v_muladd_f32_256(v_load_f32_256(x + i), v_load_f32_256(y + i), vsum0);
     221              :       i += vstep;
     222              :    }
     223              :    return v_sum_f32_256(vsum0);
     224              : }
     225              : #endif
     226              : 
     227              : //----------------------------------
     228              : // SSE
     229              : //----------------------------------
     230              : #if defined(__SSE2__)
     231              : #define V_SIMD_128 128
     232              : using v_f32_128 = __m128;
     233              : inline constexpr auto v_nlanes_f32_128 = 4;
     234              : #define v_add_f32_128    _mm_add_ps
     235              : #define v_mul_f32_128    _mm_mul_ps
     236              : #ifdef __FMA__
     237              : #define v_muladd_f32_128 _mm_fmadd_ps
     238              : //#elif defined(__FMA4__)
     239              : //#define v_muladd_f32_128 _mm_macc_ps
     240              : #else
     241              : FORCE_FINLINE __m128 v_muladd_f32_128(__m128 a, __m128 b, __m128 c) { return v_add_f32_128(v_mul_f32_128(a, b), c); }
     242              : #endif
     243              : FORCE_FINLINE float v_sum_f32_128(__m128 a) {
     244              : #ifdef __SSE3__
     245              :    const __m128 sum_halves = _mm_hadd_ps(a, a);
     246              :    return _mm_cvtss_f32(_mm_hadd_ps(sum_halves, sum_halves));
     247              : #else
     248              :    const __m128 t1 = _mm_movehl_ps(a, a);
     249              :    const __m128 t2 = _mm_add_ps(a, t1);
     250              :    const __m128 t3 = _mm_shuffle_ps(t2, t2, 1);
     251              :    const __m128 t4 = _mm_add_ss(t2, t3);
     252              :    return _mm_cvtss_f32(t4);
     253              : #endif
     254              : }
     255              : #define v_load_f32_128  _mm_load_ps
     256              : #define v_store_f32_128 _mm_store_ps
     257              : #define v_zero_f32_128  _mm_setzero_ps
     258              : #define v_set_f32_128   _mm_set1_ps
     259              : #define v_max_f32_128   _mm_max_ps
     260              : #define v_min_f32_128   _mm_min_ps
     261              : 
     262              : #if defined(__SSE2__)
     263              : #if defined(__SSE4_1__)
     264              : #define v_cvtepi16_epi32_128 _mm_cvtepi16_epi32
     265              : #else
     266              : FORCE_FINLINE __m128i v_cvtepi16_epi32_128(__m128i src_i16) {
     267              :    const __m128i sign = _mm_srai_epi16(src_i16, 15);
     268              :    return _mm_unpacklo_epi16(src_i16, sign);
     269              : }
     270              : #endif
     271              : #endif
     272              : 
     273              : template<bool Q>
     274              : FORCE_FINLINE void simdClippedReLU128Helper(float * RESTRICT x, const v_f32_128 & zero, const v_f32_128 & un){
     275              :    v_store_f32_128(x, v_max_f32_128(zero, v_min_f32_128(un, v_load_f32_128(x))));
     276              : }
     277              : 
     278              : 
     279              : template<size_t N, bool Q>
     280              : void simdActivation128(float * RESTRICT x, const v_f32_128 & zero, const v_f32_128 & un){
     281              :    constexpr int vstep    = v_nlanes_f32_128;
     282              :    constexpr int unrollx4 = N & (-vstep * 4);
     283              :    constexpr int unrollx  = N & -vstep;   
     284              :    int i = 0;
     285              :    if constexpr(unrollx4){
     286              :       while (i < unrollx4) {
     287              :             simdClippedReLU128Helper<Q>(x + i            , zero, un);
     288              :             simdClippedReLU128Helper<Q>(x + i + vstep    , zero, un);
     289              :             simdClippedReLU128Helper<Q>(x + i + vstep * 2, zero, un);
     290              :             simdClippedReLU128Helper<Q>(x + i + vstep * 3, zero, un);
     291              :          i += vstep * 4;
     292              :       }
     293              :    }
     294     19147203 :    while (i < unrollx) {
     295     12764802 :          simdClippedReLU128Helper<Q>(x + i, zero, un);
     296     12764802 :       i += vstep;
     297              :    }
     298              : }
     299              : 
     300              : template<size_t N, bool Q> 
     301     36166939 : [[nodiscard]] float simdDotProduct128(const float* RESTRICT x, const float* RESTRICT y) {
     302              :    constexpr int vstep    = v_nlanes_f32_128;
     303              :    constexpr int unrollx4 = N & (-vstep * 4);
     304              :    constexpr int unrollx  = N & -vstep;
     305              :    int i = 0;
     306              :    v_f32_128 vsum0 = v_zero_f32_128();
     307              :    if constexpr(unrollx4){
     308              :       v_f32_128 vsum1 = v_zero_f32_128();
     309              :       v_f32_128 vsum2 = v_zero_f32_128();
     310              :       v_f32_128 vsum3 = v_zero_f32_128();
     311    872261470 :       while (i < unrollx4) {
     312    836094531 :          vsum0 = v_muladd_f32_128(v_load_f32_128(x + i            ), v_load_f32_128(y + i            ), vsum0);
     313    836094531 :          vsum1 = v_muladd_f32_128(v_load_f32_128(x + i + vstep    ), v_load_f32_128(y + i + vstep    ), vsum1);
     314    836094531 :          vsum2 = v_muladd_f32_128(v_load_f32_128(x + i + vstep * 2), v_load_f32_128(y + i + vstep * 2), vsum2);
     315    836094531 :          vsum3 = v_muladd_f32_128(v_load_f32_128(x + i + vstep * 3), v_load_f32_128(y + i + vstep * 3), vsum3);
     316    816947328 :          i += vstep * 4;
     317              :       }
     318              :       vsum0 = v_add_f32_128(v_add_f32_128(vsum0, vsum1), v_add_f32_128(vsum2, vsum3));
     319              :    }
     320     57441609 :    while (i < unrollx) {
     321     38294406 :       vsum0 = v_muladd_f32_128(v_load_f32_128(x + i), v_load_f32_128(y + i), vsum0);
     322     38294406 :       i += vstep;
     323              :    }
     324     36166939 :    return v_sum_f32_128(vsum0);
     325              : }
     326              : 
     327              : #endif
     328              : 
     329              : template<size_t N, bool Q> 
     330              : FORCE_FINLINE void simdActivationDefault(float * RESTRICT x){
     331              :    constexpr int n1 = N & -4;
     332              :    for (int i = 0; i < n1; i += 4) {
     333              :          x[i]     = std::min(std::max(x[i]    , 0.f), 1.f);
     334              :          x[i + 1] = std::min(std::max(x[i + 1], 0.f), 1.f);
     335              :          x[i + 2] = std::min(std::max(x[i + 2], 0.f), 1.f);
     336              :          x[i + 3] = std::min(std::max(x[i + 3], 0.f), 1.f);
     337              :    }   
     338              : }
     339              : 
     340              : template<size_t N, bool Q> 
     341              : [[nodiscard]] float simdDotProductDefault(const float* RESTRICT x, const float* RESTRICT y) {
     342              :    constexpr int n1 = N & -4;
     343              :    float dot = 0.f;
     344              :    for (int i = 0; i < n1; i += 4) { 
     345              :       dot += y[i    ] * x[i    ]
     346              :            + y[i + 1] * x[i + 1] 
     347              :            + y[i + 2] * x[i + 2] 
     348              :            + y[i + 3] * x[i + 3]; 
     349              :    }
     350              :    return dot;
     351              : }
     352              : 
     353              : template<size_t N>
     354              : FORCE_FINLINE void simdAdd_i16(int16_t* RESTRICT dst, const int16_t* RESTRICT src) {
     355              :    size_t i = 0;
     356              : #if V_SIMD_256
     357              :    constexpr size_t vstep = 16; // 256 bits / 16 bits
     358              :    constexpr size_t unrollx4 = N & (-vstep * 4);
     359              :    constexpr size_t unrollx  = N & -vstep;
     360              :    if constexpr (unrollx4) {
     361              :       while (i < unrollx4) {
     362              :          _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i            ), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i            )), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i            ))));
     363              :          _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep    ), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep    )), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep    ))));
     364              :          _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep * 2), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep * 2)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep * 2))));
     365              :          _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep * 3), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep * 3)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep * 3))));
     366              :          i += vstep * 4;
     367              :       }
     368              :    }
     369              :    while (i + vstep <= unrollx) {
     370              :       _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i))));
     371              :       i += vstep;
     372              :    }
     373              : #endif
     374              : #if V_SIMD_128
     375              :    constexpr size_t vstep128 = 8;
     376  25357484614 :    while (i + vstep128 <= N) {
     377  24839984928 :       _mm_store_si128(reinterpret_cast<__m128i*>(dst + i), _mm_add_epi16(_mm_load_si128(reinterpret_cast<const __m128i*>(dst + i)), _mm_load_si128(reinterpret_cast<const __m128i*>(src + i))));
     378              :       i += vstep128;
     379              :    }
     380              : #endif
     381              :    const size_t tail = N - i;
     382              :    for (size_t j = 0; j < tail; ++j) dst[i + j] += src[i + j];
     383              : }
     384              : 
     385              : template<size_t N>
     386              : FORCE_FINLINE void simdSub_i16(int16_t* RESTRICT dst, const int16_t* RESTRICT src) {
     387              :    size_t i = 0;
     388              : #if V_SIMD_256
     389              :    constexpr size_t vstep = 16;
     390              :    constexpr size_t unrollx4 = N & (-vstep * 4);
     391              :    constexpr size_t unrollx  = N & -vstep;
     392              :    if constexpr (unrollx4) {
     393              :       while (i < unrollx4) {
     394              :          _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i            ), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i            )), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i            ))));
     395              :          _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep    ), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep    )), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep    ))));
     396              :          _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep * 2), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep * 2)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep * 2))));
     397              :          _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep * 3), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep * 3)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep * 3))));
     398              :          i += vstep * 4;
     399              :       }
     400              :    }
     401              :    while (i + vstep <= unrollx) {
     402              :       _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i))));
     403              :       i += vstep;
     404              :    }
     405              : #endif
     406              : #if V_SIMD_128
     407              :    constexpr size_t vstep128 = 8;
     408  16794124172 :    while (i + vstep128 <= N) {
     409  16451386944 :       _mm_store_si128(reinterpret_cast<__m128i*>(dst + i), _mm_sub_epi16(_mm_load_si128(reinterpret_cast<const __m128i*>(dst + i)), _mm_load_si128(reinterpret_cast<const __m128i*>(src + i))));
     410              :       i += vstep128;
     411              :    }
     412              : #endif
     413              :    const size_t tail = N - i;
     414              :    for (size_t j = 0; j < tail; ++j) dst[i + j] -= src[i + j];
     415              : }
     416              : 
     417              : template<size_t N>
     418              : FORCE_FINLINE void simdCopy_i16(int16_t* RESTRICT dst, const int16_t* RESTRICT src) {
     419              :    size_t i = 0;
     420              : #if V_SIMD_256
     421              :    constexpr size_t vstep = 16;
     422              :    while (i + vstep <= N) {
     423              :       _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i)));
     424              :       i += vstep;
     425              :    }
     426              : #endif
     427              : #if V_SIMD_128
     428              :    constexpr size_t vstep128 = 8;
     429    423218677 :    while (i + vstep128 <= N) {
     430    414582720 :       _mm_store_si128(reinterpret_cast<__m128i*>(dst + i), _mm_load_si128(reinterpret_cast<const __m128i*>(src + i)));
     431              :       i += vstep128;
     432              :    }
     433              : #endif
     434              :    const size_t tail = N - i;
     435              :    if (tail) std::memcpy(dst + i, src + i, tail * sizeof(int16_t));
     436              : }
     437              : 
     438              : template<size_t N>
     439              : FORCE_FINLINE void simdCopy_f32(float* RESTRICT dst, const float* RESTRICT src) {
     440              :    size_t i = 0;
     441              : #if V_SIMD_256
     442              :    constexpr size_t vstep = 8;
     443              :    while (i + vstep <= N) {
     444              :       _mm256_store_ps(dst + i, _mm256_load_ps(src + i));
     445              :       i += vstep;
     446              :    }
     447              : #endif
     448              : #if V_SIMD_128
     449              :    constexpr size_t vstep128 = 4;
     450     23402137 :    while (i + vstep128 <= N) {
     451     17019736 :       _mm_store_ps(dst + i, _mm_load_ps(src + i));
     452              :       i += vstep128;
     453              :    }
     454              : #endif
     455              :    const size_t tail = N - i;
     456              :    if (tail) std::memcpy(dst + i, src + i, tail * sizeof(float));
     457              : }
     458              : 
     459              : template<size_t N>
     460              : FORCE_FINLINE void simdDequantize_i16_f32(float* RESTRICT dst, const int16_t* RESTRICT src, const float scale) {
     461              :    size_t i = 0;
     462              : #if V_SIMD_256
     463              :    constexpr size_t vstep = 8; // 8 floats per __m256
     464              :    const __m256 vscale = _mm256_set1_ps(scale);
     465              :    while (i + vstep <= N) {
     466              :       // Load 8 x int16, sign-extend to 8 x int32, convert to 8 x float, multiply by scale
     467              :       const __m128i src_i16 = _mm_load_si128(reinterpret_cast<const __m128i*>(src + i));
     468              :       const __m256i src_i32 = _mm256_cvtepi16_epi32(src_i16);
     469              :       const __m256  src_f32 = _mm256_cvtepi32_ps(src_i32);
     470              :       _mm256_store_ps(dst + i, _mm256_mul_ps(src_f32, vscale));
     471              :       i += vstep;
     472              :    }
     473              : #endif
     474              : #if V_SIMD_128
     475              :    constexpr size_t vstep128 = 4;
     476              :    const __m128 vscale128 = _mm_set1_ps(scale);
     477              :    while (i + vstep128 <= N) {
     478              :       // Load 4 x int16 (as 64-bit), sign-extend to 4 x int32, convert to 4 x float
     479              :       const __m128i src_i16 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(src + i));
     480              :       const __m128i src_i32 = v_cvtepi16_epi32_128(src_i16);
     481              :       const __m128  src_f32 = _mm_cvtepi32_ps(src_i32);
     482              :       _mm_store_ps(dst + i, _mm_mul_ps(src_f32, vscale128));
     483              :       i += vstep128;
     484              :    }
     485              : #endif
     486              :    const size_t tail = N - i;
     487              :    for (size_t j = 0; j < tail; ++j) dst[i + j] = scale * static_cast<float>(src[i + j]);
     488              : }
     489              : 
     490              : template<size_t N0, size_t N1>
     491              : FORCE_FINLINE void simdSplice_f32(float* RESTRICT dst, const float* RESTRICT a, const float* RESTRICT b) {
     492              :    simdCopy_f32<N0>(dst, a);
     493              :    simdCopy_f32<N1>(dst + N0, b);
     494              : }
     495              : 
     496              : template<size_t N, bool Q>
     497              : FORCE_FINLINE void simdDequantizeActivate_i16_f32(float* RESTRICT dst, const int16_t* RESTRICT src, const float scale) {
     498              :    size_t i = 0;
     499              : #if V_SIMD_256
     500              :    constexpr size_t vstep = 8;
     501              :    const __m256 vscale = _mm256_set1_ps(scale);
     502              :    const __m256 vzero  = _mm256_setzero_ps();
     503              :    const __m256 vone   = _mm256_set1_ps(1.0f);
     504              :    while (i + vstep <= N) {
     505              :       const __m128i src_i16 = _mm_load_si128(reinterpret_cast<const __m128i*>(src + i));
     506              :       const __m256i src_i32 = _mm256_cvtepi16_epi32(src_i16);
     507              :       const __m256  deq     = _mm256_mul_ps(_mm256_cvtepi32_ps(src_i32), vscale);
     508              :       _mm256_store_ps(dst + i, _mm256_max_ps(vzero, _mm256_min_ps(vone, deq)));
     509              :       i += vstep;
     510              :    }
     511              : #endif
     512              : #if V_SIMD_128
     513              :    constexpr size_t vstep128 = 4;
     514              :    const __m128 vscale128 = _mm_set1_ps(scale);
     515              :    const __m128 vzero128  = _mm_setzero_ps();
     516              :    const __m128 vone128   = _mm_set1_ps(1.0f);
     517    410601131 :    while (i + vstep128 <= N) {
     518    408473664 :       const __m128i src_i16 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(src + i));
     519              :       const __m128i src_i32 = v_cvtepi16_epi32_128(src_i16);
     520              :       const __m128  deq     = _mm_mul_ps(_mm_cvtepi32_ps(src_i32), vscale128);
     521    408473664 :       _mm_store_ps(dst + i, _mm_max_ps(vzero128, _mm_min_ps(vone128, deq)));
     522              :       i += vstep128;
     523              :    }
     524              : #endif
     525              :    const size_t tail = N - i;
     526              :    for (size_t j = 0; j < tail; ++j) {
     527              :       const float deq = scale * static_cast<float>(src[i + j]);
     528              :       dst[i + j] = std::max(0.f, std::min(1.f, deq));
     529              :    }
     530              : }
     531              : 
     532              : template<size_t N, bool Q> 
     533              : void simdActivation(float* RESTRICT x) {
     534              :    size_t i = 0;
     535              :    if constexpr (N <= 0) return;
     536              : 
     537              : #if V_SIMD_512
     538              :    if (N-i >= 16){
     539              :       const v_f32_512 zero = v_zero_f32_512();
     540              :       const v_f32_512 un   = v_set_f32_512(1.f);      
     541              :       simdActivation512<N,Q>(x+i, zero, un);
     542              :       i += ((N-i) & -16);
     543              :    }
     544              : #endif
     545              : 
     546              : #if V_SIMD_256
     547              :    if (N-i >= 8){
     548              :       const v_f32_256 zero = v_zero_f32_256();
     549              :       const v_f32_256 un   = v_set_f32_256(1.f);
     550              :       simdActivation256<N,Q>(x+i, zero, un);
     551              :       i += ((N-i) & -8);
     552              :    }
     553              : #endif
     554              : 
     555              : #if V_SIMD_128
     556              :    if (N-i >= 4){
     557              :       const v_f32_128 zero = v_zero_f32_128();
     558              :       const v_f32_128 un   = v_set_f32_128(1.f);      
     559              :       simdActivation128<N,Q>(x+i, zero, un);
     560              :       i += ((N-i) & -4);
     561              :    }
     562              : #endif
     563              : 
     564              :    if (N-i >= 4){
     565              :       simdActivationDefault<N,Q>(x+i);
     566              :       i += ((N-i) & -4);
     567              :    }
     568              : 
     569              :    while (i < N) {
     570              :          x[i] = std::min(std::max(x[i], 0.f), 1.f);
     571              :       ++i;
     572              :    }
     573              : }
     574              : 
     575              : template<size_t N, bool Q> 
     576              : [[nodiscard]] float simdDotProduct(const float* RESTRICT x, const float* RESTRICT y) {
     577              :    size_t i = 0;
     578              :    float dot = 0.0f;
     579              :    if constexpr (N <= 0) return dot;
     580              : 
     581              : #if V_SIMD_512
     582              :    if (N-i >= 16){
     583              :       dot += simdDotProduct512<N,Q>(x+i,y+i);
     584              :       i += ((N-i) & -16);
     585              :    }
     586              : #endif
     587              : 
     588              : #if V_SIMD_256
     589              :    if (N-i >= 8){
     590              :       dot += simdDotProduct256<N,Q>(x+i,y+i);
     591              :       i += ((N-i) & -8);
     592              :    }
     593              : #endif
     594              : 
     595              : #if V_SIMD_128
     596              :    if (N-i >= 4){
     597     53186675 :       dot += simdDotProduct128<N,Q>(x+i,y+i);
     598              :       i += ((N-i) & -4);
     599              :    }
     600              : #endif
     601              : 
     602              :    if (N-i >= 4){
     603              :       dot += simdDotProductDefault<N,Q>(x+i,y+i);
     604              :       i += ((N-i) & -4);
     605              :    }
     606              : 
     607              :    while (i < N) {
     608              :       dot += y[i] * x[i];
     609              :       ++i;
     610              :    }
     611              :    return dot;
     612              : }
     613              : 
     614              : 
     615              : 
     616              : #ifdef TESTING
     617              : int main(int, char**){
     618              :    constexpr int n = 768;
     619              :    alignas(64) float a[n];
     620              :    alignas(64) float b[n];
     621              : 
     622              :    for (int i = 0; i < n; ++i){
     623              :       a[i] = (i+1)/1000.f/(n-1);
     624              :       b[i] = 1.f/a[i];
     625              :    }
     626              :    //std::cout << simdDotProduct512<n,true>(a,b) << std::endl;
     627              :    std::cout << simdDotProduct256<n,true>(a,b) << std::endl;
     628              :    std::cout << simdDotProduct128<n,true>(a,b) << std::endl;
     629              :    std::cout << simdDotProductDefault<n,true>(a,b) << std::endl;
     630              : 
     631              :    for (int i = 0; i < n; ++i){
     632              :       a[i] = -5.f + 10.f*i/(n-1);
     633              :       b[i] = 2.f;
     634              :    }
     635              : 
     636              :    auto reset = [&](){
     637              :       for (int i = 0; i < n; ++i){
     638              :          std::cout << a[i] << " ";
     639              :          a[i] = -5.f + 10.f*i/(n-1);
     640              :       }  
     641              :       std::cout << std::endl;    
     642              :    };
     643              : 
     644              :    {
     645              :    const v_f32_256 zero = v_zero_f32_256();
     646              :    const v_f32_256 un   = v_set_f32_256(1.f);
     647              :    simdActivation256<n,true>(a, zero, un);
     648              :    reset();
     649              :    }
     650              :    {
     651              :    const v_f32_128 zero = v_zero_f32_128();
     652              :    const v_f32_128 un   = v_set_f32_128(1.f);      
     653              :    simdActivation128<n,true>(a, zero, un);
     654              :    reset();
     655              :    }
     656              :    {
     657              :    simdActivationDefault<n,true>(a);
     658              :    reset();
     659              :    }
     660              : 
     661              :    return 0;
     662              : }
     663              : #endif
        

Generated by: LCOV version 2.0-1