Line data Source code
1 : #pragma once
2 :
3 : #include "definition.hpp"
4 :
5 : #ifdef WITH_NNUE
6 :
7 : #include "inputLayer.hpp"
8 : #include "layer.hpp"
9 : #include "weightReader.hpp"
10 :
11 : template<typename NT, bool Q> struct NNUEWeights {
12 :
13 : static constexpr int nbuckets = 2;
14 :
15 : InputLayer<NT, inputLayerSize, firstInnerLayerSize, Q> w {};
16 : InputLayer<NT, inputLayerSize, firstInnerLayerSize, Q> b {};
17 :
18 : struct InnerLayer{
19 : Layer<NT, 2 * firstInnerLayerSize, 8, Q> fc0;
20 : Layer<NT, 8, 8, Q> fc1;
21 : Layer<NT, 16, 8, Q> fc2;
22 : Layer<NT, 24, 1, Q> fc3;
23 : };
24 :
25 : array1d<InnerLayer, nbuckets> innerLayer;
26 :
27 : uint32_t version {0};
28 :
29 22 : NNUEWeights<NT, Q>& load(WeightsReader<NT>& ws, bool readVersion) {
30 22 : quantizationInfo<Q>();
31 22 : if (readVersion) ws.readVersion(version);
32 22 : w.load_(ws);
33 22 : b.load_(ws);
34 66 : for (auto & l : innerLayer) l.fc0.load_(ws);
35 66 : for (auto & l : innerLayer) l.fc1.load_(ws);
36 66 : for (auto & l : innerLayer) l.fc2.load_(ws);
37 66 : for (auto & l : innerLayer) l.fc3.load_(ws);
38 22 : return *this;
39 : }
40 :
41 22 : static bool load(const std::string& path, NNUEWeights<NT, Q>& loadedWeights) {
42 : [[maybe_unused]] constexpr uint32_t expectedVersion {0xc0ffee03};
43 : [[maybe_unused]] constexpr int expectedSize {151049100}; // net size + 4 for version
44 : [[maybe_unused]] constexpr bool withVersion {true}; // used for backward compatiblity and debug
45 :
46 22 : if (path != "embedded") { // read from disk
47 : #ifndef WITHOUT_FILESYSTEM
48 : std::error_code ec;
49 0 : const auto fsize = std::filesystem::file_size(path, ec);
50 0 : if (ec) {
51 0 : Logging::LogIt(Logging::logError) << "File " << path << " is not accessible";
52 0 : return false;
53 : }
54 0 : if (fsize != expectedSize) { // with or without version
55 0 : Logging::LogIt(Logging::logError) << "File " << path << " does not look like a compatible net";
56 0 : return false;
57 : }
58 : #endif
59 0 : std::fstream stream(path, std::ios_base::in | std::ios_base::binary);
60 : auto ws = WeightsReader<NT>(stream);
61 0 : loadedWeights.load(ws, withVersion);
62 0 : }
63 : #ifdef EMBEDDEDNNUEPATH
64 : else { // read from embedded data
65 22 : if (embedded::weightsFileSize != expectedSize) { // with or without version
66 0 : Logging::LogIt(Logging::logError) << "File " << path << " does not look like a compatible net";
67 0 : return false;
68 : }
69 44 : std::istringstream stream(std::string((const char*)embedded::weightsFileData, embedded::weightsFileSize), std::stringstream::binary);
70 : auto ws = WeightsReader<NT>(stream);
71 22 : loadedWeights.load(ws, withVersion);
72 22 : }
73 : #else
74 : else {
75 : Logging::LogIt(Logging::logError) << "Minic was not compiled with an embedded net";
76 : return false;
77 : }
78 : #endif
79 :
80 : if (withVersion) {
81 22 : Logging::LogIt(Logging::logInfo) << "Expected net version : " << toHexString(expectedVersion);
82 44 : Logging::LogIt(Logging::logInfo) << "Read net version : " << toHexString(loadedWeights.version);
83 22 : if (loadedWeights.version != expectedVersion) {
84 0 : Logging::LogIt(Logging::logError) << "File " << path << " is not a compatible version of the net";
85 0 : return false;
86 : }
87 : }
88 : return true;
89 : }
90 : };
91 :
92 : #endif // WITH_NNUE
|