LCOV - code coverage report
Current view: top level - Source/nnue - weightReader.hpp (source / functions) Coverage Total Hit
Test: coverage Lines: 93.9 % 66 62
Test Date: 2026-03-02 16:42:41 Functions: 100.0 % 5 5

            Line data    Source code
       1              : #pragma once
       2              : 
       3              : #include "definition.hpp"
       4              : 
       5              : #ifdef WITH_NNUE
       6              : 
       7              : // NT is the network type as written inside the binary file
       8              : template<typename NT> 
       9              : struct WeightsReader {
      10              :    std::istream* file = nullptr;
      11              : 
      12           22 :    WeightsReader<NT>& readVersion(uint32_t& version) {
      13           22 :       assert(file);
      14           22 :       file->read((char*)&version, sizeof(uint32_t));
      15           22 :       return *this;
      16              :    }
      17              : 
      18              :    template<typename T> 
      19          176 :    WeightsReader<NT>& streamW(T* dst, const size_t request, [[maybe_unused]] size_t dim0, [[maybe_unused]] size_t dim1) {
      20          176 :       assert(file);
      21          176 :       Logging::LogIt(Logging::logInfo) << "Reading inner weight";
      22              :       // we will get min and max weight for display purpose
      23          176 :       NT minW = std::numeric_limits<NT>::max();
      24          176 :       NT maxW = std::numeric_limits<NT>::lowest();
      25          176 :       array1d<char, sizeof(NT)> singleElement {};
      26              :       
      27       280016 :       for (size_t i(0); i < request; ++i) {
      28       279840 :          file->read(singleElement.data(), singleElement.size());
      29              :          NT tmp {0};
      30              :          std::memcpy(&tmp, singleElement.data(), singleElement.size());
      31              :          // update min/max
      32       279840 :          minW = std::min(minW, tmp);
      33       279840 :          maxW = std::max(maxW, tmp);
      34              : #ifdef USE_SIMD_INTRIN
      35              :          // transpose data ///@todo as this is default now, do this in trainer ...
      36       279840 :          const size_t j = (i % dim1) * dim0 + i / dim1;
      37              : #else
      38              :          const size_t j = i;
      39              : #endif
      40       279840 :          dst[j] = static_cast<T>(tmp);
      41              :       }
      42          176 :       Logging::LogIt(Logging::logInfo) << "Weight in [" << minW << ", " << maxW << "]";
      43          176 :       return *this;
      44              :    }
      45              : 
      46              :    template<typename T, bool Q> 
      47           44 :    WeightsReader<NT>& streamWI(T* dst, const size_t request) {
      48           44 :       assert(file);
      49           44 :       Logging::LogIt(Logging::logInfo) << "Reading input weight";
      50              :       // we will get min and max weight for display purpose
      51           44 :       NT minW = std::numeric_limits<NT>::max();
      52           44 :       NT maxW = std::numeric_limits<NT>::lowest();
      53           44 :       array1d<char, sizeof(NT)> singleElement {};
      54              :       const NT Wscale = Quantization<Q>::scale;
      55              :       // read each weight one by one, and scale them if quantization is active
      56    830472236 :       for (size_t i(0); i < request; ++i) {
      57    830472192 :          file->read(singleElement.data(), singleElement.size());
      58              :          NT tmp {0};
      59              :          std::memcpy(&tmp, singleElement.data(), singleElement.size());
      60              :          // update min/max
      61    830472192 :          minW = std::min(minW, tmp);
      62    830472192 :          maxW = std::max(maxW, tmp);
      63              :          // if quantization is active and we overflow, just clamp and warn
      64    830472192 :          if (Q && Abs(tmp * Wscale) > (NT)std::numeric_limits<T>::max()) {
      65              :             NT tmp2 = tmp;
      66            0 :             tmp = std::clamp(tmp2 * Wscale, (NT)std::numeric_limits<T>::lowest(), (NT)std::numeric_limits<T>::max());
      67            0 :             Logging::LogIt(Logging::logWarn) << "Overflow weight " << tmp2 << " -> " << tmp;
      68              :          }
      69              :          else {
      70    830472192 :             tmp = tmp * Wscale;
      71              :          }
      72    830472192 :          dst[i] = static_cast<T>(Quantization<Q>::round(tmp));
      73              :       }
      74           44 :       Logging::LogIt(Logging::logInfo) << "Weight in [" << minW << ", " << maxW << "]";
      75           44 :       return *this;
      76              :    }
      77              : 
      78              :    template<typename T> 
      79          176 :    WeightsReader<NT>& streamB(T* dst, const size_t request) {
      80          176 :       assert(file);
      81          176 :       Logging::LogIt(Logging::logInfo) << "Reading inner bias";
      82              :       // we will get min and max bias for display purpose
      83          176 :       NT minB = std::numeric_limits<NT>::max();
      84          176 :       NT maxB = std::numeric_limits<NT>::lowest();
      85          176 :       array1d<char, sizeof(NT)> singleElement {};
      86              :       // read each bias one by one, and scale them if quantization is active
      87         1276 :       for (size_t i(0); i < request; ++i) {
      88         1100 :          file->read(singleElement.data(), singleElement.size());
      89              :          NT tmp {0};
      90              :          std::memcpy(&tmp, singleElement.data(), singleElement.size());
      91              :          // update min/max
      92         1100 :          minB = std::min(minB, tmp);
      93         1100 :          maxB = std::max(maxB, tmp);
      94         1100 :          dst[i] = static_cast<T>(tmp);
      95              :       }
      96          176 :       Logging::LogIt(Logging::logInfo) << "Bias in [" << minB << ", " << maxB << "]";
      97          176 :       return *this;
      98              :    }
      99              : 
     100              :    template<typename T, bool Q> 
     101           44 :    WeightsReader<NT>& streamBI(T* dst, const size_t request) {
     102           44 :       assert(file);
     103           44 :       Logging::LogIt(Logging::logInfo) << "Reading input bias";
     104              :       // we will get min and max bias for display purpose
     105           44 :       NT minB = std::numeric_limits<NT>::max();
     106           44 :       NT maxB = std::numeric_limits<NT>::lowest();
     107           44 :       array1d<char, sizeof(NT)> singleElement {};
     108              :       const NT Bscale = Quantization<Q>::scale;
     109              :       // read each bias one by one, and scale them if quantization is active
     110        16940 :       for (size_t i(0); i < request; ++i) {
     111        16896 :          file->read(singleElement.data(), singleElement.size());
     112              :          NT tmp {0};
     113              :          std::memcpy(&tmp, singleElement.data(), singleElement.size());
     114              :          // update min/max
     115        16896 :          minB = std::min(minB, tmp);
     116        16896 :          maxB = std::max(maxB, tmp);
     117              :          // if quantization is active and we overflow, just clamp and warn
     118        16896 :          if (Q && Abs(tmp * Bscale) > (NT)std::numeric_limits<T>::max()) {
     119              :             NT tmp2 = tmp;
     120            0 :             tmp = std::clamp(tmp2 * Bscale, (NT)std::numeric_limits<T>::lowest(), (NT)std::numeric_limits<T>::max());
     121            0 :             Logging::LogIt(Logging::logWarn) << "Overflow bias " << tmp2 << " -> " << tmp;
     122              :          }
     123              :          else {
     124        16896 :             tmp = tmp * Bscale;
     125              :          }
     126        16896 :          dst[i] = static_cast<T>(Quantization<Q>::round(tmp));
     127              :       }
     128           44 :       Logging::LogIt(Logging::logInfo) << "Bias in [" << minB << ", " << maxB << "]";
     129           44 :       return *this;
     130              :    }
     131              : 
     132           22 :    explicit WeightsReader(std::istream& stream): file(&stream) {}
     133              : };
     134              : 
     135              : #endif // WITH_NNUE
        

Generated by: LCOV version 2.0-1