Line data Source code
1 : #pragma once
2 :
3 : #include "definition.hpp"
4 :
5 : #ifdef WITH_NNUE
6 :
7 : // NT is the network type as written inside the binary file
8 : template<typename NT>
9 : struct WeightsReader {
10 : std::istream* file = nullptr;
11 :
12 22 : WeightsReader<NT>& readVersion(uint32_t& version) {
13 22 : assert(file);
14 22 : file->read((char*)&version, sizeof(uint32_t));
15 22 : return *this;
16 : }
17 :
18 : template<typename T>
19 176 : WeightsReader<NT>& streamW(T* dst, const size_t request, [[maybe_unused]] size_t dim0, [[maybe_unused]] size_t dim1) {
20 176 : assert(file);
21 176 : Logging::LogIt(Logging::logInfo) << "Reading inner weight";
22 : // we will get min and max weight for display purpose
23 176 : NT minW = std::numeric_limits<NT>::max();
24 176 : NT maxW = std::numeric_limits<NT>::lowest();
25 176 : array1d<char, sizeof(NT)> singleElement {};
26 :
27 280016 : for (size_t i(0); i < request; ++i) {
28 279840 : file->read(singleElement.data(), singleElement.size());
29 : NT tmp {0};
30 : std::memcpy(&tmp, singleElement.data(), singleElement.size());
31 : // update min/max
32 279840 : minW = std::min(minW, tmp);
33 279840 : maxW = std::max(maxW, tmp);
34 : #ifdef USE_SIMD_INTRIN
35 : // transpose data ///@todo as this is default now, do this in trainer ...
36 279840 : const size_t j = (i % dim1) * dim0 + i / dim1;
37 : #else
38 : const size_t j = i;
39 : #endif
40 279840 : dst[j] = static_cast<T>(tmp);
41 : }
42 176 : Logging::LogIt(Logging::logInfo) << "Weight in [" << minW << ", " << maxW << "]";
43 176 : return *this;
44 : }
45 :
46 : template<typename T, bool Q>
47 44 : WeightsReader<NT>& streamWI(T* dst, const size_t request) {
48 44 : assert(file);
49 44 : Logging::LogIt(Logging::logInfo) << "Reading input weight";
50 : // we will get min and max weight for display purpose
51 44 : NT minW = std::numeric_limits<NT>::max();
52 44 : NT maxW = std::numeric_limits<NT>::lowest();
53 44 : array1d<char, sizeof(NT)> singleElement {};
54 : const NT Wscale = Quantization<Q>::scale;
55 : // read each weight one by one, and scale them if quantization is active
56 830472236 : for (size_t i(0); i < request; ++i) {
57 830472192 : file->read(singleElement.data(), singleElement.size());
58 : NT tmp {0};
59 : std::memcpy(&tmp, singleElement.data(), singleElement.size());
60 : // update min/max
61 830472192 : minW = std::min(minW, tmp);
62 830472192 : maxW = std::max(maxW, tmp);
63 : // if quantization is active and we overflow, just clamp and warn
64 830472192 : if (Q && Abs(tmp * Wscale) > (NT)std::numeric_limits<T>::max()) {
65 : NT tmp2 = tmp;
66 0 : tmp = std::clamp(tmp2 * Wscale, (NT)std::numeric_limits<T>::lowest(), (NT)std::numeric_limits<T>::max());
67 0 : Logging::LogIt(Logging::logWarn) << "Overflow weight " << tmp2 << " -> " << tmp;
68 : }
69 : else {
70 830472192 : tmp = tmp * Wscale;
71 : }
72 830472192 : dst[i] = static_cast<T>(Quantization<Q>::round(tmp));
73 : }
74 44 : Logging::LogIt(Logging::logInfo) << "Weight in [" << minW << ", " << maxW << "]";
75 44 : return *this;
76 : }
77 :
78 : template<typename T>
79 176 : WeightsReader<NT>& streamB(T* dst, const size_t request) {
80 176 : assert(file);
81 176 : Logging::LogIt(Logging::logInfo) << "Reading inner bias";
82 : // we will get min and max bias for display purpose
83 176 : NT minB = std::numeric_limits<NT>::max();
84 176 : NT maxB = std::numeric_limits<NT>::lowest();
85 176 : array1d<char, sizeof(NT)> singleElement {};
86 : // read each bias one by one, and scale them if quantization is active
87 1276 : for (size_t i(0); i < request; ++i) {
88 1100 : file->read(singleElement.data(), singleElement.size());
89 : NT tmp {0};
90 : std::memcpy(&tmp, singleElement.data(), singleElement.size());
91 : // update min/max
92 1100 : minB = std::min(minB, tmp);
93 1100 : maxB = std::max(maxB, tmp);
94 1100 : dst[i] = static_cast<T>(tmp);
95 : }
96 176 : Logging::LogIt(Logging::logInfo) << "Bias in [" << minB << ", " << maxB << "]";
97 176 : return *this;
98 : }
99 :
100 : template<typename T, bool Q>
101 44 : WeightsReader<NT>& streamBI(T* dst, const size_t request) {
102 44 : assert(file);
103 44 : Logging::LogIt(Logging::logInfo) << "Reading input bias";
104 : // we will get min and max bias for display purpose
105 44 : NT minB = std::numeric_limits<NT>::max();
106 44 : NT maxB = std::numeric_limits<NT>::lowest();
107 44 : array1d<char, sizeof(NT)> singleElement {};
108 : const NT Bscale = Quantization<Q>::scale;
109 : // read each bias one by one, and scale them if quantization is active
110 16940 : for (size_t i(0); i < request; ++i) {
111 16896 : file->read(singleElement.data(), singleElement.size());
112 : NT tmp {0};
113 : std::memcpy(&tmp, singleElement.data(), singleElement.size());
114 : // update min/max
115 16896 : minB = std::min(minB, tmp);
116 16896 : maxB = std::max(maxB, tmp);
117 : // if quantization is active and we overflow, just clamp and warn
118 16896 : if (Q && Abs(tmp * Bscale) > (NT)std::numeric_limits<T>::max()) {
119 : NT tmp2 = tmp;
120 0 : tmp = std::clamp(tmp2 * Bscale, (NT)std::numeric_limits<T>::lowest(), (NT)std::numeric_limits<T>::max());
121 0 : Logging::LogIt(Logging::logWarn) << "Overflow bias " << tmp2 << " -> " << tmp;
122 : }
123 : else {
124 16896 : tmp = tmp * Bscale;
125 : }
126 16896 : dst[i] = static_cast<T>(Quantization<Q>::round(tmp));
127 : }
128 44 : Logging::LogIt(Logging::logInfo) << "Bias in [" << minB << ", " << maxB << "]";
129 44 : return *this;
130 : }
131 :
132 22 : explicit WeightsReader(std::istream& stream): file(&stream) {}
133 : };
134 :
135 : #endif // WITH_NNUE
|