Line data Source code
1 : #pragma once
2 :
3 : #include "definition.hpp"
4 :
5 : #ifdef WITH_NNUE
6 :
7 : template<typename T, size_t dim, bool Q>
8 : struct StackVector {
9 : alignas(NNUEALIGNMENT) T data[dim];
10 :
11 : template<typename T2>
12 517499686 : CONSTEXPR StackVector<T, dim, Q>& add_(const T2* other) {
13 : #ifdef USE_SIMD_INTRIN
14 : if constexpr (std::is_same_v<T, int16_t> && std::is_same_v<T2, int16_t>) {
15 517499686 : simdAdd_i16<dim>(data, other);
16 : } else
17 : #endif
18 : {
19 : #pragma omp simd
20 : for (size_t i = 0; i < dim; ++i) { data[i] += other[i]; }
21 : }
22 517499686 : return *this;
23 : }
24 :
25 : template<typename T2>
26 342737228 : CONSTEXPR StackVector<T, dim, Q>& sub_(const T2* other) {
27 : #ifdef USE_SIMD_INTRIN
28 : if constexpr (std::is_same_v<T, int16_t> && std::is_same_v<T2, int16_t>) {
29 342737228 : simdSub_i16<dim>(data, other);
30 : } else
31 : #endif
32 : {
33 : #pragma omp simd
34 : for (size_t i = 0; i < dim; ++i) { data[i] -= other[i]; }
35 : }
36 342737228 : return *this;
37 : }
38 :
39 : #ifndef USE_SIMD_INTRIN
40 : template<typename T2, typename T3>
41 : CONSTEXPR StackVector<T, dim, Q>& fma_(const T2 c, const T3* other) {
42 : #pragma omp simd
43 : for (size_t i = 0; i < dim; ++i) { data[i] += c * other[i]; }
44 : return *this;
45 : }
46 :
47 : template<typename F>
48 : CONSTEXPR StackVector<T, dim, Q>& apply_(F&& f) {
49 : #pragma omp simd
50 : for (size_t i = 0; i < dim; ++i) { data[i] = f(data[i]); }
51 : return *this;
52 : }
53 : #else
54 17019736 : T dot_(const T* other) const { return simdDotProduct<dim,Q>(data, other); }
55 : CONSTEXPR StackVector<T, dim, Q>& activation_() { simdActivation<dim,Q>(data); return *this;}
56 : #endif
57 :
58 : template<typename T2>
59 : FORCE_FINLINE void from(const T2* other) {
60 : #ifdef USE_SIMD_INTRIN
61 : if constexpr (std::is_same_v<T, int16_t> && std::is_same_v<T2, int16_t>) {
62 8637140 : simdCopy_i16<dim>(data, other);
63 : } else if constexpr (std::is_same_v<T, float> && std::is_same_v<T2, float>) {
64 : simdCopy_f32<dim>(data, other);
65 : } else
66 : #endif
67 : {
68 : #pragma omp simd
69 : for (size_t i = 0; i < dim; ++i) { data[i] = static_cast<T>(other[i]); }
70 : }
71 : }
72 :
73 : // note that quantization is done on read if needed (see weightReader)
74 : template <typename U>
75 : [[nodiscard]] CONSTEXPR StackVector<U, dim, Q> dequantize(const U& scale) const {
76 : static_assert(std::is_integral_v<T> && std::is_floating_point_v<U>);
77 : StackVector<U, dim, Q> result;
78 : #ifdef USE_SIMD_INTRIN
79 : if constexpr (std::is_same_v<T, int16_t> && std::is_same_v<U, float>) {
80 : simdDequantize_i16_f32<dim>(result.data, data, scale);
81 : } else
82 : #endif
83 : {
84 : #pragma omp simd
85 : for (size_t i = 0; i < dim; ++i) { result.data[i] = scale * static_cast<U>(data[i]); }
86 : }
87 : return result; // RVO
88 : }
89 :
90 : #ifdef DEBUG_NNUE_UPDATE
91 : bool operator==(const StackVector<T, dim, Q>& other) {
92 : constexpr T eps = std::numeric_limits<T>::epsilon() * 100;
93 : for (size_t i = 0; i < dim; ++i) {
94 : if (std::fabs(data[i] - other.data[i]) > eps) {
95 : std::cout << data[i] << "!=" << other.data[i] << std::endl;
96 : return false;
97 : }
98 : }
99 : return true;
100 : }
101 :
102 : bool operator!=(const StackVector<T, dim, Q>& other) {
103 : constexpr T eps = std::numeric_limits<T>::epsilon() * 100;
104 : for (size_t i = 0; i < dim; ++i) {
105 : if (std::fabs(data[i] - other.data[i]) > eps) {
106 : std::cout << data[i] << "!=" << other.data[i] << std::endl;
107 : return true;
108 : }
109 : }
110 : return false;
111 : }
112 : #endif
113 : };
114 :
115 : template<typename T, size_t dim0, size_t dim1, bool Q>
116 : CONSTEXPR StackVector<T, dim0 + dim1, Q> splice(const StackVector<T, dim0, Q>& a, const StackVector<T, dim1, Q>& b) {
117 : StackVector<T, dim0 + dim1, Q> c;
118 : #ifdef USE_SIMD_INTRIN
119 : if constexpr (std::is_same_v<T, float>) {
120 : simdSplice_f32<dim0, dim1>(c.data, a.data, b.data);
121 : } else
122 : #endif
123 : {
124 : #pragma omp simd
125 : for (size_t i = 0; i < dim0; ++i) { c.data[i] = a.data[i]; }
126 : #pragma omp simd
127 : for (size_t i = 0; i < dim1; ++i) { c.data[dim0 + i] = b.data[i]; }
128 : }
129 : return c; // RVO
130 : }
131 :
132 : template<typename T, size_t dim, bool Q>
133 : inline std::ostream& operator<<(std::ostream& ostr, const StackVector<T, dim, Q>& vec) {
134 : static_assert(dim != 0, "can't stream empty vector.");
135 : ostr << "StackVector<T, " << dim << ">([";
136 : for (size_t i = 0; i < (dim - 1); ++i) { ostr << vec.data[i] << ", "; }
137 : ostr << vec.data[dim - 1] << "])";
138 : return ostr;
139 : }
140 :
141 : #endif // WITH_NNUE
|