Line data Source code
1 : #pragma once
2 :
3 : #include "definition.hpp"
4 :
5 : #ifdef WITH_NNUE
6 :
7 : #include "quantization.hpp"
8 : #include "stackVector.hpp"
9 : #include "weightReader.hpp"
10 :
11 : template<typename NT, size_t dim0, size_t dim1, bool Q>
12 : struct Layer {
13 : static constexpr size_t nbW = dim0 * dim1;
14 : static constexpr size_t nbB = dim1;
15 :
16 : using BT = typename Quantization<Q>::BT;
17 : using WT = typename Quantization<Q>::WT;
18 :
19 : // Layer is always for inner layer, so we can safely use WT and BT
20 : // Always small enough to be statically allocated
21 : alignas(NNUEALIGNMENT) WT W[nbW];
22 : alignas(NNUEALIGNMENT) BT b[nbB];
23 :
24 : template<typename T>
25 2127467 : CONSTEXPR StackVector<BT, dim1, Q> forward(const StackVector<T, dim0, Q>& x) const {
26 : StackVector<BT, dim1, Q> result;
27 2127467 : result.from(b);
28 : #ifdef USE_SIMD_INTRIN
29 : #pragma omp simd
30 21274670 : for (size_t i = 0; i < dim1; ++i) { result.data[i] += x.dot_(W + i * dim0); }
31 : #else
32 : #pragma omp simd
33 : for (size_t i = 0; i < dim0; ++i) { result.fma_(x.data[i], W + i * dim1); }
34 : #endif // USE_SIMD_INTRIN
35 2127467 : return result; // RVO
36 : }
37 :
38 : #ifdef USE_SIMD_INTRIN
39 : template<typename T>
40 : FORCE_FINLINE void forwardTo(const StackVector<T, dim0, Q>& x, BT* RESTRICT dst) const {
41 38294406 : for (size_t i = 0; i < dim1; ++i) { dst[i] = b[i]; }
42 38294406 : for (size_t i = 0; i < dim1; ++i) { dst[i] += x.dot_(W + i * dim0); }
43 : }
44 : #endif
45 :
46 176 : Layer<NT, dim0, dim1, Q>& load_(WeightsReader<NT>& ws) {
47 176 : ws.template streamW<WT>(W, nbW, dim0, dim1)
48 176 : .template streamB<BT>(b, nbB);
49 176 : return *this;
50 : }
51 :
52 : // non copyable
53 : Layer<NT, dim0, dim1, Q>& operator=(const Layer<NT, dim0, dim1, Q>& other) = delete;
54 : Layer<NT, dim0, dim1, Q>& operator=(Layer<NT, dim0, dim1, Q>&& other) = delete;
55 : Layer(const Layer<NT, dim0, dim1, Q>& other) = delete;
56 : Layer(Layer<NT, dim0, dim1, Q>&& other) = delete;
57 :
58 : Layer(){};
59 : ~Layer() = default;
60 :
61 : };
62 :
63 : #endif // WITH_NNUE
|