Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 5f57e98

Browse files
committed
adjust handling of 0d arrays, somd GC improvements
1 parent 7babf97 commit 5f57e98

File tree

6 files changed

+67
-57
lines changed

6 files changed

+67
-57
lines changed

src/DDPTensorImpl.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ std::string DDPTensorImpl::__repr__() const
106106

107107
dispatch(_dtype, _aligned, [this, nd, &oss](auto * ptr) {
108108
auto cptr = ptr + this->_offset;
109-
printit(oss, 0, cptr);
109+
if(nd>0) printit(oss, 0, cptr);
110+
else oss << *cptr;
110111
});
111112
return oss.str();
112113
}
@@ -158,32 +159,33 @@ void DDPTensorImpl::bufferize(const NDSlice & slc, Buffer & buff) const
158159

159160
void DDPTensorImpl::add_to_args(std::vector<void*> & args, int ndims)
160161
{
161-
assert(ndims == this->ndims() || (ndims == 0 && this->ndims() == 1));
162-
// global shape first
163-
intptr_t * buff = new intptr_t[dtensor_sz(1)];
162+
assert(ndims == this->ndims());
163+
// local tensor first
164+
intptr_t * buff = new intptr_t[dtensor_sz(ndims)];
165+
buff[0] = reinterpret_cast<intptr_t>(_allocated);
166+
buff[1] = reinterpret_cast<intptr_t>(_aligned);
167+
buff[2] = static_cast<intptr_t>(_offset);
168+
memcpy(buff+3, _sizes, ndims*sizeof(intptr_t));
169+
memcpy(buff+3+ndims, _strides, ndims*sizeof(intptr_t));
170+
args.push_back(buff);
171+
// second the team
172+
args.push_back(reinterpret_cast<void*>(1));
173+
if(ndims > 0)
174+
// global shape third
175+
buff = new intptr_t[dtensor_sz(1)];
164176
buff[0] = reinterpret_cast<intptr_t>(_gs_allocated);
165177
buff[1] = reinterpret_cast<intptr_t>(_gs_aligned);
166178
buff[2] = 0;
167179
buff[3] = ndims;
168180
buff[4] = 1;
169181
args.push_back(buff);
170182
assert(5 == memref_sz(1));
171-
// local tensor
172-
buff = new intptr_t[dtensor_sz(ndims)];
173-
buff[0] = reinterpret_cast<intptr_t>(_allocated);
174-
buff[1] = reinterpret_cast<intptr_t>(_aligned);
175-
buff[2] = static_cast<intptr_t>(_offset);
176-
memcpy(buff+3, _sizes, ndims*sizeof(intptr_t));
177-
memcpy(buff+3+ndims, _strides, ndims*sizeof(intptr_t));
178-
args.push_back(buff);
179-
// local offsets
183+
// local offsets last
180184
buff = new intptr_t[dtensor_sz(1)];
181185
buff[0] = reinterpret_cast<intptr_t>(_lo_allocated);
182186
buff[1] = reinterpret_cast<intptr_t>(_lo_aligned);
183187
buff[2] = 0;
184188
buff[3] = ndims;
185189
buff[4] = 1;
186190
args.push_back(buff);
187-
// finally the team
188-
args.push_back(reinterpret_cast<void*>(1));
189191
}

src/Deferred.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ void process_promises()
6565
{
6666
bool done = false;
6767
jit::JIT jit;
68-
6968
do {
7069
::mlir::OpBuilder builder(&jit._context);
7170
auto loc = builder.getUnknownLoc();

src/ddptensor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ rank_type myrank()
4949
return getTransceiver()->rank();
5050
}
5151

52-
std::thread * pprocessor;
52+
std::thread * pprocessor = nullptr;
5353

5454
extern bool inited;
5555
extern bool finied;
@@ -63,6 +63,7 @@ void fini()
6363
if(getTransceiver()->nranks() == 1) defer(nullptr);
6464
pprocessor->join();
6565
delete pprocessor;
66+
pprocessor = nullptr;
6667
}
6768
fini_transceiver();
6869
Deferred::fini();

src/idtr.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,11 @@ void idtr_reduce_all(void * inout, DTypeId dtype, uint64_t N, int op)
124124
getTransceiver()->reduce_all(inout, dtype, N, mlir2ddpt(static_cast<imex::ptensor::ReduceOpId>(op)));
125125
}
126126

