Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit ea8bcc6

Browse files
committed
adding setitem and adjustsing to new imex
1 parent 5dbf1d2 commit ea8bcc6

File tree

5 files changed

+42
-8
lines changed

5 files changed

+42
-8
lines changed

ddptensor/ddptensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,4 @@ def __getitem__(self, key):
5454
return dtensor(self._t.__getitem__(key if isinstance(key, list) else [key,]))
5555

5656
def __setitem__(self, key, value):
57-
self._t = self._t.__setitem__(key, value._t) # if isinstance(value, dtensor) else value)
57+
self._t.__setitem__(key if isinstance(key, list) else [key,], value._t) # if isinstance(value, dtensor) else value)

src/Creator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "ddptensor/DDPTensorImpl.hpp"
66

77
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
8-
#include <imex/internal/PassUtils.h>
8+
#include <imex/Utils/PassUtils.h>
99

1010
#include <mlir/IR/Builders.h>
1111
#include <mlir/Dialect/Arith/IR/Arith.h>

src/SetGetItem.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
99
#include <imex/Dialect/Dist/IR/DistOps.h>
10-
#include <imex/internal/PassUtils.h>
10+
#include <imex/Utils/PassUtils.h>
1111
#include <mlir/IR/Builders.h>
1212

1313
#if 0
@@ -202,6 +202,38 @@ struct DeferredSetItem : public Deferred
202202
//set_value(std::move(TypeDispatch<x::SetItem>(a, b, _slc, _b)));
203203
}
204204

205+
bool generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::DepManager & dm) override
206+
{
207+
// get params and extract offsets/sizes/strides
208+
const auto dtype = this->dtype();
209+
auto av = dm.getDependent(builder, _a);
210+
auto bv = dm.getDependent(builder, _b);
211+
auto & offs = _slc.offsets();
212+
auto & sizes = _slc.sizes();
213+
auto & strides = _slc.strides();
214+
auto nd = offs.size();
215+
// convert C++ slices into vectors of MLIR Values
216+
std::vector<::mlir::Value> offsV(nd);
217+
std::vector<::mlir::Value> sizesV(nd);
218+
std::vector<::mlir::Value> stridesV(nd);
219+
for(auto i = 0; i<nd; ++i) {
220+
offsV[i] = ::imex::createIndex(loc, builder, offs[i]);
221+
sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]);
222+
stridesV[i] = ::imex::createIndex(loc, builder, strides[i]);
223+
}
224+
// insertsliceop has no return value, so we just craete the op...
225+
builder.create<::imex::ptensor::InsertSliceOp>(loc, av, bv, offsV, sizesV, stridesV);
226+
// ... and use av as to later create the ptensor
227+
dm.addVal(this->guid(), av,
228+
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
229+
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
230+
this->set_value(Registry::get(this->_a).get());
231+
// this->set_value(std::move(mk_tnsr(dtype, rank, allocated, aligned, offset, sizes, strides,
232+
// gs_allocated, gs_aligned, lo_allocated, lo_aligned)));
233+
});
234+
return false;
235+
}
236+
205237
FactoryId factory() const
206238
{
207239
return F_SETITEM;
@@ -264,7 +296,7 @@ struct DeferredGetItem : public Deferred
264296
sizesV,
265297
stridesV),
266298
[this, dtype](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
267-
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
299+
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
268300
this->set_value(std::move(mk_tnsr(dtype, rank, allocated, aligned, offset, sizes, strides,
269301
gs_allocated, gs_aligned, lo_allocated, lo_aligned)));
270302
});

src/idtr.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: BSD-3-Clause
22

33
#include <ddptensor/idtr.hpp>
4+
#include <ddptensor/jit/mlir.hpp>
45
#include <ddptensor/DDPTensorImpl.hpp>
56
#include <ddptensor/MPITransceiver.hpp>
67

@@ -120,14 +121,15 @@ static ReduceOpId mlir2ddpt(const ::imex::ptensor::ReduceOpId rop)
120121
// Elementwise inplace allreduce
121122
void idtr_reduce_all(void * inout, DTypeId dtype, uint64_t N, int op)
122123
{
123-
124124
getTransceiver()->reduce_all(inout, dtype, N, mlir2ddpt(static_cast<imex::ptensor::ReduceOpId>(op)));
125125
}
126126

127127
// FIXME hard-coded 0d tensor
128-
void _idtr_reduce_all(uint64_t * allocated, uint64_t * aligned, uint64_t offset, DTypeId dtype, int op)
128+
void _idtr_reduce_all(uint64_t rank, uint64_t * mrd, DTypeId dtype, int op)
129129
{
130-
idtr_reduce_all(aligned + offset, dtype, 1, op);
130+
assert(rank==0);
131+
auto descr = reinterpret_cast<jit::JIT::MemRefDescriptor<uint64_t, 0>*>(mrd);
132+
idtr_reduce_all(descr->aligned + descr->offset, dtype, 1, op);
131133
}
132134

133135
} // extern "C"

src/jit/mlir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ int JIT::run(::mlir::ModuleOp & module, const std::string & fname, std::vector<v
296296
static const char * pass_pipeline =
297297
getenv("DDPT_PASSES")
298298
? getenv("DDPT_PASSES")
299-
: "func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-expand,canonicalize,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,canonicalize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,func.func(lower-affine),fold-memref-alias-ops,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,convert-scf-to-cf,reconcile-unrealized-casts";
299+
: "func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-expand,canonicalize,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,canonicalize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts";
300300

301301
JIT::JIT()
302302
: _context(::mlir::MLIRContext::Threading::DISABLED),

0 commit comments

Comments
 (0)