LCOV - code coverage report
Current view: top level - Source - distributed.cpp (source / functions) Coverage Total Hit
Test: coverage Lines: 0.0 % 1 0
Test Date: 2026-03-02 16:42:41 Functions: 0.0 % 1 0

            Line data    Source code
       1              : #include "distributed.h"
       2              : 
       3              : #include "searcher.hpp"
       4              : #include "searchConfig.hpp"
       5              : #include "threading.hpp"
       6              : 
       7              : #ifdef WITH_MPI
       8              : 
       9              : #include <memory>
      10              : #include <mutex>
      11              : 
      12              : #include "dynamicConfig.hpp"
      13              : #include "logging.hpp"
      14              : #include "stats.hpp"
      15              : 
      16              : namespace Distributed {
      17              : int         worldSize;
      18              : int         rank;
      19              : std::string name;
      20              : 
      21              : const MPI_Datatype TraitMpiType<bool>     ::type = MPI_CXX_BOOL;
      22              : const MPI_Datatype TraitMpiType<char>     ::type = MPI_CHAR;
      23              : const MPI_Datatype TraitMpiType<Counter>  ::type = MPI_LONG_LONG_INT;
      24              : const MPI_Datatype TraitMpiType<EntryHash>::type = MPI_CHAR; // WARNING : size must be adapted *sizeof(EntryHash)
      25              : const MPI_Datatype TraitMpiType<Move>     ::type = MPI_INT;
      26              : 
      27              : 
      28              : MPI_Comm _commTT         = MPI_COMM_NULL;
      29              : MPI_Comm _commTT2        = MPI_COMM_NULL;
      30              : MPI_Comm _commStat       = MPI_COMM_NULL;
      31              : MPI_Comm _commStat2      = MPI_COMM_NULL;
      32              : MPI_Comm _commInput      = MPI_COMM_NULL;
      33              : MPI_Comm _commMove       = MPI_COMM_NULL;
      34              : MPI_Comm _commStopFromR0 = MPI_COMM_NULL;
      35              : 
      36              : MPI_Request _requestTT         = MPI_REQUEST_NULL;
      37              : MPI_Request _requestStat       = MPI_REQUEST_NULL;
      38              : MPI_Request _requestInput      = MPI_REQUEST_NULL;
      39              : MPI_Request _requestMove       = MPI_REQUEST_NULL;
      40              : //MPI_Request _requestStopFromR0 = MPI_REQUEST_NULL;
      41              : 
      42              : MPI_Win _winStopFromR0;
      43              : 
      44              : array1d<Counter, Stats::sid_maxid> _countersBufSend;
      45              : array1d<Counter, Stats::sid_maxid> _countersBufRecv[2];
      46              : uint8_t _doubleBufferStatParity;
      47              : uint64_t _nbStatPoll;
      48              : 
      49              : const uint64_t & _ttBufSize = SearchConfig::distributedTTBufSize;
      50              : uint64_t         _ttCurPos;
      51              : const DepthType  _ttMinDepth = 3;
      52              : std::vector<EntryHash> _ttBufSend[2];
      53              : std::vector<EntryHash> _ttBufRecv;
      54              : uint8_t     _doubleBufferTTParity;
      55              : std::mutex  _ttMutex;
      56              : uint64_t    _nbTTTransfert;
      57              : uint64_t    _nbTTBufferOverruns;
      58              : 
      59              : void checkError(int err, const std::string& context) {
      60              :    if (err != MPI_SUCCESS) {
      61              :       char error_string[MPI_MAX_ERROR_STRING];
      62              :       int length;
      63              :       MPI_Error_string(err, error_string, &length);
      64              :       Logging::LogIt(Logging::logFatal) << "MPI error" << (context.empty() ? "" : " in " + context) 
      65              :                                         << ": " << error_string << " (code: " << err << ")";
      66              :    }
      67              : }
      68              : 
      69              : void init() {
      70              :    std::cout << Logging::_protocolComment[Logging::ct] << "Initializing MPI ..." << std::endl;
      71              :    int provided;
      72              :    checkError(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided));
      73              :    checkError(MPI_Comm_size(MPI_COMM_WORLD, &worldSize));
      74              :    checkError(MPI_Comm_rank(MPI_COMM_WORLD, &rank));
      75              :    char processor_name[MPI_MAX_PROCESSOR_NAME];
      76              :    int  name_len;
      77              :    checkError(MPI_Get_processor_name(processor_name, &name_len));
      78              :    name = processor_name;
      79              : 
      80              :    if (!isMainProcess()) DynamicConfig::minOutputLevel = Logging::logOff;
      81              : 
      82              :    if (provided < MPI_THREAD_MULTIPLE) { Logging::LogIt(Logging::logFatal) << "The threading support level is lesser than needed"; }
      83              : 
      84              :    checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commTT));
      85              :    checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commTT2));
      86              :    checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commStat));
      87              :    checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commStat2));
      88              :    checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commInput));
      89              :    checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commMove));
      90              :    checkError(MPI_Comm_dup(MPI_COMM_WORLD, &_commStopFromR0));
      91              : 
      92              :    // Initialize MPI requests to NULL
      93              :    _requestTT = MPI_REQUEST_NULL;
      94              :    _requestStat = MPI_REQUEST_NULL;
      95              :    _requestInput = MPI_REQUEST_NULL;
      96              :    _requestMove = MPI_REQUEST_NULL;
      97              : 
      98              :    _nbStatPoll             = 0ull;
      99              :    _doubleBufferStatParity = 0;
     100              : 
     101              :    // buffer size depends on worldsize
     102              :    _ttBufSend[0].resize(_ttBufSize);
     103              :    _ttBufSend[1].resize(_ttBufSize);
     104              :    _ttBufRecv.resize(worldSize * _ttBufSize);
     105              :    _ttCurPos      = 0ull;
     106              :    _nbTTTransfert = 0ull;
     107              :    _nbTTBufferOverruns = 0ull;
     108              :    _doubleBufferTTParity = 0;
     109              : }
     110              : 
     111              : void lateInit() {
     112              :    if (moreThanOneProcess()) {
     113              :       checkError(MPI_Win_create(&ThreadPool::instance().main().stopFlag, sizeof(bool), sizeof(bool), MPI_INFO_NULL, _commStopFromR0, &_winStopFromR0), "MPI_Win_create");
     114              :       checkError(MPI_Win_fence(0, _winStopFromR0), "MPI_Win_fence in lateInit");
     115              :    }
     116              : }
     117              : 
     118              : void finalize() {
     119              :    checkError(MPI_Comm_free(&_commTT), "MPI_Comm_free _commTT");
     120              :    checkError(MPI_Comm_free(&_commTT2), "MPI_Comm_free _commTT2");
     121              :    checkError(MPI_Comm_free(&_commStat), "MPI_Comm_free _commStat");
     122              :    checkError(MPI_Comm_free(&_commStat2), "MPI_Comm_free _commStat2");
     123              :    checkError(MPI_Comm_free(&_commInput), "MPI_Comm_free _commInput");
     124              :    checkError(MPI_Comm_free(&_commMove), "MPI_Comm_free _commMove");
     125              :    checkError(MPI_Comm_free(&_commStopFromR0), "MPI_Comm_free _commStopFromR0");
     126              : 
     127              :    if (moreThanOneProcess()) checkError(MPI_Win_free(&_winStopFromR0), "MPI_Win_free");
     128              : 
     129              :    checkError(MPI_Finalize(), "MPI_Finalize");
     130              : }
     131              : 
     132              : bool isMainProcess() { return rank == 0; }
     133              : 
     134              : bool moreThanOneProcess() {
     135              :    return worldSize > 1;
     136              : }
     137              : 
     138              : // classic busy wait
     139              : void sync(MPI_Comm& com, const std::string& msg) {
     140              :    if (!moreThanOneProcess()) return;
     141              :    Logging::LogIt(Logging::logDebug) << "syncing: " << msg;
     142              :    checkError(MPI_Barrier(com), "MPI_Barrier: " + msg);
     143              :    Logging::LogIt(Logging::logDebug) << "...done";
     144              : }
     145              : 
     146              : // "softer" wait on other process
     147              : void waitRequest(MPI_Request& req) {
     148              :    if (!moreThanOneProcess()) return;
     149              :    // don't rely on MPI to do a "passive wait", most implementations are doing a busy-wait, so use 100% cpu
     150              :    if (Distributed::isMainProcess()) { checkError(MPI_Wait(&req, MPI_STATUS_IGNORE), "MPI_Wait in waitRequest"); }
     151              :    else {
     152              :       while (true) {
     153              :          int flag;
     154              :          checkError(MPI_Test(&req, &flag, MPI_STATUS_IGNORE), "MPI_Test in waitRequest");
     155              :          if (flag) break;
     156              :          else
     157              :             std::this_thread::sleep_for(std::chrono::milliseconds(1));
     158              :       }
     159              :    }
     160              : }
     161              : 
     162              : void initStat() {
     163              :    if (!moreThanOneProcess()) return;
     164              :    _countersBufSend.fill(0ull);
     165              :    _countersBufRecv[0].fill(0ull);
     166              :    _countersBufRecv[1].fill(0ull);
     167              : }
     168              : 
     169              : void sendStat() {
     170              :    if (!moreThanOneProcess()) return;
     171              :    // launch an async reduce
     172              :    Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " sendstat";
     173              :    asyncAllReduceSum(_countersBufSend.data(), _countersBufRecv[_doubleBufferStatParity % 2].data(), Stats::sid_maxid, _requestStat, _commStat);
     174              :    Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " sendstat done";
     175              :    ++_nbStatPoll;
     176              : }
     177              : 
     178              : void pollStat() { // only called from main thread
     179              :    if (!moreThanOneProcess()) return;
     180              :    int flag = 1;
     181              :    Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " pollstat";
     182              :    checkError(MPI_Test(&_requestStat, &flag, MPI_STATUS_IGNORE), "MPI_Test in pollStat");
     183              :    // if previous comm is done, launch another one
     184              :    if (flag) {
     185              :       ++_doubleBufferStatParity;
     186              :       // gather stats from all local threads
     187              :       for (size_t k = 0; k < Stats::sid_maxid; ++k) {
     188              :          _countersBufSend[k] = 0ull;
     189              :          for (const auto& it : ThreadPool::instance()) { _countersBufSend[k] += it->stats.counters[k]; }
     190              :       }
     191              :       sendStat();
     192              :    }
     193              :    Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " pollstat done";
     194              : }
     195              : 
     196              : // get all rank to a common synchronous state at the end of search
     197              : void syncStat() { // only called from main thread
     198              :    if (!moreThanOneProcess()) return;
     199              :    Logging::LogIt(Logging::logInfo) << "Syncing stat";
     200              :    Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " syncstat";
     201              :    // wait for equilibrium
     202              :    uint64_t globalNbPoll = 0ull;
     203              :    allReduceMax(&_nbStatPoll, &globalNbPoll, 1, _commStat2);
     204              :    while (_nbStatPoll < globalNbPoll) {
     205              :       checkError(MPI_Wait(&_requestStat, MPI_STATUS_IGNORE), "MPI_Wait in syncStat loop");
     206              :       sendStat();
     207              :    }
     208              : 
     209              :    // final resync
     210              :    checkError(MPI_Wait(&_requestStat, MPI_STATUS_IGNORE), "MPI_Wait in syncStat final 1");
     211              :    pollStat();
     212              :    checkError(MPI_Wait(&_requestStat, MPI_STATUS_IGNORE), "MPI_Wait in syncStat final 2");
     213              : 
     214              :    _countersBufRecv[(_doubleBufferStatParity + 1) % 2] = _countersBufRecv[(_doubleBufferStatParity) % 2];
     215              : 
     216              :    //showStat(); // debug
     217              :    //sync(_commStat, __PRETTY_FUNCTION__);
     218              :    Logging::LogIt(Logging::logInfo) << "...ok";
     219              :    _nbStatPoll = 0;
     220              :    Logging::LogIt(Logging::logDebug) << Logging::_protocolComment[Logging::ct] << rank << " syncstat done";
     221              : }
     222              : 
     223              : void showStat() {
     224              :    // show reduced stats
     225              :    for (size_t k = 0; k < Stats::sid_maxid; ++k) {
     226              :       Logging::LogIt(Logging::logInfo) << Stats::Names[k] << " " << _countersBufRecv[(_doubleBufferStatParity + 1) % 2][k];
     227              :    }
     228              : }
     229              : 
     230              : Counter counter(Stats::StatId id) {
     231              :    // because doubleBuffer is by design not initially filled, we return locals data at the beginning)
     232              :    const Counter global = _countersBufRecv[(_doubleBufferStatParity + 1) % 2][id];
     233              :    return global != 0 ? global : ThreadPool::instance().counter(id, true);
     234              : }
     235              : 
     236              : void setEntry(const Hash h, const TT::Entry& e) {
     237              :    if (!moreThanOneProcess()) return;
     238              :    DEBUGCOUT("set entry")
     239              :    // do not share entry near leaf, their are changing to quickly
     240              :    if (e.d > _ttMinDepth) {
     241              :       DEBUGCOUT("depth ok")
     242              :       if (_ttMutex.try_lock()) { // this can be called from multiple threads of this process !
     243              :          DEBUGCOUT("lock ok")
     244              :          const uint8_t writeBuffer = _doubleBufferTTParity % 2;
     245              :          _ttBufSend[writeBuffer][_ttCurPos++] = {h, e};
     246              :          if (_ttCurPos == _ttBufSize) { // buffer is full
     247              :             DEBUGCOUT("buffer full")
     248              :             // if previous comm is done then use data and launch another one
     249              :             int flag;
     250              :             checkError(MPI_Test(&_requestTT, &flag, MPI_STATUS_IGNORE), "MPI_Test in setEntry");
     251              :             if (flag) {
     252              :                // receive previous data (if this is not the first send)
     253              :                if (_nbTTTransfert > 0) {
     254              : #ifdef DEBUG_DISTRIBUTED
     255              :                   static uint64_t received = 0;
     256              : #endif
     257              :                   DEBUGCOUT("buffer received " +std::to_string(received++))
     258              :                   for (const auto& i : _ttBufRecv) {
     259              :                      if (i.h != nullHash) { // Only process valid entries
     260              :                         TT::_setEntry(i.h, i.e); // always replace (favour data from other process)
     261              :                      }
     262              :                   }
     263              :                }
     264              :                // send current buffer (writeBuffer)
     265              :                DEBUGCOUT("sending data " + std::to_string(_nbTTTransfert))
     266              :                asyncAllGather(_ttBufSend[writeBuffer].data(), _ttBufRecv.data(), _ttBufSize * sizeof(EntryHash), _requestTT, _commTT);
     267              :                ++_nbTTTransfert;
     268              :                // switch to other buffer and reset position
     269              :                ++_doubleBufferTTParity;
     270              :                _ttCurPos = 0ull;
     271              :             }
     272              :             else {
     273              :                // Previous comm not done, reset position and overwrite old data in same buffer
     274              :                // This means we're losing TT entries - track this for monitoring
     275              :                ++_nbTTBufferOverruns;
     276              :                if (_nbTTBufferOverruns % 100 == 1) { // Log occasionally to avoid spam
     277              :                   Logging::LogIt(Logging::logWarn) << "Rank " << rank << ": TT buffer overrun #" 
     278              :                                                     << _nbTTBufferOverruns << " (MPI send slower than generation)";
     279              :                }
     280              :                DEBUGCOUT("previous comm not done, reusing buffer")
     281              :                _ttCurPos = 0ull;
     282              :             }
     283              :          }
     284              :          _ttMutex.unlock();
     285              :       } // end of lock
     286              :    }    // depth ok
     287              : }
     288              : 
     289              : // get all rank to a common synchronous state at the end of search
     290              : void syncTT() { // only called from main thread
     291              :    if (!moreThanOneProcess()) return;
     292              :    Logging::LogIt(Logging::logInfo) << "Syncing TT";
     293              :    
     294              :    // First, ensure any pending TT transfer is complete
     295              :    DEBUGCOUT("sync TT initial wait")
     296              :    if (_requestTT != MPI_REQUEST_NULL) {
     297              :       checkError(MPI_Wait(&_requestTT, MPI_STATUS_IGNORE), "MPI_Wait in syncTT initial");
     298              :       // Process received data
     299              :       if (_nbTTTransfert > 0) {
     300              :          for (const auto& i : _ttBufRecv) {
     301              :             if (i.h != nullHash) {
     302              :                TT::_setEntry(i.h, i.e);
     303              :             }
     304              :          }
     305              :       }
     306              :    }
     307              :    
     308              :    // Synchronize transfer counts across all processes
     309              :    uint64_t globalNbPoll = 0ull;
     310              :    allReduceMax(&_nbTTTransfert, &globalNbPoll, 1, _commTT2);
     311              :    DEBUGCOUT("sync TT " + std::to_string(_nbTTTransfert) + " " + std::to_string(globalNbPoll))
     312              :    
     313              :    // Make sure all processes have done the same number of transfers
     314              :    while (_nbTTTransfert < globalNbPoll) {
     315              :       DEBUGCOUT("sync TT catchup")
     316              :       // Send a partial buffer or empty buffer to catch up
     317              :       std::lock_guard<std::mutex> lock(_ttMutex);
     318              :       const uint8_t sendBuffer = _doubleBufferTTParity % 2;
     319              :       asyncAllGather(_ttBufSend[sendBuffer].data(), _ttBufRecv.data(), _ttBufSize * sizeof(EntryHash), _requestTT, _commTT);
     320              :       checkError(MPI_Wait(&_requestTT, MPI_STATUS_IGNORE), "MPI_Wait in syncTT catchup");
     321              :       ++_nbTTTransfert;
     322              :       // Process received data
     323              :       for (const auto& i : _ttBufRecv) {
     324              :          if (i.h != nullHash) {
     325              :             TT::_setEntry(i.h, i.e);
     326              :          }
     327              :       }
     328              :    }
     329              :    
     330              :    DEBUGCOUT("sync TT final barrier")
     331              :    sync(_commTT2, "syncTT final");
     332              :    
     333              :    Logging::LogIt(Logging::logInfo) << "... ok";
     334              :    
     335              :    if (_nbTTBufferOverruns > 0) {
     336              :       Logging::LogIt(Logging::logInfo) << "Rank " << rank << ": Total TT buffer overruns: " << _nbTTBufferOverruns;
     337              :    }
     338              :    
     339              :    _nbTTTransfert = 0;
     340              :    _nbTTBufferOverruns = 0;
     341              :    _ttCurPos = 0;
     342              :    _requestTT = MPI_REQUEST_NULL;
     343              : }
     344              : } // namespace Distributed
     345              : 
     346              : #else
     347              : 
     348              : namespace Distributed {
     349              : int worldSize = 1;
     350              : int rank      = 0;
     351              : 
     352              : DummyType _commTT         = 0;
     353              : DummyType _commTT2        = 0;
     354              : DummyType _commStat       = 0;
     355              : DummyType _commStat2      = 0;
     356              : DummyType _commInput      = 0;
     357              : DummyType _commMove       = 0;
     358              : DummyType _commStopFromR0 = 0;
     359              : 
     360              : DummyType _requestTT         = 0;
     361              : DummyType _requestStat       = 0;
     362              : DummyType _requestInput      = 0;
     363              : DummyType _requestMove       = 0;
     364              : //DummyType _requestStopFromR0 = 0;
     365              : 
     366              : DummyType _winStopFromR0 = 0;
     367              : 
     368            0 : Counter counter(Stats::StatId id) { return ThreadPool::instance().counter(id, true); }
     369              : 
     370              : } // namespace Distributed
     371              : 
     372              : #endif
        

Generated by: LCOV version 2.0-1