127-
// FIXME hard-coded 0d tensor
128-
void _idtr_reduce_all(uint64_t rank, uint64_t * mrd, DTypeId dtype, int op)
127+
// FIXME hard-coded for contiguous layout
128+
void _idtr_reduce_all(uint64_t rank, void * data, int64_t * sizes, int64_t * strides, DTypeId dtype, int op)
129129
{
130-
assert(rank==0);
131-
auto descr = reinterpret_cast<jit::JIT::MemRefDescriptor<uint64_t, 0>*>(mrd);
132-
idtr_reduce_all(descr->aligned + descr->offset, dtype, 1, op);
130+
assert(rank == 0 || strides[rank-1] == 1);
131+
idtr_reduce_all(data, dtype, rank ? rank : 1, op);
133132
}
134133

135134
} // extern "C"

src/include/ddptensor/jit/mlir.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ void init();
5050
class DepManager
5151
{
5252
private:
53-
using IdValueMap = std::unordered_map<id_type, std::pair<::mlir::Value, SetResFunc>>;
53+
using IdValueMap = std::unordered_map<id_type, ::mlir::Value>;
54+
using IdCallbackMap = std::unordered_map<id_type, SetResFunc>;
5455
using IdRankMap = std::unordered_map<id_type, int>;
5556
using ArgList = std::vector<std::pair<id_type, int>>;
5657

5758
::mlir::func::FuncOp & _func; // MLIR function to which ops are added
58-
IdValueMap _ivm; // guid -> {mlir::Value, deliver-callback}
59+
IdValueMap _ivm; // guid -> mlir::Value
60+
IdCallbackMap _icm; // guid -> deliver-callback
5961
IdRankMap _irm; // guid -> rank as computed in MLIR
6062
ArgList _args; // input arguments of the generated function
6163

src/jit/mlir.cpp

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ ::mlir::Value DepManager::getDependent(::mlir::OpBuilder & builder, id_type guid
128128
_func.insertArgument(idx, typ, {}, loc);
129129
auto val = _func.getArgument(idx);
130130
_args.push_back({guid, fut.rank()});
131-
_ivm[guid] = {val, {}};
131+
_ivm[guid] = val;
132132
return val;
133133
} else {
134-
return d->second.first;
134+
return d->second;
135135
}
136136
}
137137

@@ -151,22 +151,23 @@ std::vector<void*> DepManager::store_inputs()
151151
auto f = Registry::get(a.first);
152152
f.get().get()->add_to_args(res, a.second);
153153
_ivm.erase(a.first); // inputs need no delivery
154+
_icm.erase(a.first);
154155
}
155156
return res;
156157
}
157158

158159
void DepManager::addVal(id_type guid, ::mlir::Value val, SetResFunc cb)
159160
{
160161
assert(_ivm.find(guid) == _ivm.end());
161-
_ivm[guid] = {val, cb};
162+
_ivm[guid] = val;
163+
_icm[guid] = cb;
162164
}
163165

164166
void DepManager::drop(id_type guid)
165167
{
166-
if(auto e = _ivm.find(guid); e != _ivm.end()) {
167-
_ivm.erase(e);
168-
// FIXME create delete op
169-
}
168+
_ivm.erase(guid);
169+
_icm.erase(guid);
170+
// FIXME create delete op
170171
}
171172

