Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit ecc2cc7

Browse files
committed
adding spmd.gather
1 parent 9abdeab commit ecc2cc7

File tree

12 files changed

+202
-37
lines changed

12 files changed

+202
-37
lines changed

ddptensor/spmd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ def get_slice(obj, *args):
55

66
def get_local(obj):
77
return _cdt._get_local(obj._t, obj)
8+
9+
def gather(obj):
10+
return _cdt._gather(obj._t)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def build_cmake(self, ext):
2929
extdir.parent.mkdir(parents=True, exist_ok=True)
3030

3131
# example of cmake args
32-
config = 'Debug' if self.debug else 'Release' # 'RelWithDebInfo' #'Release'
32+
config = 'Debug'# if self.debug else 'Release' # 'RelWithDebInfo' #'Release'
3333
cmake_args = [
3434
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(extdir.parent.absolute()),
3535
'-DCMAKE_BUILD_TYPE=' + config

src/MPIMediator.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "ddptensor/UtilsAndTypes.hpp"
1010
#include "ddptensor/MPIMediator.hpp"
11+
#include "ddptensor/MPITransceiver.hpp"
1112
#include "ddptensor/NDSlice.hpp"
1213
#include "ddptensor/Factory.hpp"
1314

@@ -18,29 +19,30 @@ constexpr static int DEFER_TAG = 14714;
1819
constexpr static int EXIT_TAG = 14715;
1920
static std::mutex ak_mutex;
2021

21-
void send_to_workers(const Deferred::ptr_type & dfrd, bool self = false);
22+
void send_to_workers(const Deferred::ptr_type & dfrd, bool self, MPI_Comm comm);
2223

2324
MPIMediator::MPIMediator()
2425
: _listener(nullptr)
2526
{
26-
MPI_Comm comm = MPI_COMM_WORLD;
27+
auto c = dynamic_cast<MPITransceiver*>(theTransceiver);
28+
if(c == nullptr) throw std::runtime_error("Expected Transceiver to be MPITransceiver.");
29+
_comm = c->comm();
2730
int sz;
28-
MPI_Comm_size(comm, &sz);
31+
MPI_Comm_size(_comm, &sz);
2932
if(sz > 1)
3033
_listener = new std::thread(&MPIMediator::listen, this);
3134
}
3235

3336
MPIMediator::~MPIMediator()
3437
{
3538
std::cerr << "MPIMediator::~MPIMediator()" << std::endl;
36-
MPI_Comm comm = MPI_COMM_WORLD;
3739
int rank, sz;
38-
MPI_Comm_rank(comm, &rank);
39-
MPI_Comm_size(comm, &sz);
40+
MPI_Comm_rank(_comm, &rank);
41+
MPI_Comm_size(_comm, &sz);
4042

4143
if(is_cw() && rank == 0) to_workers(nullptr);
42-
MPI_Barrier(comm);
43-
if(!is_cw() || rank == 0) send_to_workers(nullptr, true);
44+
MPI_Barrier(_comm);
45+
if(!is_cw() || rank == 0) send_to_workers(nullptr, true, _comm);
4446
if(_listener) {
4547
_listener->join();
4648
delete _listener;
@@ -50,7 +52,6 @@ MPIMediator::~MPIMediator()
5052

5153
void MPIMediator::pull(rank_type from, id_type guid, const NDSlice & slice, void * rbuff)
5254
{
53-
MPI_Comm comm = MPI_COMM_WORLD;
5455
MPI_Request request[2];
5556
MPI_Status status[2];
5657
Buffer buff;
@@ -65,8 +66,8 @@ void MPIMediator::pull(rank_type from, id_type guid, const NDSlice & slice, void
6566
int cnt = static_cast<int>(ser.adapter().writtenBytesCount());
6667

6768
auto sz = slice.size() * Registry::get(id).get()->item_size();
68-
MPI_Irecv(rbuff, sz, MPI_CHAR, from, PUSH_TAG, comm, &request[1]);
69-
MPI_Isend(buff.data(), cnt, MPI_CHAR, from, REQ_TAG, comm, &request[0]);
69+
MPI_Irecv(rbuff, sz, MPI_CHAR, from, PUSH_TAG, _comm, &request[1]);
70+
MPI_Isend(buff.data(), cnt, MPI_CHAR, from, REQ_TAG, _comm, &request[0]);
7071
auto error_code = MPI_Waitall(2, &request[0], &status[0]);
7172
if (error_code != MPI_SUCCESS) {
7273
throw std::runtime_error("MPI_Waitall returned error code " + std::to_string(error_code));
@@ -81,10 +82,9 @@ void MPIMediator::pull(rank_type from, id_type guid, const NDSlice & slice, void
8182
if(cnt != sz) throw(std::runtime_error("Received unexpected message size."));
8283
}
8384

84-
void send_to_workers(const Deferred::ptr_type & dfrd, bool self)
85+
void send_to_workers(const Deferred::ptr_type & dfrd, bool self, MPI_Comm comm)
8586
{
8687
int rank, sz;
87-
MPI_Comm comm = MPI_COMM_WORLD;
8888
MPI_Comm_rank(comm, &rank);
8989
MPI_Comm_size(comm, &sz);
9090

@@ -126,22 +126,21 @@ void send_to_workers(const Deferred::ptr_type & dfrd, bool self)
126126

127127
void MPIMediator::to_workers(const Deferred::ptr_type & dfrd)
128128
{
129-
send_to_workers(dfrd);
129+
send_to_workers(dfrd, false, _comm);
130130
}
131131

132132
void MPIMediator::listen()
133133
{
134134
int nranks;
135-
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
135+
MPI_Comm_size(_comm, &nranks);
136136
if(nranks < 2 ) return;
137137

138138
constexpr int BSZ = 256;
139-
MPI_Comm comm = MPI_COMM_WORLD;
140139
MPI_Request request_in = MPI_REQUEST_NULL, request_out = MPI_REQUEST_NULL;
141140
Buffer rbuff;
142141
// Issue async recv request
143142
Buffer buff(BSZ);
144-
MPI_Irecv(buff.data(), buff.size(), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, comm, &request_in);
143+
MPI_Irecv(buff.data(), buff.size(), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, _comm, &request_in);
145144
do {
146145
MPI_Status status;
147146
// Wait for any request
@@ -170,15 +169,15 @@ void MPIMediator::listen()
170169

171170
// Issue async recv request for next msg
172171
buff.resize(BSZ);
173-
MPI_Irecv(buff.data(), buff.size(), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, comm, &request_in);
172+
MPI_Irecv(buff.data(), buff.size(), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, _comm, &request_in);
174173

175174
// Now find the array in question and send back its bufferized slice
176175
tensor_i::ptr_type ptr = Registry::get(id).get();
177176
// Wait for previous answer to complete so that we can re-use the buffer
178177
MPI_Wait(&request_out, MPI_STATUS_IGNORE);
179178
ptr->bufferize(slice, rbuff);
180179
if(slice.size() * ptr->item_size() != rbuff.size()) throw(std::runtime_error("Got unexpected buffer size."));
181-
MPI_Isend(rbuff.data(), rbuff.size(), MPI_CHAR, requester, PUSH_TAG, comm, &request_out);
180+
MPI_Isend(rbuff.data(), rbuff.size(), MPI_CHAR, requester, PUSH_TAG, _comm, &request_out);
182181
break;
183182
}
184183
case EXIT_TAG:
@@ -190,7 +189,7 @@ void MPIMediator::listen()
190189
if(request_in == MPI_REQUEST_NULL) {
191190
// Issue async recv request for next msg
192191
buff.resize(BSZ);
193-
MPI_Irecv(buff.data(), buff.size(), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, comm, &request_in);
192+
MPI_Irecv(buff.data(), buff.size(), MPI_CHAR, MPI_ANY_SOURCE, REQ_TAG, _comm, &request_in);
194193
}
195194
} while(true);
196195
// MPI_Cancel(&request_in);

src/MPITransceiver.cpp

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
#include <mpi.h>
44
#include <limits>
5+
#include <sstream>
56
#include "ddptensor/MPITransceiver.hpp"
67

78
MPITransceiver::MPITransceiver()
9+
: _nranks(1), _rank(0), _comm(MPI_COMM_WORLD)
810
{
911
int flag;
1012
MPI_Initialized(&flag);
@@ -21,9 +23,81 @@ MPITransceiver::MPITransceiver()
2123
throw(std::logic_error("MPI had been initialized incorrectly: not MPI_THREAD_MULTIPLE"));
2224
std::cerr << "MPI already initialized\n";
2325
}
26+
2427
int nranks, rank;
25-
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
26-
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
28+
MPI_Comm_rank(_comm, &rank);
29+
MPI_Comm parentComm;
30+
MPI_Comm_get_parent(&parentComm);
31+
32+
// rank father-of-all checks if he's requested to spawn processes:
33+
if(rank == 0 && parentComm == MPI_COMM_NULL) {
34+
// Ok, let's spawn the clients.
35+
// I need some information for the startup.
36+
// 1. Name of the executable (default is the current exe)
37+
const char * _tmp = getenv("DDPT_MPI_SPAWN");
38+
if(_tmp) {
39+
int nClientsToSpawn = atol(_tmp);
40+
_tmp = getenv("DDPT_MPI_EXECUTABLE");
41+
std::string clientExe(_tmp ? _tmp : getenv("PYTHON_EXE"));
42+
if(clientExe.empty()) throw std::runtime_error("Spawning MPI processes requires setting 'DDPT_MPI_EXECUTABLE' or 'PYTHON_EXE'");
43+
44+
// 2. arguments
45+
_tmp = getenv("DDPT_MPI_EXE_ARGS");
46+
std::vector<std::string> args;
47+
if(_tmp) {
48+
std::istringstream iss(_tmp);
49+
std::copy(std::istream_iterator<std::string>(iss),
50+
std::istream_iterator<std::string>(),
51+
std::back_inserter(args));
52+
} else {
53+
_tmp = "-c import ddptensor as dt; dt.init(True)";
54+
args.push_back("-c");
55+
args.push_back("import ddptensor as dt; dt.init(True)");
56+
}
57+
const char * clientArgs[args.size()+1];
58+
for(int i=0; i<args.size(); ++i) clientArgs[i] = args[i].c_str();
59+
clientArgs[args.size()] = nullptr;
60+
61+
// 3. Special setting for MPI_Info: hosts
62+
const char * clientHost = getenv("DDPT_MPI_HOSTS");
63+
64+
// Prepare MPI_Info object:
65+
MPI_Info clientInfo = MPI_INFO_NULL;
66+
if(clientHost) {
67+
MPI_Info_create(&clientInfo);
68+
MPI_Info_set(clientInfo, const_cast< char * >("host"), const_cast< char * >(clientHost));
69+
std::cerr << "[DDPT " << rank << "] Set MPI_Info_set(\"host\", \"" << clientHost << "\")\n";
70+
}
71+
// Now spawn the client processes:
72+
// can't use Speaker yet, need Channels to be inited
73+
std::cerr << "[DDPT " << rank << "] Spawning " << nClientsToSpawn << " MPI processes ("
74+
<< clientExe << " " << _tmp << ")" << std::endl;
75+
int* errCodes = new int[nClientsToSpawn];
76+
MPI_Comm interComm;
77+
int err = MPI_Comm_spawn(const_cast< char * >(clientExe.c_str()),
78+
const_cast< char ** >(clientArgs),
79+
nClientsToSpawn, clientInfo, 0,
80+
MPI_COMM_WORLD, &interComm, errCodes);
81+
delete [] errCodes;
82+
if (err) {
83+
// can't use Speaker yet, need Channels to be inited
84+
std::cerr << "[DDPT " << rank << "] Error in MPI_Comm_spawn. Skipping process spawning";
85+
} else {
86+
MPI_Intercomm_merge(interComm, 0, &_comm);
87+
}
88+
} // else {
89+
// No process spawning
90+
// MPI-1 situation: all clients to be started by mpiexec
91+
// _comm = MPI_COMM_WORLD;
92+
//}
93+
}
94+
if(parentComm != MPI_COMM_NULL) {
95+
// I am a child. Build intra-comm to the parent.
96+
MPI_Intercomm_merge(parentComm, 1, &_comm);
97+
}
98+
99+
MPI_Comm_size(_comm, &nranks);
100+
MPI_Comm_rank(_comm, &rank);
27101
_nranks = nranks;
28102
_rank = rank;
29103
};
@@ -73,17 +147,17 @@ static MPI_Op to_mpi(RedOpType o)
73147

74148
void MPITransceiver::barrier()
75149
{
76-
MPI_Barrier(MPI_COMM_WORLD);
150+
MPI_Barrier(_comm);
77151
}
78152

79153
void MPITransceiver::bcast(void * ptr, size_t N, rank_type root)
80154
{
81-
MPI_Bcast(ptr, N, MPI_CHAR, root, MPI_COMM_WORLD);
155+
MPI_Bcast(ptr, N, MPI_CHAR, root, _comm);
82156
}
83157

84158
void MPITransceiver::reduce_all(void * inout, DTypeId T, size_t N, RedOpType op)
85159
{
86-
MPI_Allreduce(MPI_IN_PLACE, inout, N, to_mpi(T), to_mpi(op), MPI_COMM_WORLD);
160+
MPI_Allreduce(MPI_IN_PLACE, inout, N, to_mpi(T), to_mpi(op), _comm);
87161
}
88162

89163
void MPITransceiver::alltoall(const void* buffer_send,
@@ -103,7 +177,17 @@ void MPITransceiver::alltoall(const void* buffer_send,
103177
counts_recv,
104178
displacements_recv,
105179
to_mpi(datatype_recv),
106-
MPI_COMM_WORLD);
180+
_comm);
181+
}
182+
183+
void MPITransceiver::allgather(void* buffer,
184+
const int* counts,
185+
const int* displacements,
186+
DTypeId datatype)
187+
{
188+
MPI_Allgatherv(MPI_IN_PLACE, 0, to_mpi(datatype),
189+
buffer, counts, displacements, to_mpi(datatype),
190+
_comm);
107191
}
108192

109193
void MPITransceiver::send_recv(void* buffer_send,
@@ -120,6 +204,6 @@ void MPITransceiver::send_recv(void* buffer_send,
120204
SRTAG,
121205
source,
122206
SRTAG,
123-
MPI_COMM_WORLD,
207+
_comm,
124208
MPI_STATUS_IGNORE);
125209
}

src/SetGetItem.cpp

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ namespace x {
1414
template<typename T>
1515
static ptr_type op(const NDSlice & slice, const std::shared_ptr<DPTensorX<T>> & a_ptr)
1616
{
17-
auto nd = a_ptr->shape().size();
18-
if(nd != slice.ndims())
17+
const auto & slc = a_ptr->slice();
18+
if(slc.ndims() != slice.ndims())
1919
throw std::runtime_error("Index dimensionality must match array dimensionality");
2020

21-
return operatorx<T>::mk_tx(*a_ptr.get(), slice);
21+
return operatorx<T>::mk_tx(*a_ptr.get(), slice.trim(slc.slice()));
2222
}
2323
};
2424

@@ -50,7 +50,6 @@ namespace x {
5050
NDSlice my_curr_local_slice = my_curr_needed_view.local_slice_of_rank(theTransceiver->rank());
5151

5252
if(curr_needed_norm_slice.size()) {
53-
py::tuple tpl = _make_tuple(my_curr_local_slice); //my_curr_view.slice());
5453
if(i == theTransceiver->rank()) {
5554
// copy locally
5655
auto to_v = xt::strided_view(dest/*.xarray()*/, to_xt(my_curr_local_slice));
@@ -125,6 +124,44 @@ namespace x {
125124
T * data = a_ptr->xarray().data();
126125
return py::array(std::move(slc.shape()), std::move(strides), data + off, handle);
127126
}
127+
128+
// gather
129+
// We simply create a local buffer, copy our local data to the right place
130+
// and then call AllGatherV via inplace operation.
131+
template<typename T>
132+
static py::object op(const std::shared_ptr<DPTensorX<T>> & a_ptr)
133+
{
134+
auto nranks = theTransceiver->nranks();
135+
auto rank = theTransceiver->rank();
136+
const auto & slc = a_ptr->slice();
137+
138+
// create buffer/numpy array
139+
auto res = py::array_t<T>(std::move(slc.shape()));
140+
T * ptr = reinterpret_cast<T*>(res.mutable_data());
141+
int displacements[nranks];
142+
int counts[nranks];
143+
int off = 0;
144+
// for each rank compute counts and displacements
145+
for(auto i=0; i<nranks; ++i) {
146+
uint64_t szi = slc.slice_of_rank(i).size();
147+
counts[i] = szi;
148+
displacements[i] = off;
149+
// copy our local data
150+
if(i == rank) {
151+
if(a_ptr->is_sliced()) {
152+
// if non-contiguous copy element by element
153+
const auto & av = xt::strided_view(a_ptr->xarray(), a_ptr->lslice());
154+
uint64_t i = off-1;
155+
for(auto v : av) ptr[++i] = v;
156+
} else {
157+
memcpy(&ptr[off], a_ptr->xarray().data(), szi*sizeof(T));
158+
}
159+
}
160+
off += szi;
161+
}
162+
theTransceiver->allgather(ptr, counts, displacements, DTYPE<T>::value);
163+
return res;
164+
}
128165
};
129166

130167
} // namespace x
@@ -212,5 +249,11 @@ py::object GetItem::get_local(const ddptensor & a, py::handle h)
212249
return TypeDispatch<x::SPMD>(aa, h);
213250
}
214251

252+
py::object GetItem::gather(const ddptensor & a)
253+
{
254+
const auto aa = std::move(a.get().get());
255+
return TypeDispatch<x::SPMD>(aa);
256+
}
257+
215258
FACTORY_INIT(DeferredGetItem, F_GETITEM);
216259
FACTORY_INIT(DeferredSetItem, F_SETITEM);

src/ddptensor.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ void fini()
8383
void init(bool cw)
8484
{
8585
if(inited) return;
86+
theTransceiver = new MPITransceiver();
87+
theMediator = new MPIMediator();
8688
if(cw) {
8789
_is_cw = true;
8890
if(theTransceiver->rank()) {
@@ -114,9 +116,6 @@ PYBIND11_MODULE(_ddptensor, m) {
114116
Factory::init<F_RANDOM>();
115117
Factory::init<F_SERVICE>();
116118

117-
theTransceiver = new MPITransceiver();
118-
theMediator = new MPIMediator();
119-
120119
m.doc() = "A partitioned and distributed tensor";
121120

122121
def_enums(m);
@@ -126,7 +125,8 @@ PYBIND11_MODULE(_ddptensor, m) {
126125
.def("sync", &sync)
127126
.def("myrank", &myrank)
128127
.def("_get_slice", &GetItem::get_slice)
129-
.def("_get_local", &GetItem::get_local);
128+
.def("_get_local", &GetItem::get_local)
129+
.def("_gather", &GetItem::gather);
130130

131131
py::class_<Creator>(m, "Creator")
132132
.def("create_from_shape", &Creator::create_from_shape)

0 commit comments

Comments
 (0)