Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 08a3baf

Browse files
committed
introducing idtr::rebalance; fixing calls to imex::ptensor::create
1 parent 985dfd8 commit 08a3baf

File tree

13 files changed

+196
-83
lines changed

13 files changed

+196
-83
lines changed

CMakeLists.txt

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,23 +102,26 @@ set(DDPTSrcs
102102
${PROJECT_SOURCE_DIR}/src/SetGetItem.cpp
103103
${PROJECT_SOURCE_DIR}/src/Sorting.cpp
104104
)
105-
set(IDTRSrcs
106-
${PROJECT_SOURCE_DIR}/src/idtr.cpp
105+
set(RTSrcs
107106
${PROJECT_SOURCE_DIR}/src/CollComm.cpp
108107
${PROJECT_SOURCE_DIR}/src/DDPTensorImpl.cpp
109108
${PROJECT_SOURCE_DIR}/src/Deferred.cpp
110109
${PROJECT_SOURCE_DIR}/src/Factory.cpp
111110
${PROJECT_SOURCE_DIR}/src/Mediator.cpp
112111
${PROJECT_SOURCE_DIR}/src/MPIMediator.cpp
113-
${PROJECT_SOURCE_DIR}/src/MPITransceiver.cpp
114112
${PROJECT_SOURCE_DIR}/src/Registry.cpp
115-
${PROJECT_SOURCE_DIR}/src/Transceiver.cpp
116113
${PROJECT_SOURCE_DIR}/src/jit/mlir.cpp
117114
)
115+
set(IDTRSrcs
116+
${PROJECT_SOURCE_DIR}/src/idtr.cpp
117+
${PROJECT_SOURCE_DIR}/src/MPITransceiver.cpp
118+
${PROJECT_SOURCE_DIR}/src/Transceiver.cpp
119+
)
118120

119121
pybind11_add_module(_ddptensor MODULE ${DDPTSrcs} ${Hpps})
122+
add_library(_ddpt_rt SHARED ${RTSrcs} ${Hpps})
120123
add_library(idtr SHARED ${IDTRSrcs} ${Hpps})
121-
set(AllTargets _ddptensor idtr)
124+
set(AllTargets _ddptensor _ddpt_rt idtr)
122125

123126
add_compile_definitions(USE_MKL=1)
124127
add_compile_options("-ftemplate-backtrace-limit=0")
@@ -144,17 +147,24 @@ get_property(imex_all_libs GLOBAL PROPERTY IMEX_ALL_LIBS)
144147

145148
#llvm_update_compile_flags(_ddpttensor)
146149
target_link_directories(_ddptensor PRIVATE ${CONDA_PREFIX}/lib)
147-
target_link_directories(idtr PRIVATE ${CONDA_PREFIX}/lib ${IMEX_INSTALL_PREFIX}/lib)
150+
target_link_directories(_ddpt_rt PRIVATE ${CONDA_PREFIX}/lib) # ${IMEX_INSTALL_PREFIX}/lib)
151+
target_link_directories(idtr PRIVATE ${CONDA_PREFIX}/lib)
148152

149153
target_link_libraries(_ddptensor PRIVATE
150154
# ${MKL_LIBRARIES}
151155
# tbb
156+
_ddpt_rt
152157
idtr
153158
)
154159
target_link_libraries(idtr PRIVATE
155160
${MPI_C_LIBRARIES}
156161
# ${MKL_LIBRARIES}
157162
tbb
163+
)
164+
target_link_libraries(_ddpt_rt PRIVATE
165+
${MPI_C_LIBRARIES}
166+
# ${MKL_LIBRARIES}
167+
tbb
158168
IMEXPTensorDialect
159169
IMEXPTensorTransforms
160170
IMEXPTensorToLinalg

src/Creator.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "ddptensor/Deferred.hpp"
44
#include "ddptensor/Factory.hpp"
55
#include "ddptensor/DDPTensorImpl.hpp"
6+
#include "ddptensor/Transceiver.hpp"
67

78
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
89
#include <imex/Utils/PassUtils.h>
@@ -153,11 +154,10 @@ struct DeferredFull : public Deferred
153154
::imex::ptensor::DType dtyp;
154155
::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
155156

156-
auto dmy = ::imex::createInt<1>(loc, builder, 0);
157157
auto team = ::imex::createIndex(loc, builder, reinterpret_cast<uint64_t>(getTransceiver()));
158158

159159
dm.addVal(this->guid(),
160-
builder.create<::imex::ptensor::CreateOp>(loc, shp, dtyp, val, dmy, team),
160+
builder.create<::imex::ptensor::CreateOp>(loc, shp, dtyp, val, nullptr, team),
161161
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
162162
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
163163
assert(rank == this->_shape.size());
@@ -207,13 +207,10 @@ struct DeferredArange : public Deferred
207207
auto start = ::imex::createInt(loc, builder, _start);
208208
auto stop = ::imex::createInt(loc, builder, _end);
209209
auto step = ::imex::createInt(loc, builder, _step);
210-
auto dtype = builder.getI64Type(); // FIXME
211-
auto artype = ::imex::ptensor::PTensorType::get(builder.getContext(), 1, dtype, false);
212-
auto dmy = ::imex::createInt<1>(loc, builder, 0);
213210
// ::mlir::Value
214211
auto team = ::imex::createIndex(loc, builder, reinterpret_cast<uint64_t>(getTransceiver()));
215212
dm.addVal(this->guid(),
216-
builder.create<::imex::ptensor::ARangeOp>(loc, artype, start, stop, step, dmy, team),
213+
builder.create<::imex::ptensor::ARangeOp>(loc, start, stop, step, nullptr, team),
217214
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
218215
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
219216
assert(rank == 1);

src/DDPTensorImpl.cpp

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include <ddptensor/DDPTensorImpl.hpp>
77
#include <ddptensor/CppTypes.hpp>
8+
#include <ddptensor/Transceiver.hpp>
89

910
#include <algorithm>
1011

@@ -152,21 +153,6 @@ int64_t DDPTensorImpl::__int__() const
152153
return res;
153154
}
154155

155-
void DDPTensorImpl::bufferize(const NDSlice & slc, Buffer & buff) const
156-
{
157-
// FIXME slices/strides
158-
#if 0
159-
if(slc.size() <= 0) return;
160-
NDSlice lslice = NDSlice(slice().tile_shape()).slice(slc);
161-
#endif
162-
assert(_strides[0] == 1);
163-
auto pos = buff.size();
164-
auto sz = size()*item_size();
165-
buff.resize(pos + sz);
166-
void * out = buff.data() + pos;
167-
dispatch(_dtype, _aligned, [this, sz, out](auto * ptr) { memcpy(out, ptr + this->_offset, sz); });
168-
}
169-
170156
void DDPTensorImpl::add_to_args(std::vector<void*> & args, int ndims)
171157
{
172158
assert(ndims == this->ndims());
@@ -180,24 +166,25 @@ void DDPTensorImpl::add_to_args(std::vector<void*> & args, int ndims)
180166
args.push_back(buff);
181167
// second the team
182168
args.push_back(reinterpret_cast<void*>(1));
183-
if(ndims > 0)
184-
// global shape third
185-
buff = new intptr_t[dtensor_sz(1)];
186-
buff[0] = reinterpret_cast<intptr_t>(_gs_allocated);
187-
buff[1] = reinterpret_cast<intptr_t>(_gs_aligned);
188-
buff[2] = 0;
189-
buff[3] = ndims;
190-
buff[4] = 1;
191-
args.push_back(buff);
192-
assert(5 == memref_sz(1));
193-
// local offsets last
194-
buff = new intptr_t[dtensor_sz(1)];
195-
buff[0] = reinterpret_cast<intptr_t>(_lo_allocated);
196-
buff[1] = reinterpret_cast<intptr_t>(_lo_aligned);
197-
buff[2] = 0;
198-
buff[3] = ndims;
199-
buff[4] = 1;
200-
args.push_back(buff);
169+
if(ndims > 0) {
170+
// global shape third
171+
buff = new intptr_t[dtensor_sz(1)];
172+
buff[0] = reinterpret_cast<intptr_t>(_gs_allocated);
173+
buff[1] = reinterpret_cast<intptr_t>(_gs_aligned);
174+
buff[2] = 0;
175+
buff[3] = ndims;
176+
buff[4] = 1;
177+
args.push_back(buff);
178+
assert(5 == memref_sz(1));
179+
// local offsets last
180+
buff = new intptr_t[dtensor_sz(1)];
181+
buff[0] = reinterpret_cast<intptr_t>(_lo_allocated);
182+
buff[1] = reinterpret_cast<intptr_t>(_lo_aligned);
183+
buff[2] = 0;
184+
buff[3] = ndims;
185+
buff[4] = 1;
186+
args.push_back(buff);
187+
}
201188
}
202189

203190
void DDPTensorImpl::replicate()

src/MPITransceiver.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,23 +168,36 @@ void MPITransceiver::reduce_all(void * inout, DTypeId T, size_t N, RedOpType op)
168168
void MPITransceiver::alltoall(const void* buffer_send,
169169
const int* counts_send,
170170
const int* displacements_send,
171-
DTypeId datatype_send,
171+
DTypeId datatype,
172172
void* buffer_recv,
173173
const int* counts_recv,
174-
const int* displacements_recv,
175-
DTypeId datatype_recv)
174+
const int* displacements_recv)
176175
{
177176
MPI_Alltoallv(buffer_send,
178177
counts_send,
179178
displacements_send,
180-
to_mpi(datatype_send),
179+
to_mpi(datatype),
181180
buffer_recv,
182181
counts_recv,
183182
displacements_recv,
184-
to_mpi(datatype_recv),
183+
to_mpi(datatype),
185184
_comm);
186185
}
187186

187+
void MPITransceiver::alltoall(const void* buffer_send,
188+
const int counts,
189+
DTypeId datatype,
190+
void* buffer_recv)
191+
{
192+
MPI_Alltoall(buffer_send,
193+
counts,
194+
to_mpi(datatype),
195+
buffer_recv,
196+
counts,
197+
to_mpi(datatype),
198+
_comm);
199+
}
200+
188201
void MPITransceiver::gather(void* buffer,
189202
const int* counts,
190203
const int* displacements,

src/SetGetItem.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "ddptensor/DDPTensorImpl.hpp"
55
#include "ddptensor/Mediator.hpp"
66
#include "ddptensor/Factory.hpp"
7+
#include "ddptensor/NDSlice.hpp"
78

89
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
910
#include <imex/Dialect/Dist/IR/DistOps.h>

src/idtr.cpp

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
// SPDX-License-Identifier: BSD-3-Clause
22

33
#include <ddptensor/idtr.hpp>
4-
#include <ddptensor/jit/mlir.hpp>
4+
// #include <ddptensor/jit/mlir.hpp>
55
#include <ddptensor/DDPTensorImpl.hpp>
66
#include <ddptensor/MPITransceiver.hpp>
77

8-
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
8+
#include <imex/Dialect/PTensor/IR/PTensorDefs.h>
99

1010
#include <cassert>
1111
#include <memory>
12+
#include <iostream>
1213

1314
using container_type = std::unordered_map<id_type, std::unique_ptr<DDPTensorImpl>>;
1415

@@ -160,6 +161,51 @@ static DTypeId mlir2ddpt(const ::imex::ptensor::DType dt)
160161
};
161162
}
162163

164+
165+
template<typename T, typename OP>
166+
void forall(uint64_t d, const T * cptr, const int64_t * sizes, const int64_t * strides, uint64_t nd, OP op)
167+
{
168+
auto stride = strides[d];
169+
auto sz = sizes[d];
170+
if(d==nd-1) {
171+
for(auto i=0; i<sz; ++i) {
172+
op(&cptr[i*stride]);
173+
}
174+
} else {
175+
for(auto i=0; i<sz; ++i) {
176+
forall(d+1, cptr, sizes, strides, nd, op);
177+
}
178+
}
179+
}
180+
181+
bool is_contiguous(const int64_t * sizes, const int64_t * strides, uint64_t nd)
182+
{
183+
if(nd == 0) return true;
184+
if(strides[nd-1] != 1) return false;
185+
auto sz = 1;
186+
for(auto i=nd-1; i>0; --i) {
187+
sz *= sizes[i];
188+
if(strides[i-1] != sz) return false;
189+
}
190+
return true;
191+
}
192+
193+
void * bufferize(void * cptr, DTypeId dtype, const int64_t * sizes, const int64_t * strides, uint64_t nd, void * out)
194+
{
195+
if(is_contiguous(sizes, strides, nd)) {
196+
return cptr;
197+
} else {
198+
dispatch(dtype, cptr, [sizes, strides, nd, out](auto * ptr) {
199+
auto buff = static_cast<decltype(ptr)>(out);
200+
forall(0, ptr, sizes, strides, nd, [&buff](const auto * in) {
201+
*buff = *in;
202+
++buff;
203+
});
204+
});
205+
return out;
206+
}
207+
}
208+
163209
extern "C" {
164210
// Elementwise inplace allreduce
165211
void idtr_reduce_all(void * inout, DTypeId dtype, uint64_t N, ReduceOpId op)
@@ -168,12 +214,59 @@ void idtr_reduce_all(void * inout, DTypeId dtype, uint64_t N, ReduceOpId op)
168214
}
169215

170216
// FIXME hard-coded for contiguous layout
171-
void _idtr_reduce_all(uint64_t rank, void * data, int64_t * sizes, int64_t * strides, int dtype, int op)
217+
void _idtr_reduce_all(uint64_t rank, void * data, const int64_t * sizes, const int64_t * strides, int dtype, int op)
172218
{
173219
assert(rank == 0 || strides[rank-1] == 1);
174220
idtr_reduce_all(data,
175221
mlir2ddpt(static_cast<::imex::ptensor::DType>(dtype)),
176222
rank ? rank : 1,
177223
mlir2ddpt(static_cast<imex::ptensor::ReduceOpId>(op)));
178224
}
225+
226+
void _idtr_rebalance(uint64_t rank, const int64_t * gShape, const int64_t * lOffs,
227+
void * data, const int64_t * sizes, const int64_t * strides, int dtype,
228+
uint64_t outRank, void * out, const int64_t * outSizes, const int64_t * outStrides)
229+
{
230+
assert(rank);
231+
is_contiguous(outSizes, outStrides, outRank);
232+
auto N = (int64_t)getTransceiver()->nranks();
233+
auto myOff = lOffs[0];
234+
auto mySz = sizes[0];
235+
auto myEnd = myOff + mySz;
236+
auto tSz = gShape[0];
237+
auto sz = (tSz + N - 1) / N;
238+
auto ddpttype = mlir2ddpt(static_cast<::imex::ptensor::DType>(dtype));
239+
auto nSz = std::accumulate(&sizes[1], &sizes[rank], 1, std::multiplies<int64_t>());
240+
std::vector<int> soffs(N);
241+
std::vector<int> sszs(N, 0);
242+
for(auto i=0; i<N; ++i) {
243+
auto tOff = i * sz;
244+
auto tEnd = std::min(tSz, tOff + sz);
245+
if(tEnd > myOff && tOff < myEnd) {
246+
// We have a target partition which is inside my local data
247+
// we now compute what data goes to this target partition
248+
auto start = std::max(myOff, tOff);
249+
auto end = std::min(myEnd, tEnd);
250+
soffs[i] = (int)(start - myOff) * nSz;
251+
sszs[i] = (int)(end - start) * nSz;
252+
} else {
253+
soffs[i] = i ? soffs[i-1] + sszs[i-1] : 0;
254+
}
255+
}
256+
// we now send our send sizes to others and receiver theirs
257+
std::vector<int> rszs(N);
258+
getTransceiver()->alltoall(sszs.data(), 1, INT32, rszs.data());
259+
// For the actual alltoall we need the receive-displacements
260+
std::vector<int> roffs(N);
261+
roffs[0] = 0;
262+
for(auto i=1; i<N; ++i) {
263+
// compute for all i > 0
264+
roffs[i] = roffs[i-1] + rszs[i-1];
265+
}
266+
// create send buffer (might be strided!)
267+
Buffer buff(nSz * mySz * sizeof_dtype(ddpttype));
268+
auto ptr = bufferize(data, ddpttype, sizes, strides, rank, buff.data());
269+
// Finally communicate elements
270+
getTransceiver()->alltoall(ptr, sszs.data(), soffs.data(), ddpttype, out, rszs.data(), roffs.data());
271+
}
179272
} // extern "C"

