LCOV - code coverage report
Current view: top level - Source/nnue - evaluator.hpp (source / functions) Coverage Total Hit
Test: coverage Lines: 100.0 % 17 17
Test Date: 2026-03-02 16:42:41 Functions: 100.0 % 2 2

            Line data    Source code
       1              : #pragma once
       2              : 
       3              : #include "definition.hpp"
       4              : 
       5              : #ifdef WITH_NNUE
       6              : 
       7              : #include "activation.hpp"
       8              : #include "featureTransformer.hpp"
       9              : #include "sided.hpp"
      10              : #include "weight.hpp"
      11              : 
      12              : template<typename NT, bool Q> 
      13              : struct NNUEEval : Sided<NNUEEval<NT, Q>, FeatureTransformer<NT, Q>> {
      14              :    // common data (weights and bias)
      15              :    static NNUEWeights<NT, Q> weights;
      16              : 
      17              :    // instance data (active index)
      18              :    FeatureTransformer<NT, Q> white;
      19              :    FeatureTransformer<NT, Q> black;
      20              : 
      21              :    // status of the FeatureTransformers
      22              :    // if dirty, then an update/reset is necessary
      23              :    bool dirty = true;
      24              : 
      25              :    FORCE_FINLINE void clear() {
      26         1218 :       dirty = true;
      27              :       white.clear();
      28              :       black.clear();
      29              :    }
      30              : 
      31              :    FORCE_FINLINE void clear(Color color) {
      32      8632338 :       dirty = true;
      33      8632338 :       if (color == Co_White)
      34              :          white.clear();
      35              :       else
      36              :          black.clear();
      37              :    }
      38              : 
      39              :    using BT = typename Quantization<Q>::BT;
      40              :    
      41      2127467 :    float propagate(Color c, const int bucket) const {
      42      2127467 :       assert(!dirty);
      43      2127467 :       assert(bucket >= 0);
      44      2127467 :       assert(bucket < (NNUEWeights<NT, Q>::nbuckets));
      45              :       
      46              :       constexpr float deqScale = 1.f / Quantization<Q>::scale;
      47      2127467 :       const auto& layer = weights.innerLayer[bucket];
      48              : 
      49              : #ifdef USE_SIMD_INTRIN
      50              :       StackVector<BT, 2 * firstInnerLayerSize, Q> x0;
      51              :       {
      52      2127467 :          const auto& first  = (c == Co_White) ? white : black;
      53      2127467 :          const auto& second = (c == Co_White) ? black : white;
      54              :          simdDequantizeActivate_i16_f32<firstInnerLayerSize, Q>(
      55      2127467 :             x0.data, first.active().data, deqScale);
      56              :          simdDequantizeActivate_i16_f32<firstInnerLayerSize, Q>(
      57      2127467 :             x0.data + firstInnerLayerSize, second.active().data, deqScale);
      58              :       }
      59              : 
      60      4254934 :       auto x1 = layer.fc0.forward(x0).activation_();
      61              : 
      62              :       StackVector<BT, 16, Q> x2;
      63              :       simdCopy_f32<8>(x2.data, x1.data);
      64              :       layer.fc1.forwardTo(x1, x2.data + 8);
      65              :       simdActivation<8, Q>(x2.data + 8);
      66              : 
      67              :       StackVector<BT, 24, Q> x3;
      68              :       simdCopy_f32<16>(x3.data, x2.data);
      69              :       layer.fc2.forwardTo(x2, x3.data + 16);
      70              :       simdActivation<8, Q>(x3.data + 16);
      71              : 
      72      2127467 :       const float val = layer.fc3.forward(x3).data[0];
      73              : #if defined(__AVX2__)
      74              :       _mm256_zeroupper();
      75              : #endif
      76              : 
      77              : #else
      78              :       // Non-SIMD fallback
      79              :       const auto w_x {white.active().dequantize(deqScale)
      80              :                                .apply_(activationInput<BT, Q>)
      81              :                   };
      82              :       const auto b_x {black.active().dequantize(deqScale)
      83              :                                .apply_(activationInput<BT, Q>)
      84              :                   };
      85              :       
      86              :       const auto x0 = c == Co_White ? splice(w_x, b_x) : splice(b_x, w_x);
      87              :       const auto x1 = layer.fc0.forward(x0)
      88              :                             .apply_(activation<BT, Q>);
      89              :       const auto x2 = splice(x1, layer.fc1.forward(x1)
      90              :                                       .apply_(activation<BT, Q>));
      91              :       const auto x3 = splice(x2, layer.fc2.forward(x2)
      92              :                                       .apply_(activation<BT, Q>));
      93              :       const float val = layer.fc3.forward(x3).data[0];
      94              : #endif
      95      2127467 :       return val * Quantization<Q>::outFactor;
      96              :    }
      97              : 
      98              : #ifdef DEBUG_NNUE_UPDATE
      99              :    bool operator==(const NNUEEval<NT,Q>& other) {
     100              :       if (white != other.white || black != other.black) return false;
     101              :       return true;
     102              :    }
     103              : 
     104              :    bool operator!=(const NNUEEval<NT,Q>& other) {
     105              :       if (white != other.white || black != other.black) return true;
     106              :       return false;
     107              :    }
     108              : #endif
     109              : 
     110              :    // default CTOR always use loaded weights in FeatureTransformer 
     111              :    // (as pointer of course, not a copy)
     112         1183 :    NNUEEval():
     113              :       white(&weights.w),
     114         1183 :       black(&weights.b){}
     115              : };
     116              : 
     117              : template<typename NT, bool Q> 
     118              : NNUEWeights<NT, Q> NNUEEval<NT, Q>::weights;
     119              : 
     120              : #endif // WITH_NNUE
        

Generated by: LCOV version 2.0-1