172173
// Now we have to define the return type as a ValueRange of all arrays which we have created
@@ -186,7 +187,7 @@ uint64_t DepManager::handleResult(::mlir::OpBuilder & builder)
186187
uint64_t sz = 0;
187188
unsigned idx = 0;
188189
for(auto & v : _ivm) {
189-
::mlir::Value value = v.second.first;
190+
::mlir::Value value = v.second;
190191
// append the type and array/value
191192
auto retDtTyp = value.getType().dyn_cast<::imex::dist::DistTensorType>();
192193
if(!retDtTyp) {
@@ -207,44 +208,49 @@ uint64_t DepManager::handleResult(::mlir::OpBuilder & builder)
207208
// add return statement
208209
auto ret_value = builder.create<::mlir::func::ReturnOp>(builder.getUnknownLoc(), ret_values);
209210

211+
// clear any reference to MLIR values
212+
_ivm.clear();
210213
return sz;
211214
}
212215

213216
void DepManager::deliver(intptr_t * output, uint64_t sz)
214217
{
215218
size_t pos = 0;
216-
for(auto & v : _ivm) {
217-
auto value = v.second.first;
219+
for(auto & v : _icm) {
218220
auto rank = _irm[v.first];
219-
// first extract global shape
220-
uint64_t * gs_allocated = reinterpret_cast<uint64_t*>(output[pos]);
221-
uint64_t * gs_aligned = reinterpret_cast<uint64_t*>(output[pos+1]);
222-
intptr_t gs_offset = output[pos+2];
223-
// no sizes/stride needed
224-
pos += memref_sz(1);
225-
// second extract tensor
221+
// first extract tensor
226222
void * t_allocated = reinterpret_cast<void*>(output[pos]);
227223
void * t_aligned = reinterpret_cast<void*>(output[pos+1]);
228224
intptr_t t_offset = output[pos+2];
229225
intptr_t * t_sizes = output + pos + 3;
230226
intptr_t * t_stride = output + pos + 3 + rank;
231227
pos += memref_sz(rank);
232-
// third extract local offsets
233-
uint64_t * lo_allocated = reinterpret_cast<uint64_t*>(output[pos]);
234-
uint64_t * lo_aligned = reinterpret_cast<uint64_t*>(output[pos+1]);
235-
intptr_t lo_offset = output[pos+2];
236-
// no sizes/stride needed
237-
pos += memref_sz(1);
238-
// last is the team
228+
// second is the team
239229
// auto team = output[pos];
240230
pos += 1;
241-
// call finalization
242-
v.second.second(
243-
rank,
244-
t_allocated, t_aligned, t_offset, t_sizes, t_stride, // tensor
245-
gs_allocated, gs_aligned + gs_offset, // global shape is 1d tensor of uint64_t
246-
lo_allocated, lo_aligned + lo_offset // local offset is 1d tensor of uint64_t
247-
);
231+
if(rank > 0) {
232+
// third extract global shape
233+
uint64_t * gs_allocated = reinterpret_cast<uint64_t*>(output[pos]);
234+
uint64_t * gs_aligned = reinterpret_cast<uint64_t*>(output[pos+1]);
235+
intptr_t gs_offset = output[pos+2];
236+
// no sizes/stride needed
237+
pos += memref_sz(1);
238+
// lastly extract local offsets
239+
uint64_t * lo_allocated = reinterpret_cast<uint64_t*>(output[pos]);
240+
uint64_t * lo_aligned = reinterpret_cast<uint64_t*>(output[pos+1]);
241+
intptr_t lo_offset = output[pos+2];
242+
// no sizes/stride needed
243+
pos += memref_sz(1);
244+
// call finalization
245+
v.second(rank,
246+
t_allocated, t_aligned, t_offset, t_sizes, t_stride, // tensor
247+
gs_allocated, gs_aligned + gs_offset, // global shape is 1d tensor of uint64_t
248+
lo_allocated, lo_aligned + lo_offset // local offset is 1d tensor of uint64_t
249+
);
250+
} else { // 0d tensor
251+
v.second(rank, t_allocated, t_aligned, t_offset, t_sizes, t_stride,
252+
nullptr, nullptr, nullptr, nullptr);
253+
}
248254
}
249255
}
250256

@@ -296,8 +302,9 @@ int JIT::run(::mlir::ModuleOp & module, const std::string & fname, std::vector<v
296302
static const char * pass_pipeline =
297303
getenv("DDPT_PASSES")
298304
? getenv("DDPT_PASSES")
299-
: "func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-expand,canonicalize,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,canonicalize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts";
300-
305+
// : "func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-expand,canonicalize,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,canonicalize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts";
306+
// : "builtin.module(func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize,bufferization-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,expand-strided-metadata,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)";
307+
: "func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize,bufferization-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,expand-strided-metadata,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts";
301308
JIT::JIT()
302309
: _context(::mlir::MLIRContext::Threading::DISABLED),
303310
_pm(&_context),

0 commit comments

Comments
 (0)