11// SPDX-License-Identifier: BSD-3-Clause
22
33#include < ddptensor/idtr.hpp>
4- #include < ddptensor/jit/mlir.hpp>
4+ // #include <ddptensor/jit/mlir.hpp>
55#include < ddptensor/DDPTensorImpl.hpp>
66#include < ddptensor/MPITransceiver.hpp>
77
8- #include < imex/Dialect/PTensor/IR/PTensorOps .h>
8+ #include < imex/Dialect/PTensor/IR/PTensorDefs .h>
99
1010#include < cassert>
1111#include < memory>
12+ #include < iostream>
1213
1314using container_type = std::unordered_map<id_type, std::unique_ptr<DDPTensorImpl>>;
1415
@@ -160,6 +161,51 @@ static DTypeId mlir2ddpt(const ::imex::ptensor::DType dt)
160161 };
161162}
162163
164+
165+ template <typename T, typename OP>
166+ void forall (uint64_t d, const T * cptr, const int64_t * sizes, const int64_t * strides, uint64_t nd, OP op)
167+ {
168+ auto stride = strides[d];
169+ auto sz = sizes[d];
170+ if (d==nd-1 ) {
171+ for (auto i=0 ; i<sz; ++i) {
172+ op (&cptr[i*stride]);
173+ }
174+ } else {
175+ for (auto i=0 ; i<sz; ++i) {
176+ forall (d+1 , cptr, sizes, strides, nd, op);
177+ }
178+ }
179+ }
180+
181+ bool is_contiguous (const int64_t * sizes, const int64_t * strides, uint64_t nd)
182+ {
183+ if (nd == 0 ) return true ;
184+ if (strides[nd-1 ] != 1 ) return false ;
185+ auto sz = 1 ;
186+ for (auto i=nd-1 ; i>0 ; --i) {
187+ sz *= sizes[i];
188+ if (strides[i-1 ] != sz) return false ;
189+ }
190+ return true ;
191+ }
192+
193+ void * bufferize (void * cptr, DTypeId dtype, const int64_t * sizes, const int64_t * strides, uint64_t nd, void * out)
194+ {
195+ if (is_contiguous (sizes, strides, nd)) {
196+ return cptr;
197+ } else {
198+ dispatch (dtype, cptr, [sizes, strides, nd, out](auto * ptr) {
199+ auto buff = static_cast <decltype (ptr)>(out);
200+ forall (0 , ptr, sizes, strides, nd, [&buff](const auto * in) {
201+ *buff = *in;
202+ ++buff;
203+ });
204+ });
205+ return out;
206+ }
207+ }
208+
163209extern " C" {
164210// Elementwise inplace allreduce
165211void idtr_reduce_all (void * inout, DTypeId dtype, uint64_t N, ReduceOpId op)
@@ -168,12 +214,59 @@ void idtr_reduce_all(void * inout, DTypeId dtype, uint64_t N, ReduceOpId op)
168214}
169215
170216// FIXME hard-coded for contiguous layout
171- void _idtr_reduce_all (uint64_t rank, void * data, int64_t * sizes, int64_t * strides, int dtype, int op)
217+ void _idtr_reduce_all (uint64_t rank, void * data, const int64_t * sizes, const int64_t * strides, int dtype, int op)
172218{
173219 assert (rank == 0 || strides[rank-1 ] == 1 );
174220 idtr_reduce_all (data,
175221 mlir2ddpt (static_cast <::imex::ptensor::DType>(dtype)),
176222 rank ? rank : 1 ,
177223 mlir2ddpt (static_cast <imex::ptensor::ReduceOpId>(op)));
178224}
225+
226+ void _idtr_rebalance (uint64_t rank, const int64_t * gShape , const int64_t * lOffs,
227+ void * data, const int64_t * sizes, const int64_t * strides, int dtype,
228+ uint64_t outRank, void * out, const int64_t * outSizes, const int64_t * outStrides)
229+ {
230+ assert (rank);
231+ is_contiguous (outSizes, outStrides, outRank);
232+ auto N = (int64_t )getTransceiver ()->nranks ();
233+ auto myOff = lOffs[0 ];
234+ auto mySz = sizes[0 ];
235+ auto myEnd = myOff + mySz;
236+ auto tSz = gShape [0 ];
237+ auto sz = (tSz + N - 1 ) / N;
238+ auto ddpttype = mlir2ddpt (static_cast <::imex::ptensor::DType>(dtype));
239+ auto nSz = std::accumulate (&sizes[1 ], &sizes[rank], 1 , std::multiplies<int64_t >());
240+ std::vector<int > soffs (N);
241+ std::vector<int > sszs (N, 0 );
242+ for (auto i=0 ; i<N; ++i) {
243+ auto tOff = i * sz;
244+ auto tEnd = std::min (tSz, tOff + sz);
245+ if (tEnd > myOff && tOff < myEnd) {
246+ // We have a target partition which is inside my local data
247+ // we now compute what data goes to this target partition
248+ auto start = std::max (myOff, tOff);
249+ auto end = std::min (myEnd, tEnd);
250+ soffs[i] = (int )(start - myOff) * nSz;
251+ sszs[i] = (int )(end - start) * nSz;
252+ } else {
253+ soffs[i] = i ? soffs[i-1 ] + sszs[i-1 ] : 0 ;
254+ }
255+ }
256+ // we now send our send sizes to others and receiver theirs
257+ std::vector<int > rszs (N);
258+ getTransceiver ()->alltoall (sszs.data (), 1 , INT32, rszs.data ());
259+ // For the actual alltoall we need the receive-displacements
260+ std::vector<int > roffs (N);
261+ roffs[0 ] = 0 ;
262+ for (auto i=1 ; i<N; ++i) {
263+ // compute for all i > 0
264+ roffs[i] = roffs[i-1 ] + rszs[i-1 ];
265+ }
266+ // create send buffer (might be strided!)
267+ Buffer buff (nSz * mySz * sizeof_dtype (ddpttype));
268+ auto ptr = bufferize (data, ddpttype, sizes, strides, rank, buff.data ());
269+ // Finally communicate elements
270+ getTransceiver ()->alltoall (ptr, sszs.data (), soffs.data (), ddpttype, out, rszs.data (), roffs.data ());
271+ }
179272} // extern "C"
0 commit comments