src/include/ddptensor/CollComm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include "CppTypes.hpp"
6+
#include "PVSlice.hpp"
67
#include "DDPTensorImpl.hpp"
78

89
struct CollComm

src/include/ddptensor/DDPTensorImpl.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
#pragma once
77

8-
#include "PVSlice.hpp"
98
#include "p2c_ids.hpp"
109
#include "tensor_i.hpp"
1110
#include "TypeDispatch.hpp"
@@ -146,8 +145,6 @@ class DDPTensorImpl : public tensor_i
146145
return sizeof_dtype(_dtype);
147146
}
148147

149-
virtual void bufferize(const NDSlice & slc, Buffer & buff) const;
150-
151148
virtual void add_to_args(std::vector<void*> & args, int ndims);
152149

153150
template<typename T>

src/include/ddptensor/MPITransceiver.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ class MPITransceiver : public Transceiver
4646
DTypeId datatype_send,
4747
void* buffer_recv,
4848
const int* counts_recv,
49-
const int* displacements_recv,
50-
DTypeId datatype_recv);
49+
const int* displacements_recv);
50+
virtual void alltoall(const void* buffer_send,
51+
const int counts,
52+
DTypeId datatype,
53+
void* buffer_recv);
5154
virtual void gather(void* buffer,
5255
const int* counts,
5356
const int* displacements,

0 commit comments

Comments
 (0)