88
99#include " ddptensor/UtilsAndTypes.hpp"
1010#include " ddptensor/MPIMediator.hpp"
11+ #include " ddptensor/MPITransceiver.hpp"
1112#include " ddptensor/NDSlice.hpp"
1213#include " ddptensor/Factory.hpp"
1314
@@ -18,29 +19,30 @@ constexpr static int DEFER_TAG = 14714;
1819constexpr static int EXIT_TAG = 14715 ;
1920static std::mutex ak_mutex;
2021
21- void send_to_workers (const Deferred::ptr_type & dfrd, bool self = false );
22+ void send_to_workers (const Deferred::ptr_type & dfrd, bool self, MPI_Comm comm );
2223
2324MPIMediator::MPIMediator ()
2425 : _listener(nullptr )
2526{
26- MPI_Comm comm = MPI_COMM_WORLD;
27+ auto c = dynamic_cast <MPITransceiver*>(theTransceiver);
28+ if (c == nullptr ) throw std::runtime_error (" Expected Transceiver to be MPITransceiver." );
29+ _comm = c->comm ();
2730 int sz;
28- MPI_Comm_size (comm , &sz);
31+ MPI_Comm_size (_comm , &sz);
2932 if (sz > 1 )
3033 _listener = new std::thread (&MPIMediator::listen, this );
3134}
3235
3336MPIMediator::~MPIMediator ()
3437{
3538 std::cerr << " MPIMediator::~MPIMediator()" << std::endl;
36- MPI_Comm comm = MPI_COMM_WORLD;
3739 int rank, sz;
38- MPI_Comm_rank (comm , &rank);
39- MPI_Comm_size (comm , &sz);
40+ MPI_Comm_rank (_comm , &rank);
41+ MPI_Comm_size (_comm , &sz);
4042
4143 if (is_cw () && rank == 0 ) to_workers (nullptr );
42- MPI_Barrier (comm );
43- if (!is_cw () || rank == 0 ) send_to_workers (nullptr , true );
44+ MPI_Barrier (_comm );
45+ if (!is_cw () || rank == 0 ) send_to_workers (nullptr , true , _comm );
4446 if (_listener) {
4547 _listener->join ();
4648 delete _listener;
@@ -50,7 +52,6 @@ MPIMediator::~MPIMediator()
5052
5153void MPIMediator::pull (rank_type from, id_type guid, const NDSlice & slice, void * rbuff)
5254{
53- MPI_Comm comm = MPI_COMM_WORLD;
5455 MPI_Request request[2 ];
5556 MPI_Status status[2 ];
5657 Buffer buff;
@@ -65,8 +66,8 @@ void MPIMediator::pull(rank_type from, id_type guid, const NDSlice & slice, void
6566 int cnt = static_cast <int >(ser.adapter ().writtenBytesCount ());
6667
6768 auto sz = slice.size () * Registry::get (id).get ()->item_size ();
68- MPI_Irecv (rbuff, sz, MPI_CHAR, from, PUSH_TAG, comm , &request[1 ]);
69- MPI_Isend (buff.data (), cnt, MPI_CHAR, from, REQ_TAG, comm , &request[0 ]);
69+ MPI_Irecv (rbuff, sz, MPI_CHAR, from, PUSH_TAG, _comm , &request[1 ]);
70+ MPI_Isend (buff.data (), cnt, MPI_CHAR, from, REQ_TAG, _comm , &request[0 ]);
7071 auto error_code = MPI_Waitall (2 , &request[0 ], &status[0 ]);
7172 if (error_code != MPI_SUCCESS) {
7273 throw std::runtime_error (" MPI_Waitall returned error code " + std::to_string (error_code));
@@ -81,10 +82,9 @@ void MPIMediator::pull(rank_type from, id_type guid, const NDSlice & slice, void
8182 if (cnt != sz) throw (std::runtime_error (" Received unexpected message size." ));
8283}
8384
84- void send_to_workers (const Deferred::ptr_type & dfrd, bool self)
85+ void send_to_workers (const Deferred::ptr_type & dfrd, bool self, MPI_Comm comm )
8586{
8687 int rank, sz;
87- MPI_Comm comm = MPI_COMM_WORLD;
8888 MPI_Comm_rank (comm, &rank);
8989 MPI_Comm_size (comm, &sz);
9090
@@ -126,22 +126,21 @@ void send_to_workers(const Deferred::ptr_type & dfrd, bool self)
126126
127127void MPIMediator::to_workers (const Deferred::ptr_type & dfrd)
128128{
129- send_to_workers (dfrd);
129+ send_to_workers (dfrd, false , _comm );
130130}
131131
132132void MPIMediator::listen ()
133133{
134134 int nranks;
135- MPI_Comm_size (MPI_COMM_WORLD , &nranks);
135+ MPI_Comm_size (_comm , &nranks);
136136 if (nranks < 2 ) return ;
137137
138138 constexpr int BSZ = 256 ;
139- MPI_Comm comm = MPI_COMM_WORLD;
140139 MPI_Request request_in = MPI_REQUEST_NULL, request_out = MPI_REQUEST_NULL;
141140 Buffer rbuff;
142141 // Issue async recv request
143142 Buffer buff (BSZ);
144- MPI_Irecv (buff.data (), buff.size (), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, comm , &request_in);
143+ MPI_Irecv (buff.data (), buff.size (), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, _comm , &request_in);
145144 do {
146145 MPI_Status status;
147146 // Wait for any request
@@ -170,15 +169,15 @@ void MPIMediator::listen()
170169
171170 // Issue async recv request for next msg
172171 buff.resize (BSZ);
173- MPI_Irecv (buff.data (), buff.size (), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, comm , &request_in);
172+ MPI_Irecv (buff.data (), buff.size (), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, _comm , &request_in);
174173
175174 // Now find the array in question and send back its bufferized slice
176175 tensor_i::ptr_type ptr = Registry::get (id).get ();
177176 // Wait for previous answer to complete so that we can re-use the buffer
178177 MPI_Wait (&request_out, MPI_STATUS_IGNORE);
179178 ptr->bufferize (slice, rbuff);
180179 if (slice.size () * ptr->item_size () != rbuff.size ()) throw (std::runtime_error (" Got unexpected buffer size." ));
181- MPI_Isend (rbuff.data (), rbuff.size (), MPI_CHAR, requester, PUSH_TAG, comm , &request_out);
180+ MPI_Isend (rbuff.data (), rbuff.size (), MPI_CHAR, requester, PUSH_TAG, _comm , &request_out);
182181 break ;
183182 }
184183 case EXIT_TAG:
@@ -190,7 +189,7 @@ void MPIMediator::listen()
190189 if (request_in == MPI_REQUEST_NULL) {
191190 // Issue async recv request for next msg
192191 buff.resize (BSZ);
193- MPI_Irecv (buff.data (), buff.size (), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, comm , &request_in);
192+ MPI_Irecv (buff.data (), buff.size (), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, _comm , &request_in);
194193 }
195194 } while (true );
196195 // MPI_Cancel(&request_in);
0 commit comments