LCOV - code coverage report
Current view: top level - Source/nnue - weight.hpp (source / functions) Coverage Total Hit
Test: coverage Lines: 57.6 % 33 19
Test Date: 2026-03-02 16:42:41 Functions: 100.0 % 2 2

            Line data    Source code
       1              : #pragma once
       2              : 
       3              : #include "definition.hpp"
       4              : 
       5              : #ifdef WITH_NNUE
       6              : 
       7              : #include "inputLayer.hpp"
       8              : #include "layer.hpp"
       9              : #include "weightReader.hpp"
      10              : 
      11              : template<typename NT, bool Q> struct NNUEWeights {
      12              : 
      13              :    static constexpr int nbuckets = 2;
      14              : 
      15              :    InputLayer<NT, inputLayerSize, firstInnerLayerSize, Q> w {};
      16              :    InputLayer<NT, inputLayerSize, firstInnerLayerSize, Q> b {};
      17              : 
      18              :    struct InnerLayer{
      19              :      Layer<NT,  2 * firstInnerLayerSize, 8, Q> fc0;
      20              :      Layer<NT,  8, 8, Q>                       fc1;
      21              :      Layer<NT, 16, 8, Q>                       fc2;
      22              :      Layer<NT, 24, 1, Q>                       fc3;
      23              :    };
      24              : 
      25              :    array1d<InnerLayer, nbuckets> innerLayer;
      26              : 
      27              :    uint32_t version {0};
      28              : 
      29           22 :    NNUEWeights<NT, Q>& load(WeightsReader<NT>& ws, bool readVersion) {
      30           22 :       quantizationInfo<Q>();
      31           22 :       if (readVersion) ws.readVersion(version);
      32           22 :       w.load_(ws);
      33           22 :       b.load_(ws);
      34           66 :       for (auto & l : innerLayer) l.fc0.load_(ws);
      35           66 :       for (auto & l : innerLayer) l.fc1.load_(ws);
      36           66 :       for (auto & l : innerLayer) l.fc2.load_(ws);
      37           66 :       for (auto & l : innerLayer) l.fc3.load_(ws);
      38           22 :       return *this;
      39              :    }
      40              : 
      41           22 :    static bool load(const std::string& path, NNUEWeights<NT, Q>& loadedWeights) {
      42              :       [[maybe_unused]] constexpr uint32_t expectedVersion {0xc0ffee03};
      43              :       [[maybe_unused]] constexpr int      expectedSize    {151049100}; // net size + 4 for version
      44              :       [[maybe_unused]] constexpr bool     withVersion     {true}; // used for backward compatiblity and debug
      45              : 
      46           22 :       if (path != "embedded") { // read from disk
      47              : #ifndef WITHOUT_FILESYSTEM
      48              :          std::error_code ec;
      49            0 :          const auto fsize = std::filesystem::file_size(path, ec);
      50            0 :          if (ec) {
      51            0 :             Logging::LogIt(Logging::logError) << "File " << path << " is not accessible";
      52            0 :             return false;
      53              :          }
      54            0 :          if (fsize != expectedSize) { // with or without version
      55            0 :             Logging::LogIt(Logging::logError) << "File " << path << " does not look like a compatible net";
      56            0 :             return false;
      57              :          }
      58              : #endif
      59            0 :          std::fstream stream(path, std::ios_base::in | std::ios_base::binary);
      60              :          auto ws = WeightsReader<NT>(stream);
      61            0 :          loadedWeights.load(ws, withVersion);
      62            0 :       }
      63              : #ifdef EMBEDDEDNNUEPATH
      64              :       else {                                             // read from embedded data
      65           22 :          if (embedded::weightsFileSize != expectedSize) { // with or without version
      66            0 :             Logging::LogIt(Logging::logError) << "File " << path << " does not look like a compatible net";
      67            0 :             return false;
      68              :          }
      69           44 :          std::istringstream stream(std::string((const char*)embedded::weightsFileData, embedded::weightsFileSize), std::stringstream::binary);
      70              :          auto               ws = WeightsReader<NT>(stream);
      71           22 :          loadedWeights.load(ws, withVersion);
      72           22 :       }
      73              : #else
      74              :       else {
      75              :          Logging::LogIt(Logging::logError) << "Minic was not compiled with an embedded net";
      76              :          return false;
      77              :       }
      78              : #endif
      79              : 
      80              :       if (withVersion) {
      81           22 :          Logging::LogIt(Logging::logInfo) << "Expected net version : " << toHexString(expectedVersion);
      82           44 :          Logging::LogIt(Logging::logInfo) << "Read net version     : " << toHexString(loadedWeights.version);
      83           22 :          if (loadedWeights.version != expectedVersion) {
      84            0 :             Logging::LogIt(Logging::logError) << "File " << path << " is not a compatible version of the net";
      85            0 :             return false;
      86              :          }
      87              :       }
      88              :       return true;
      89              :    }
      90              : };
      91              : 
      92              : #endif // WITH_NNUE
        

Generated by: LCOV version 2.0-1