Line data Source code
1 : #include "distributed.h"
2 :
3 : #include "searcher.hpp"
4 : #include "searchConfig.hpp"
5 : #include "threading.hpp"
6 :
7 : #ifdef WITH_MPI
8 :
9 : #include <memory>
10 : #include <mutex>
11 :
12 : #include "dynamicConfig.hpp"
13 : #include "logging.hpp"
14 : #include "stats.hpp"
15 :
16 : namespace Distributed {
17 : int worldSize;
18 : int rank;
19 : std::string name;
20 :
21 : const MPI_Datatype TraitMpiType<bool> ::type = MPI_CXX_BOOL;
22 : const MPI_Datatype TraitMpiType<char> ::type = MPI_CHAR;
23 : const MPI_Datatype TraitMpiType<Counter> ::type = MPI_LONG_LONG_INT;
24 : const MPI_Datatype TraitMpiType<EntryHash>::type = MPI_CHAR; // WARNING : size must be adapted *sizeof(EntryHash)
25 : const MPI_Datatype TraitMpiType<Move> ::type = MPI_INT;
26 :
27 :
28 : MPI_Comm _commTT = MPI_COMM_NULL;
29 : MPI_Comm _commTT2 = MPI_COMM_NULL;
30 : MPI_Comm _commStat = MPI_COMM_NULL;
31 : MPI_Comm _commStat2 = MPI_COMM_NULL;
32 : MPI_Comm _commInput = MPI_COMM_NULL;
33 : MPI_Comm _commMove = MPI_COMM_NULL;
34 : MPI_Comm _commStopFromR0 = MPI_COMM_NULL;
35 :
36 : MPI_Request _requestTT = MPI_REQUEST_NULL;
37 : MPI_Request _requestStat = MPI_REQUEST_NULL;
38 : MPI_Request _requestInput = MPI_REQUEST_NULL;
39 : MPI_Request _requestMove = MPI_REQUEST_NULL;
40 : //MPI_Request _requestStopFromR0 = MPI_REQUEST_NULL;
41 :
42 : MPI_Win _winStopFromR0;
43 :
44 : array1d<Counter, Stats::sid_maxid> _countersBufSend;
45 : array1d<Counter, Stats::sid_maxid> _countersBufRecv[2];
46 : uint8_t _doubleBufferStatParity;
47 : uint64_t _nbStatPoll;
48 :
49 : const uint64_t & _ttBufSize = SearchConfig::distributedTTBufSize;
50 : uint64_t _ttCurPos;
51 : const DepthType _ttMinDepth = 3;
52 : std::vector<EntryHash> _ttBufSend[2];
53 : std::vector<EntryHash> _ttBufRecv;
54 : uint8_t _doubleBufferTTParity;
55 : std::mutex _ttMutex;
56 : uint64_t _nbTTTransfert;
57 : uint64_t _nbTTBufferOverruns;
58 :
59 : void checkError(int err, const std::string& context) {
60 : if (err != MPI_SUCCESS) {
61 : char error_string[MPI_MAX_ERROR_STRING];
62 : int length;
63 : MPI_Error_string(err, error_string, &length);
64 : Logging::LogIt(Logging::logFatal) << "MPI error" << (context.empty() ? "" : " in " + context)
65 : << ": " << error_string << " (code: " << err << ")";
66 : }
67 : }
68 :
69 : void init() {
70 : std::cout << Logging::_protocolComment[Logging::ct] << "Initializing MPI ..." << std::endl;
71 : int provided;
72 : checkError(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided));
73 : checkError(MPI_Comm_size(MPI_COMM_WORLD, &worldSize));
74 : checkError(MPI_Comm_rank(MPI_COMM_WORLD, &rank));
75 : char processor_name[MPI_MAX_PROCESSOR_NAME];
76 : int name_len;
77 : checkError(MPI_Get_processor_name(processor_name, &name_len));
78 : name = processor_name;
79 :
80 : if (!isMainProcess()) DynamicConfig::minOutputLevel = Logging::logOff;
81 :
82 : if (provided < MPI_THREAD_MULTIPLE) { Logging::LogIt(Logging::logFatal) << "The threading support level is lesser than needed"; }
83 :
84 : checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commTT));
85 : checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commTT2));
86 : checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commStat));
87 : checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commStat2));
88 : checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commInput));
89 : checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commMove));
90 : checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commStopFromR0));
91 :
92 : // Initialize MPI requests to NULL
93 : _requestTT = MPI_REQUEST_NULL;
94 : _requestStat = MPI_REQUEST_NULL;
95 : _requestInput = MPI_REQUEST_NULL;
96 : _requestMove = MPI_REQUEST_NULL;
97 :
98 : _nbStatPoll = 0ull;
99 : _doubleBufferStatParity = 0;
100 :
101 : // buffer size depends on worldsize
102 : _ttBufSend[0].resize(_ttBufSize);
103 : _ttBufSend[1].resize(_ttBufSize);
104 : _ttBufRecv.resize(worldSize * _ttBufSize);
105 : _ttCurPos = 0ull;
106 : _nbTTTransfert = 0ull;
107 : _nbTTBufferOverruns = 0ull;
108 : _doubleBufferTTParity = 0;
109 : }
110 :
111 : void lateInit() {
112 : if (moreThanOneProcess()) {
113 : checkError(MPI_Win_create(&ThreadPool::instance().main().stopFlag, sizeof(bool), sizeof(bool), MPI_INFO_NULL, _commStopFromR0, &_winStopFromR0), "MPI_Win_create");
114 : checkError(MPI_Win_fence(0, _winStopFromR0), "MPI_Win_fence in lateInit");
115 : }
116 : }
117 :
118 : void finalize() {
119 : checkError(MPI_Comm_free(&_commTT), "MPI_Comm_free _commTT");
120 : checkError(MPI_Comm_free(&_commTT2), "MPI_Comm_free _commTT2");
121 : checkError(MPI_Comm_free(&_commStat), "MPI_Comm_free _commStat");
122 : checkError(MPI_Comm_free(&_commStat2), "MPI_Comm_free _commStat2");
123 : checkError(MPI_Comm_free(&_commInput), "MPI_Comm_free _commInput");
124 : checkError(MPI_Comm_free(&_commMove), "MPI_Comm_free _commMove");
125 : checkError(MPI_Comm_free(&_commStopFromR0), "MPI_Comm_free _commStopFromR0");
126 :
127 : if (moreThanOneProcess()) checkError(MPI_Win_free(&_winStopFromR0), "MPI_Win_free");
128 :
129 : checkError(MPI_Finalize(), "MPI_Finalize");
130 : }
131 :
132 : bool isMainProcess() { return rank == 0; }
133 :
134 : bool moreThanOneProcess() {
135 : return worldSize > 1;
136 : }
137 :
138 : // classic busy wait
139 : void sync(MPI_Comm& com, const std::string& msg) {
140 : if (!moreThanOneProcess()) return;
141 : Logging::LogIt(Logging::logDebug) << "syncing: " << msg;
142 : checkError(MPI_Barrier(com), "MPI_Barrier: " + msg);
143 : Logging::LogIt(Logging::logDebug) << "...done";
144 : }
145 :
146 : // "softer" wait on other process
147 : void waitRequest(MPI_Request& req) {
148 : if (!moreThanOneProcess()) return;
149 : // don't rely on MPI to do a "passive wait", most implementations are doing a busy-wait, so use 100% cpu
150 : if (Distributed::isMainProcess()) { checkError(MPI_Wait(&req, MPI_STATUS_IGNORE), "MPI_Wait in waitRequest"); }
151 : else {
152 : while (true) {
153 : int flag;
154 : checkError(MPI_Test(&req, &flag, MPI_STATUS_IGNORE), "MPI_Test in waitRequest");
155 : if (flag) break;
156 : else
157 : std::this_thread::sleep_for(std::chrono::milliseconds(1));
158 : }
159 : }
160 : }
161 :
162 : void initStat() {
163 : if (!moreThanOneProcess()) return;
164 : _countersBufSend.fill(0ull);
165 : _countersBufRecv[0].fill(0ull);
166 : _countersBufRecv[1].fill(0ull);
167 : }
168 :
169 : void sendStat() {
170 : if (!moreThanOneProcess()) return;
171 : // launch an async reduce
172 : Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " sendstat";
173 : asyncAllReduceSum(_countersBufSend.data(), _countersBufRecv[_doubleBufferStatParity % 2].data(), Stats::sid_maxid, _requestStat, _commStat);
174 : Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " sendstat done";
175 : ++_nbStatPoll;
176 : }
177 :
178 : void pollStat() { // only called from main thread
179 : if (!moreThanOneProcess()) return;
180 : int flag = 1;
181 : Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " pollstat";
182 : checkError(MPI_Test(&_requestStat, &flag, MPI_STATUS_IGNORE), "MPI_Test in pollStat");
183 : // if previous comm is done, launch another one
184 : if (flag) {
185 : ++_doubleBufferStatParity;
186 : // gather stats from all local threads
187 : for (size_t k = 0; k < Stats::sid_maxid; ++k) {
188 : _countersBufSend[k] = 0ull;
189 : for (const auto& it : ThreadPool::instance()) { _countersBufSend[k] += it->stats.counters[k]; }
190 : }
191 : sendStat();
192 : }
193 : Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " pollstat done";
194 : }
195 :
196 : // get all rank to a common synchronous state at the end of search
197 : void syncStat() { // only called from main thread
198 : if (!moreThanOneProcess()) return;
199 : Logging::LogIt(Logging::logInfo) << "Syncing stat";
200 : Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " syncstat";
201 : // wait for equilibrium
202 : uint64_t globalNbPoll = 0ull;
203 : allReduceMax(&_nbStatPoll, &globalNbPoll, 1, _commStat2);
204 : while (_nbStatPoll < globalNbPoll) {
205 : checkError(MPI_Wait(&_requestStat, MPI_STATUS_IGNORE), "MPI_Wait in syncStat loop");
206 : sendStat();
207 : }
208 :
209 : // final resync
210 : checkError(MPI_Wait(&_requestStat, MPI_STATUS_IGNORE), "MPI_Wait in syncStat final 1");
211 : pollStat();
212 : checkError(MPI_Wait(&_requestStat, MPI_STATUS_IGNORE), "MPI_Wait in syncStat final 2");
213 :
214 : _countersBufRecv[(_doubleBufferStatParity + 1) % 2] = _countersBufRecv[(_doubleBufferStatParity) % 2];
215 :
216 : //showStat(); // debug
217 : //sync(_commStat, __PRETTY_FUNCTION__);
218 : Logging::LogIt(Logging::logInfo) << "...ok";
219 : _nbStatPoll = 0;
220 : Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " syncstat done";
221 : }
222 :
223 : void showStat() {
224 : // show reduced stats
225 : for (size_t k = 0; k < Stats::sid_maxid; ++k) {
226 : Logging::LogIt(Logging::logInfo) << Stats::Names[k] << " " << _countersBufRecv[(_doubleBufferStatParity + 1) % 2][k];
227 : }
228 : }
229 :
230 : Counter counter(Stats::StatId id) {
231 : // because doubleBuffer is by design not initially filled, we return locals data at the beginning)
232 : const Counter global = _countersBufRecv[(_doubleBufferStatParity + 1) % 2][id];
233 : return global != 0 ? global : ThreadPool::instance().counter(id, true);
234 : }
235 :
236 : void setEntry(const Hash h, const TT::Entry& e) {
237 : if (!moreThanOneProcess()) return;
238 : DEBUGCOUT("set entry")
239 : // do not share entry near leaf, their are changing to quickly
240 : if (e.d > _ttMinDepth) {
241 : DEBUGCOUT("depth ok")
242 : if (_ttMutex.try_lock()) { // this can be called from multiple threads of this process !
243 : DEBUGCOUT("lock ok")
244 : const uint8_t writeBuffer = _doubleBufferTTParity % 2;
245 : _ttBufSend[writeBuffer][_ttCurPos++] = {h, e};
246 : if (_ttCurPos == _ttBufSize) { // buffer is full
247 : DEBUGCOUT("buffer full")
248 : // if previous comm is done then use data and launch another one
249 : int flag;
250 : checkError(MPI_Test(&_requestTT, &flag, MPI_STATUS_IGNORE), "MPI_Test in setEntry");
251 : if (flag) {
252 : // receive previous data (if this is not the first send)
253 : if (_nbTTTransfert > 0) {
254 : #ifdef DEBUG_DISTRIBUTED
255 : static uint64_t received = 0;
256 : #endif
257 : DEBUGCOUT("buffer received " +std::to_string(received++))
258 : for (const auto& i : _ttBufRecv) {
259 : if (i.h != nullHash) { // Only process valid entries
260 : TT::_setEntry(i.h, i.e); // always replace (favour data from other process)
261 : }
262 : }
263 : }
264 : // send current buffer (writeBuffer)
265 : DEBUGCOUT("sending data " + std::to_string(_nbTTTransfert))
266 : asyncAllGather(_ttBufSend[writeBuffer].data(), _ttBufRecv.data(), _ttBufSize * sizeof(EntryHash), _requestTT, _commTT);
267 : ++_nbTTTransfert;
268 : // switch to other buffer and reset position
269 : ++_doubleBufferTTParity;
270 : _ttCurPos = 0ull;
271 : }
272 : else {
273 : // Previous comm not done, reset position and overwrite old data in same buffer
274 : // This means we're losing TT entries - track this for monitoring
275 : ++_nbTTBufferOverruns;
276 : if (_nbTTBufferOverruns % 100 == 1) { // Log occasionally to avoid spam
277 : Logging::LogIt(Logging::logWarn) << "Rank " << rank << ": TT buffer overrun #"
278 : << _nbTTBufferOverruns << " (MPI send slower than generation)";
279 : }
280 : DEBUGCOUT("previous comm not done, reusing buffer")
281 : _ttCurPos = 0ull;
282 : }
283 : }
284 : _ttMutex.unlock();
285 : } // end of lock
286 : } // depth ok
287 : }
288 :
289 : // get all rank to a common synchronous state at the end of search
290 : void syncTT() { // only called from main thread
291 : if (!moreThanOneProcess()) return;
292 : Logging::LogIt(Logging::logInfo) << "Syncing TT";
293 :
294 : // First, ensure any pending TT transfer is complete
295 : DEBUGCOUT("sync TT initial wait")
296 : if (_requestTT != MPI_REQUEST_NULL) {
297 : checkError(MPI_Wait(&_requestTT, MPI_STATUS_IGNORE), "MPI_Wait in syncTT initial");
298 : // Process received data
299 : if (_nbTTTransfert > 0) {
300 : for (const auto& i : _ttBufRecv) {
301 : if (i.h != nullHash) {
302 : TT::_setEntry(i.h, i.e);
303 : }
304 : }
305 : }
306 : }
307 :
308 : // Synchronize transfer counts across all processes
309 : uint64_t globalNbPoll = 0ull;
310 : allReduceMax(&_nbTTTransfert, &globalNbPoll, 1, _commTT2);
311 : DEBUGCOUT("sync TT " + std::to_string(_nbTTTransfert) + " " + std::to_string(globalNbPoll))
312 :
313 : // Make sure all processes have done the same number of transfers
314 : while (_nbTTTransfert < globalNbPoll) {
315 : DEBUGCOUT("sync TT catchup")
316 : // Send a partial buffer or empty buffer to catch up
317 : std::lock_guard<std::mutex> lock(_ttMutex);
318 : const uint8_t sendBuffer = _doubleBufferTTParity % 2;
319 : asyncAllGather(_ttBufSend[sendBuffer].data(), _ttBufRecv.data(), _ttBufSize * sizeof(EntryHash), _requestTT, _commTT);
320 : checkError(MPI_Wait(&_requestTT, MPI_STATUS_IGNORE), "MPI_Wait in syncTT catchup");
321 : ++_nbTTTransfert;
322 : // Process received data
323 : for (const auto& i : _ttBufRecv) {
324 : if (i.h != nullHash) {
325 : TT::_setEntry(i.h, i.e);
326 : }
327 : }
328 : }
329 :
330 : DEBUGCOUT("sync TT final barrier")
331 : sync(_commTT2, "syncTT final");
332 :
333 : Logging::LogIt(Logging::logInfo) << "... ok";
334 :
335 : if (_nbTTBufferOverruns > 0) {
336 : Logging::LogIt(Logging::logInfo) << "Rank " << rank << ": Total TT buffer overruns: " << _nbTTBufferOverruns;
337 : }
338 :
339 : _nbTTTransfert = 0;
340 : _nbTTBufferOverruns = 0;
341 : _ttCurPos = 0;
342 : _requestTT = MPI_REQUEST_NULL;
343 : }
344 : } // namespace Distributed
345 :
346 : #else
347 :
348 : namespace Distributed {
349 : int worldSize = 1;
350 : int rank = 0;
351 :
352 : DummyType _commTT = 0;
353 : DummyType _commTT2 = 0;
354 : DummyType _commStat = 0;
355 : DummyType _commStat2 = 0;
356 : DummyType _commInput = 0;
357 : DummyType _commMove = 0;
358 : DummyType _commStopFromR0 = 0;
359 :
360 : DummyType _requestTT = 0;
361 : DummyType _requestStat = 0;
362 : DummyType _requestInput = 0;
363 : DummyType _requestMove = 0;
364 : //DummyType _requestStopFromR0 = 0;
365 :
366 : DummyType _winStopFromR0 = 0;
367 :
368 0 : Counter counter(Stats::StatId id) { return ThreadPool::instance().counter(id, true); }
369 :
370 : } // namespace Distributed
371 :
372 : #endif
|