Line data Source code
1 : #pragma once
2 :
3 : #include <cstring>
4 :
5 : // Highly inspired by/copied from https://github.com/xianyi/OpenBLAS, same naming convention here.
6 :
7 : /*
8 : My gcc (and clang) gives those macros for simd extension :
9 :
10 : >> gcc -march=skylake-avx512 -dM -E - < /dev/null | egrep "SSE|AVX" | sort
11 :
12 : #define __AVX__ 1
13 : #define __AVX2__ 1
14 : #define __AVX512BW__ 1
15 : #define __AVX512CD__ 1
16 : #define __AVX512DQ__ 1
17 : #define __AVX512F__ 1
18 : #define __AVX512VL__ 1
19 : #define __MMX_WITH_SSE__ 1
20 : #define __SSE__ 1
21 : #define __SSE2__ 1
22 : #define __SSE2_MATH__ 1
23 : #define __SSE3__ 1
24 : #define __SSE4_1__ 1
25 : #define __SSE4_2__ 1
26 : #define __SSE_MATH__ 1
27 : #define __SSSE3__ 1
28 : */
29 :
30 : /** SSE **/
31 : #ifdef __SSE__
32 : #include <xmmintrin.h>
33 : #endif
34 : /** SSE2 **/
35 : #ifdef __SSE2__
36 : #include <emmintrin.h>
37 : #endif
38 : /** SSE3 **/
39 : #ifdef __SSE3__
40 : #include <pmmintrin.h>
41 : #endif
42 : /** SSSE3 **/
43 : #ifdef __SSSE3__
44 : #include <tmmintrin.h>
45 : #endif
46 : /** SSE41 **/
47 : #ifdef __SSE4_1__
48 : #include <smmintrin.h>
49 : #endif
50 : /** AVX **/
51 : #if defined(__AVX__) || defined(__FMA__)
52 : #include <immintrin.h>
53 : #endif
54 :
55 : //----------------------------------
56 : // AVX512
57 : //----------------------------------
58 : #if defined(__AVX512VL__NONONO)
59 : #define V_SIMD_512 512
60 : using v_f32_512 = __m512;
61 : inline constexpr auto v_nlanes_f32_512 = 16;
62 : #define v_add_f32_512 _mm512_add_ps
63 : #define v_mul_f32_512 _mm512_mul_ps
64 : #define v_muladd_f32_512 _mm512_fmadd_ps
65 :
66 : FORCE_FINLINE float v_sum_f32_256(__m256 a) {
67 : __m256 sum_halves = _mm256_hadd_ps(a, a);
68 : sum_halves = _mm256_hadd_ps(sum_halves, sum_halves);
69 : const __m128 lo = _mm256_castps256_ps128(sum_halves);
70 : const __m128 hi = _mm256_extractf128_ps(sum_halves, 1);
71 : const __m128 sum = _mm_add_ps(lo, hi);
72 : return _mm_cvtss_f32(sum);
73 : }
74 :
75 : FORCE_FINLINE float v_sum_f32_512(__m512 a) {
76 : const __m256 low = _mm512_castps512_ps256(a);
77 : const __m256 high = _mm512_extractf32x8_ps(a,1);
78 : return v_sum_f32_256(low+high);
79 : }
80 :
81 : #define v_load_f32_512(PTR) _mm512_load_ps((const __m512*)(PTR))
82 : #define v_store_f32_512 _mm512_store_ps
83 : #define v_zero_f32_512 _mm512_setzero_ps
84 : #define v_set_f32_512 _mm512_set1_ps
85 : #define v_max_f32_512 _mm512_max_ps
86 : #define v_min_f32_512 _mm512_min_ps
87 :
88 : template<bool Q>
89 : FORCE_FINLINE void simdClippedReLU512Helper(float * RESTRICT x, const __m512 & zero, const __m512 & un){
90 : v_store_f32_512(x, v_max_f32_512(zero, v_min_f32_512(un, v_load_f32_512(x))));
91 : }
92 :
93 : template<size_t N, bool Q>
94 : void simdActivation512(float * RESTRICT x, const __m512 & zero, const __m512 & un){
95 : constexpr int vstep = v_nlanes_f32_512;
96 : constexpr int unrollx4 = N & (-vstep * 4);
97 : constexpr int unrollx = N & -vstep;
98 : int i = 0;
99 : if constexpr(unrollx4){
100 : while (i < unrollx4) {
101 : simdClippedReLU512Helper<Q>(x + i , zero, un);
102 : simdClippedReLU512Helper<Q>(x + i + vstep , zero, un);
103 : simdClippedReLU512Helper<Q>(x + i + vstep * 2, zero, un);
104 : simdClippedReLU512Helper<Q>(x + i + vstep * 3, zero, un);
105 : i += vstep * 4;
106 : }
107 : }
108 : while (i < unrollx) {
109 : simdClippedReLU512Helper<Q>(x + i, zero, un);
110 : i += vstep;
111 : }
112 : }
113 :
114 : template<size_t N, bool Q>
115 : [[nodiscard]] float simdDotProduct512(const float* RESTRICT x, const float* RESTRICT y) {
116 : constexpr int vstep = v_nlanes_f32_512;
117 : constexpr int unrollx4 = N & (-vstep * 4);
118 : constexpr int unrollx = N & -vstep;
119 : int i = 0;
120 : v_f32_512 vsum0 = v_zero_f32_512();
121 : if constexpr(unrollx4){
122 : v_f32_512 vsum1 = v_zero_f32_512();
123 : v_f32_512 vsum2 = v_zero_f32_512();
124 : v_f32_512 vsum3 = v_zero_f32_512();
125 : while (i < unrollx4) {
126 : vsum0 = v_muladd_f32_512(v_load_f32_512(x + i ), v_load_f32_512(y + i ), vsum0);
127 : vsum1 = v_muladd_f32_512(v_load_f32_512(x + i + vstep ), v_load_f32_512(y + i + vstep ), vsum1);
128 : vsum2 = v_muladd_f32_512(v_load_f32_512(x + i + vstep * 2), v_load_f32_512(y + i + vstep * 2), vsum2);
129 : vsum3 = v_muladd_f32_512(v_load_f32_512(x + i + vstep * 3), v_load_f32_512(y + i + vstep * 3), vsum3);
130 : i += vstep * 4;
131 : }
132 : vsum0 = v_add_f32_512(v_add_f32_512(vsum0, vsum1), v_add_f32_512(vsum2, vsum3));
133 : }
134 : while (i < unrollx) {
135 : vsum0 = v_muladd_f32_512(v_load_f32_512(x + i), v_load_f32_512(y + i), vsum0);
136 : i += vstep;
137 : }
138 : return v_sum_f32_512(vsum0);
139 : }
140 : #endif
141 :
142 : //----------------------------------
143 : // AVX
144 : //----------------------------------
145 : #if defined(__AVX2__)
146 : #define V_SIMD_256 256
147 : using v_f32_256 = __m256;
148 : inline constexpr auto v_nlanes_f32_256 = 8;
149 : #define v_add_f32_256 _mm256_add_ps
150 : #define v_mul_f32_256 _mm256_mul_ps
151 : #ifdef __FMA__
152 : #define v_muladd_f32_256 _mm256_fmadd_ps
153 : #else
154 : FORCE_FINLINE __m256 v_muladd_f32_256(__m256 a, __m256 b, __m256 c) { return v_add_f32_256(v_mul_f32_256(a, b), c); }
155 : #endif
156 : #ifndef V_SIMD_512
157 : FORCE_FINLINE float v_sum_f32_256(__m256 a) {
158 : __m256 sum_halves = _mm256_hadd_ps(a, a);
159 : sum_halves = _mm256_hadd_ps(sum_halves, sum_halves);
160 : const __m128 lo = _mm256_castps256_ps128(sum_halves);
161 : const __m128 hi = _mm256_extractf128_ps(sum_halves, 1);
162 : const __m128 sum = _mm_add_ps(lo, hi);
163 : return _mm_cvtss_f32(sum);
164 : }
165 : #endif
166 : #define v_load_f32_256 _mm256_load_ps
167 : #define v_store_f32_256 _mm256_store_ps
168 : #define v_zero_f32_256 _mm256_setzero_ps
169 : #define v_set_f32_256 _mm256_set1_ps
170 : #define v_max_f32_256 _mm256_max_ps
171 : #define v_min_f32_256 _mm256_min_ps
172 :
173 : template<bool Q>
174 : FORCE_FINLINE void simdClippedReLU256Helper(float * RESTRICT x, const __m256 & zero, const __m256 & un){
175 : v_store_f32_256(x, v_max_f32_256(zero, v_min_f32_256(un, v_load_f32_256(x))));
176 : }
177 :
178 : template<size_t N, bool Q>
179 : void simdActivation256(float * RESTRICT x, const __m256 & zero, const __m256 & un){
180 : constexpr int vstep = v_nlanes_f32_256;
181 : constexpr int unrollx4 = N & (-vstep * 4);
182 : constexpr int unrollx = N & -vstep;
183 : int i = 0;
184 : if constexpr(unrollx4){
185 : while (i < unrollx4) {
186 : simdClippedReLU256Helper<Q>(x + i , zero, un);
187 : simdClippedReLU256Helper<Q>(x + i + vstep , zero, un);
188 : simdClippedReLU256Helper<Q>(x + i + vstep * 2, zero, un);
189 : simdClippedReLU256Helper<Q>(x + i + vstep * 3, zero, un);
190 : i += vstep * 4;
191 : }
192 : }
193 : while (i < unrollx) {
194 : simdClippedReLU256Helper<Q>(x + i, zero, un);
195 : i += vstep;
196 : }
197 : }
198 :
199 : template<size_t N, bool Q>
200 : [[nodiscard]] float simdDotProduct256(const float* RESTRICT x, const float* RESTRICT y) {
201 : constexpr int vstep = v_nlanes_f32_256;
202 : constexpr int unrollx4 = N & (-vstep * 4);
203 : constexpr int unrollx = N & -vstep;
204 : int i = 0;
205 : v_f32_256 vsum0 = v_zero_f32_256();
206 : if constexpr(unrollx4){
207 : v_f32_256 vsum1 = v_zero_f32_256();
208 : v_f32_256 vsum2 = v_zero_f32_256();
209 : v_f32_256 vsum3 = v_zero_f32_256();
210 : while (i < unrollx4) {
211 : vsum0 = v_muladd_f32_256(v_load_f32_256(x + i ), v_load_f32_256(y + i ), vsum0);
212 : vsum1 = v_muladd_f32_256(v_load_f32_256(x + i + vstep ), v_load_f32_256(y + i + vstep ), vsum1);
213 : vsum2 = v_muladd_f32_256(v_load_f32_256(x + i + vstep * 2), v_load_f32_256(y + i + vstep * 2), vsum2);
214 : vsum3 = v_muladd_f32_256(v_load_f32_256(x + i + vstep * 3), v_load_f32_256(y + i + vstep * 3), vsum3);
215 : i += vstep * 4;
216 : }
217 : vsum0 = v_add_f32_256(v_add_f32_256(vsum0, vsum1), v_add_f32_256(vsum2, vsum3));
218 : }
219 : while (i < unrollx) {
220 : vsum0 = v_muladd_f32_256(v_load_f32_256(x + i), v_load_f32_256(y + i), vsum0);
221 : i += vstep;
222 : }
223 : return v_sum_f32_256(vsum0);
224 : }
225 : #endif
226 :
227 : //----------------------------------
228 : // SSE
229 : //----------------------------------
230 : #if defined(__SSE2__)
231 : #define V_SIMD_128 128
232 : using v_f32_128 = __m128;
233 : inline constexpr auto v_nlanes_f32_128 = 4;
234 : #define v_add_f32_128 _mm_add_ps
235 : #define v_mul_f32_128 _mm_mul_ps
236 : #ifdef __FMA__
237 : #define v_muladd_f32_128 _mm_fmadd_ps
238 : //#elif defined(__FMA4__)
239 : //#define v_muladd_f32_128 _mm_macc_ps
240 : #else
241 : FORCE_FINLINE __m128 v_muladd_f32_128(__m128 a, __m128 b, __m128 c) { return v_add_f32_128(v_mul_f32_128(a, b), c); }
242 : #endif
243 : FORCE_FINLINE float v_sum_f32_128(__m128 a) {
244 : #ifdef __SSE3__
245 : const __m128 sum_halves = _mm_hadd_ps(a, a);
246 : return _mm_cvtss_f32(_mm_hadd_ps(sum_halves, sum_halves));
247 : #else
248 : const __m128 t1 = _mm_movehl_ps(a, a);
249 : const __m128 t2 = _mm_add_ps(a, t1);
250 : const __m128 t3 = _mm_shuffle_ps(t2, t2, 1);
251 : const __m128 t4 = _mm_add_ss(t2, t3);
252 : return _mm_cvtss_f32(t4);
253 : #endif
254 : }
255 : #define v_load_f32_128 _mm_load_ps
256 : #define v_store_f32_128 _mm_store_ps
257 : #define v_zero_f32_128 _mm_setzero_ps
258 : #define v_set_f32_128 _mm_set1_ps
259 : #define v_max_f32_128 _mm_max_ps
260 : #define v_min_f32_128 _mm_min_ps
261 :
262 : #if defined(__SSE2__)
263 : #if defined(__SSE4_1__)
264 : #define v_cvtepi16_epi32_128 _mm_cvtepi16_epi32
265 : #else
266 : FORCE_FINLINE __m128i v_cvtepi16_epi32_128(__m128i src_i16) {
267 : const __m128i sign = _mm_srai_epi16(src_i16, 15);
268 : return _mm_unpacklo_epi16(src_i16, sign);
269 : }
270 : #endif
271 : #endif
272 :
273 : template<bool Q>
274 : FORCE_FINLINE void simdClippedReLU128Helper(float * RESTRICT x, const v_f32_128 & zero, const v_f32_128 & un){
275 : v_store_f32_128(x, v_max_f32_128(zero, v_min_f32_128(un, v_load_f32_128(x))));
276 : }
277 :
278 :
279 : template<size_t N, bool Q>
280 : void simdActivation128(float * RESTRICT x, const v_f32_128 & zero, const v_f32_128 & un){
281 : constexpr int vstep = v_nlanes_f32_128;
282 : constexpr int unrollx4 = N & (-vstep * 4);
283 : constexpr int unrollx = N & -vstep;
284 : int i = 0;
285 : if constexpr(unrollx4){
286 : while (i < unrollx4) {
287 : simdClippedReLU128Helper<Q>(x + i , zero, un);
288 : simdClippedReLU128Helper<Q>(x + i + vstep , zero, un);
289 : simdClippedReLU128Helper<Q>(x + i + vstep * 2, zero, un);
290 : simdClippedReLU128Helper<Q>(x + i + vstep * 3, zero, un);
291 : i += vstep * 4;
292 : }
293 : }
294 19147203 : while (i < unrollx) {
295 12764802 : simdClippedReLU128Helper<Q>(x + i, zero, un);
296 12764802 : i += vstep;
297 : }
298 : }
299 :
300 : template<size_t N, bool Q>
301 36166939 : [[nodiscard]] float simdDotProduct128(const float* RESTRICT x, const float* RESTRICT y) {
302 : constexpr int vstep = v_nlanes_f32_128;
303 : constexpr int unrollx4 = N & (-vstep * 4);
304 : constexpr int unrollx = N & -vstep;
305 : int i = 0;
306 : v_f32_128 vsum0 = v_zero_f32_128();
307 : if constexpr(unrollx4){
308 : v_f32_128 vsum1 = v_zero_f32_128();
309 : v_f32_128 vsum2 = v_zero_f32_128();
310 : v_f32_128 vsum3 = v_zero_f32_128();
311 872261470 : while (i < unrollx4) {
312 836094531 : vsum0 = v_muladd_f32_128(v_load_f32_128(x + i ), v_load_f32_128(y + i ), vsum0);
313 836094531 : vsum1 = v_muladd_f32_128(v_load_f32_128(x + i + vstep ), v_load_f32_128(y + i + vstep ), vsum1);
314 836094531 : vsum2 = v_muladd_f32_128(v_load_f32_128(x + i + vstep * 2), v_load_f32_128(y + i + vstep * 2), vsum2);
315 836094531 : vsum3 = v_muladd_f32_128(v_load_f32_128(x + i + vstep * 3), v_load_f32_128(y + i + vstep * 3), vsum3);
316 816947328 : i += vstep * 4;
317 : }
318 : vsum0 = v_add_f32_128(v_add_f32_128(vsum0, vsum1), v_add_f32_128(vsum2, vsum3));
319 : }
320 57441609 : while (i < unrollx) {
321 38294406 : vsum0 = v_muladd_f32_128(v_load_f32_128(x + i), v_load_f32_128(y + i), vsum0);
322 38294406 : i += vstep;
323 : }
324 36166939 : return v_sum_f32_128(vsum0);
325 : }
326 :
327 : #endif
328 :
329 : template<size_t N, bool Q>
330 : FORCE_FINLINE void simdActivationDefault(float * RESTRICT x){
331 : constexpr int n1 = N & -4;
332 : for (int i = 0; i < n1; i += 4) {
333 : x[i] = std::min(std::max(x[i] , 0.f), 1.f);
334 : x[i + 1] = std::min(std::max(x[i + 1], 0.f), 1.f);
335 : x[i + 2] = std::min(std::max(x[i + 2], 0.f), 1.f);
336 : x[i + 3] = std::min(std::max(x[i + 3], 0.f), 1.f);
337 : }
338 : }
339 :
340 : template<size_t N, bool Q>
341 : [[nodiscard]] float simdDotProductDefault(const float* RESTRICT x, const float* RESTRICT y) {
342 : constexpr int n1 = N & -4;
343 : float dot = 0.f;
344 : for (int i = 0; i < n1; i += 4) {
345 : dot += y[i ] * x[i ]
346 : + y[i + 1] * x[i + 1]
347 : + y[i + 2] * x[i + 2]
348 : + y[i + 3] * x[i + 3];
349 : }
350 : return dot;
351 : }
352 :
353 : template<size_t N>
354 : FORCE_FINLINE void simdAdd_i16(int16_t* RESTRICT dst, const int16_t* RESTRICT src) {
355 : size_t i = 0;
356 : #if V_SIMD_256
357 : constexpr size_t vstep = 16; // 256 bits / 16 bits
358 : constexpr size_t unrollx4 = N & (-vstep * 4);
359 : constexpr size_t unrollx = N & -vstep;
360 : if constexpr (unrollx4) {
361 : while (i < unrollx4) {
362 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i ), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i )), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i ))));
363 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep ), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep )), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep ))));
364 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep * 2), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep * 2)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep * 2))));
365 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep * 3), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep * 3)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep * 3))));
366 : i += vstep * 4;
367 : }
368 : }
369 : while (i + vstep <= unrollx) {
370 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i), _mm256_add_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i))));
371 : i += vstep;
372 : }
373 : #endif
374 : #if V_SIMD_128
375 : constexpr size_t vstep128 = 8;
376 25357484614 : while (i + vstep128 <= N) {
377 24839984928 : _mm_store_si128(reinterpret_cast<__m128i*>(dst + i), _mm_add_epi16(_mm_load_si128(reinterpret_cast<const __m128i*>(dst + i)), _mm_load_si128(reinterpret_cast<const __m128i*>(src + i))));
378 : i += vstep128;
379 : }
380 : #endif
381 : const size_t tail = N - i;
382 : for (size_t j = 0; j < tail; ++j) dst[i + j] += src[i + j];
383 : }
384 :
385 : template<size_t N>
386 : FORCE_FINLINE void simdSub_i16(int16_t* RESTRICT dst, const int16_t* RESTRICT src) {
387 : size_t i = 0;
388 : #if V_SIMD_256
389 : constexpr size_t vstep = 16;
390 : constexpr size_t unrollx4 = N & (-vstep * 4);
391 : constexpr size_t unrollx = N & -vstep;
392 : if constexpr (unrollx4) {
393 : while (i < unrollx4) {
394 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i ), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i )), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i ))));
395 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep ), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep )), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep ))));
396 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep * 2), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep * 2)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep * 2))));
397 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i + vstep * 3), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i + vstep * 3)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i + vstep * 3))));
398 : i += vstep * 4;
399 : }
400 : }
401 : while (i + vstep <= unrollx) {
402 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i), _mm256_sub_epi16(_mm256_load_si256(reinterpret_cast<const __m256i*>(dst + i)), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i))));
403 : i += vstep;
404 : }
405 : #endif
406 : #if V_SIMD_128
407 : constexpr size_t vstep128 = 8;
408 16794124172 : while (i + vstep128 <= N) {
409 16451386944 : _mm_store_si128(reinterpret_cast<__m128i*>(dst + i), _mm_sub_epi16(_mm_load_si128(reinterpret_cast<const __m128i*>(dst + i)), _mm_load_si128(reinterpret_cast<const __m128i*>(src + i))));
410 : i += vstep128;
411 : }
412 : #endif
413 : const size_t tail = N - i;
414 : for (size_t j = 0; j < tail; ++j) dst[i + j] -= src[i + j];
415 : }
416 :
417 : template<size_t N>
418 : FORCE_FINLINE void simdCopy_i16(int16_t* RESTRICT dst, const int16_t* RESTRICT src) {
419 : size_t i = 0;
420 : #if V_SIMD_256
421 : constexpr size_t vstep = 16;
422 : while (i + vstep <= N) {
423 : _mm256_store_si256(reinterpret_cast<__m256i*>(dst + i), _mm256_load_si256(reinterpret_cast<const __m256i*>(src + i)));
424 : i += vstep;
425 : }
426 : #endif
427 : #if V_SIMD_128
428 : constexpr size_t vstep128 = 8;
429 423218677 : while (i + vstep128 <= N) {
430 414582720 : _mm_store_si128(reinterpret_cast<__m128i*>(dst + i), _mm_load_si128(reinterpret_cast<const __m128i*>(src + i)));
431 : i += vstep128;
432 : }
433 : #endif
434 : const size_t tail = N - i;
435 : if (tail) std::memcpy(dst + i, src + i, tail * sizeof(int16_t));
436 : }
437 :
438 : template<size_t N>
439 : FORCE_FINLINE void simdCopy_f32(float* RESTRICT dst, const float* RESTRICT src) {
440 : size_t i = 0;
441 : #if V_SIMD_256
442 : constexpr size_t vstep = 8;
443 : while (i + vstep <= N) {
444 : _mm256_store_ps(dst + i, _mm256_load_ps(src + i));
445 : i += vstep;
446 : }
447 : #endif
448 : #if V_SIMD_128
449 : constexpr size_t vstep128 = 4;
450 23402137 : while (i + vstep128 <= N) {
451 17019736 : _mm_store_ps(dst + i, _mm_load_ps(src + i));
452 : i += vstep128;
453 : }
454 : #endif
455 : const size_t tail = N - i;
456 : if (tail) std::memcpy(dst + i, src + i, tail * sizeof(float));
457 : }
458 :
459 : template<size_t N>
460 : FORCE_FINLINE void simdDequantize_i16_f32(float* RESTRICT dst, const int16_t* RESTRICT src, const float scale) {
461 : size_t i = 0;
462 : #if V_SIMD_256
463 : constexpr size_t vstep = 8; // 8 floats per __m256
464 : const __m256 vscale = _mm256_set1_ps(scale);
465 : while (i + vstep <= N) {
466 : // Load 8 x int16, sign-extend to 8 x int32, convert to 8 x float, multiply by scale
467 : const __m128i src_i16 = _mm_load_si128(reinterpret_cast<const __m128i*>(src + i));
468 : const __m256i src_i32 = _mm256_cvtepi16_epi32(src_i16);
469 : const __m256 src_f32 = _mm256_cvtepi32_ps(src_i32);
470 : _mm256_store_ps(dst + i, _mm256_mul_ps(src_f32, vscale));
471 : i += vstep;
472 : }
473 : #endif
474 : #if V_SIMD_128
475 : constexpr size_t vstep128 = 4;
476 : const __m128 vscale128 = _mm_set1_ps(scale);
477 : while (i + vstep128 <= N) {
478 : // Load 4 x int16 (as 64-bit), sign-extend to 4 x int32, convert to 4 x float
479 : const __m128i src_i16 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(src + i));
480 : const __m128i src_i32 = v_cvtepi16_epi32_128(src_i16);
481 : const __m128 src_f32 = _mm_cvtepi32_ps(src_i32);
482 : _mm_store_ps(dst + i, _mm_mul_ps(src_f32, vscale128));
483 : i += vstep128;
484 : }
485 : #endif
486 : const size_t tail = N - i;
487 : for (size_t j = 0; j < tail; ++j) dst[i + j] = scale * static_cast<float>(src[i + j]);
488 : }
489 :
490 : template<size_t N0, size_t N1>
491 : FORCE_FINLINE void simdSplice_f32(float* RESTRICT dst, const float* RESTRICT a, const float* RESTRICT b) {
492 : simdCopy_f32<N0>(dst, a);
493 : simdCopy_f32<N1>(dst + N0, b);
494 : }
495 :
496 : template<size_t N, bool Q>
497 : FORCE_FINLINE void simdDequantizeActivate_i16_f32(float* RESTRICT dst, const int16_t* RESTRICT src, const float scale) {
498 : size_t i = 0;
499 : #if V_SIMD_256
500 : constexpr size_t vstep = 8;
501 : const __m256 vscale = _mm256_set1_ps(scale);
502 : const __m256 vzero = _mm256_setzero_ps();
503 : const __m256 vone = _mm256_set1_ps(1.0f);
504 : while (i + vstep <= N) {
505 : const __m128i src_i16 = _mm_load_si128(reinterpret_cast<const __m128i*>(src + i));
506 : const __m256i src_i32 = _mm256_cvtepi16_epi32(src_i16);
507 : const __m256 deq = _mm256_mul_ps(_mm256_cvtepi32_ps(src_i32), vscale);
508 : _mm256_store_ps(dst + i, _mm256_max_ps(vzero, _mm256_min_ps(vone, deq)));
509 : i += vstep;
510 : }
511 : #endif
512 : #if V_SIMD_128
513 : constexpr size_t vstep128 = 4;
514 : const __m128 vscale128 = _mm_set1_ps(scale);
515 : const __m128 vzero128 = _mm_setzero_ps();
516 : const __m128 vone128 = _mm_set1_ps(1.0f);
517 410601131 : while (i + vstep128 <= N) {
518 408473664 : const __m128i src_i16 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(src + i));
519 : const __m128i src_i32 = v_cvtepi16_epi32_128(src_i16);
520 : const __m128 deq = _mm_mul_ps(_mm_cvtepi32_ps(src_i32), vscale128);
521 408473664 : _mm_store_ps(dst + i, _mm_max_ps(vzero128, _mm_min_ps(vone128, deq)));
522 : i += vstep128;
523 : }
524 : #endif
525 : const size_t tail = N - i;
526 : for (size_t j = 0; j < tail; ++j) {
527 : const float deq = scale * static_cast<float>(src[i + j]);
528 : dst[i + j] = std::max(0.f, std::min(1.f, deq));
529 : }
530 : }
531 :
532 : template<size_t N, bool Q>
533 : void simdActivation(float* RESTRICT x) {
534 : size_t i = 0;
535 : if constexpr (N <= 0) return;
536 :
537 : #if V_SIMD_512
538 : if (N-i >= 16){
539 : const v_f32_512 zero = v_zero_f32_512();
540 : const v_f32_512 un = v_set_f32_512(1.f);
541 : simdActivation512<N,Q>(x+i, zero, un);
542 : i += ((N-i) & -16);
543 : }
544 : #endif
545 :
546 : #if V_SIMD_256
547 : if (N-i >= 8){
548 : const v_f32_256 zero = v_zero_f32_256();
549 : const v_f32_256 un = v_set_f32_256(1.f);
550 : simdActivation256<N,Q>(x+i, zero, un);
551 : i += ((N-i) & -8);
552 : }
553 : #endif
554 :
555 : #if V_SIMD_128
556 : if (N-i >= 4){
557 : const v_f32_128 zero = v_zero_f32_128();
558 : const v_f32_128 un = v_set_f32_128(1.f);
559 : simdActivation128<N,Q>(x+i, zero, un);
560 : i += ((N-i) & -4);
561 : }
562 : #endif
563 :
564 : if (N-i >= 4){
565 : simdActivationDefault<N,Q>(x+i);
566 : i += ((N-i) & -4);
567 : }
568 :
569 : while (i < N) {
570 : x[i] = std::min(std::max(x[i], 0.f), 1.f);
571 : ++i;
572 : }
573 : }
574 :
575 : template<size_t N, bool Q>
576 : [[nodiscard]] float simdDotProduct(const float* RESTRICT x, const float* RESTRICT y) {
577 : size_t i = 0;
578 : float dot = 0.0f;
579 : if constexpr (N <= 0) return dot;
580 :
581 : #if V_SIMD_512
582 : if (N-i >= 16){
583 : dot += simdDotProduct512<N,Q>(x+i,y+i);
584 : i += ((N-i) & -16);
585 : }
586 : #endif
587 :
588 : #if V_SIMD_256
589 : if (N-i >= 8){
590 : dot += simdDotProduct256<N,Q>(x+i,y+i);
591 : i += ((N-i) & -8);
592 : }
593 : #endif
594 :
595 : #if V_SIMD_128
596 : if (N-i >= 4){
597 53186675 : dot += simdDotProduct128<N,Q>(x+i,y+i);
598 : i += ((N-i) & -4);
599 : }
600 : #endif
601 :
602 : if (N-i >= 4){
603 : dot += simdDotProductDefault<N,Q>(x+i,y+i);
604 : i += ((N-i) & -4);
605 : }
606 :
607 : while (i < N) {
608 : dot += y[i] * x[i];
609 : ++i;
610 : }
611 : return dot;
612 : }
613 :
614 :
615 :
616 : #ifdef TESTING
617 : int main(int, char**){
618 : constexpr int n = 768;
619 : alignas(64) float a[n];
620 : alignas(64) float b[n];
621 :
622 : for (int i = 0; i < n; ++i){
623 : a[i] = (i+1)/1000.f/(n-1);
624 : b[i] = 1.f/a[i];
625 : }
626 : //std::cout << simdDotProduct512<n,true>(a,b) << std::endl;
627 : std::cout << simdDotProduct256<n,true>(a,b) << std::endl;
628 : std::cout << simdDotProduct128<n,true>(a,b) << std::endl;
629 : std::cout << simdDotProductDefault<n,true>(a,b) << std::endl;
630 :
631 : for (int i = 0; i < n; ++i){
632 : a[i] = -5.f + 10.f*i/(n-1);
633 : b[i] = 2.f;
634 : }
635 :
636 : auto reset = [&](){
637 : for (int i = 0; i < n; ++i){
638 : std::cout << a[i] << " ";
639 : a[i] = -5.f + 10.f*i/(n-1);
640 : }
641 : std::cout << std::endl;
642 : };
643 :
644 : {
645 : const v_f32_256 zero = v_zero_f32_256();
646 : const v_f32_256 un = v_set_f32_256(1.f);
647 : simdActivation256<n,true>(a, zero, un);
648 : reset();
649 : }
650 : {
651 : const v_f32_128 zero = v_zero_f32_128();
652 : const v_f32_128 un = v_set_f32_128(1.f);
653 : simdActivation128<n,true>(a, zero, un);
654 : reset();
655 : }
656 : {
657 : simdActivationDefault<n,true>(a);
658 : reset();
659 : }
660 :
661 : return 0;
662 : }
663 : #endif
|