Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 5dbf1d2

Browse files
committed
adding getitem and simple single-node sort
1 parent 38dd6f9 commit 5dbf1d2

File tree

18 files changed

+135
-61
lines changed

18 files changed

+135
-61
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ set(IDTRSrcs
111111
${PROJECT_SOURCE_DIR}/src/Mediator.cpp
112112
${PROJECT_SOURCE_DIR}/src/MPIMediator.cpp
113113
${PROJECT_SOURCE_DIR}/src/MPITransceiver.cpp
114-
${PROJECT_SOURCE_DIR}/src/PVSlice.cpp
115114
${PROJECT_SOURCE_DIR}/src/Registry.cpp
116115
${PROJECT_SOURCE_DIR}/src/Transceiver.cpp
117116
${PROJECT_SOURCE_DIR}/src/jit/mlir.cpp
@@ -173,6 +172,7 @@ target_link_libraries(idtr PRIVATE
173172
MLIRLinalgDialect
174173
MLIRLinalgToLLVM
175174
MLIRLinalgTransforms
175+
MLIRLLVMDialect
176176
MLIRMathDialect
177177
MLIRMathToLLVM
178178
MLIRMathTransforms

ddptensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def to_numpy(a):
7272
)
7373
elif func == "arange":
7474
exec(
75-
f"{func} = lambda start, end, step, dtype: dtensor(_cdt.Creator.arange(start, end, step, dtype))"
75+
f"{func} = lambda start, end, step, dtype, team=0: dtensor(_cdt.Creator.arange(start, end, step, dtype, team))"
7676
)
7777

7878
for func in api.api_categories["ReduceOp"]:

ddptensor/ddptensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def _inplace(self, t):
5050
f"{att} = property(lambda self: self._t.{att})"
5151
)
5252

53-
def __getitem__(self, *args):
54-
return dtensor(self._t.__getitem__(*args))
53+
def __getitem__(self, key):
54+
return dtensor(self._t.__getitem__(key if isinstance(key, list) else [key,]))
5555

5656
def __setitem__(self, key, value):
5757
self._t = self._t.__setitem__(key, value._t) # if isinstance(value, dtensor) else value)

src/CollComm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// This is not implemented: we need an extra mechanism to work with reshape-views or alike.
1616
std::vector<std::vector<int>> CollComm::map(const PVSlice & n_slc, const PVSlice & o_slc)
1717
{
18+
#if 0
1819
auto nr = getTransceiver()->nranks();
1920
std::vector<int> counts_send(nr, 0);
2021
std::vector<int> disp_send(nr, 0);
@@ -63,4 +64,6 @@ std::vector<std::vector<int>> CollComm::map(const PVSlice & n_slc, const PVSlice
6364
disp_send[r] = soverlap._start - o_llslc._start;
6465
}
6566
return {counts_send, disp_send, counts_recv, disp_recv};
67+
#endif // if 0
68+
return {};
6669
}

src/Creator.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,12 @@ ddptensor * Creator::full(const shape_type & shape, const py::object & val, DTyp
152152

153153
struct DeferredArange : public Deferred
154154
{
155-
uint64_t _start, _end, _step;
155+
uint64_t _start, _end, _step, _team;
156156

157157
DeferredArange() = default;
158-
DeferredArange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype)
158+
DeferredArange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype, uint64_t team = 0)
159159
: Deferred(dtype, 1),
160-
_start(start), _end(end), _step(step)
160+
_start(start), _end(end), _step(step), _team(team)
161161
{}
162162

163163
void run() override
@@ -171,9 +171,10 @@ struct DeferredArange : public Deferred
171171
auto stop = ::imex::createInt(loc, builder, _end);
172172
auto step = ::imex::createInt(loc, builder, _step);
173173
auto dtype = builder.getI64Type(); // FIXME
174-
auto artype = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get({-1}, dtype), false);
174+
auto artype = ::imex::ptensor::PTensorType::get(builder.getContext(), 1, dtype, false);
175175
auto dmy = ::imex::createInt<1>(loc, builder, 0);
176-
auto team = ::imex::createInt(loc, builder, reinterpret_cast<uint64_t>(getTransceiver()));
176+
// ::mlir::Value
177+
auto team = ::imex::createIndex(loc, builder, reinterpret_cast<uint64_t>(getTransceiver()));
177178
dm.addVal(this->guid(),
178179
builder.create<::imex::ptensor::ARangeOp>(loc, artype, start, stop, step, dmy, team),
179180
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
@@ -201,9 +202,9 @@ struct DeferredArange : public Deferred
201202
}
202203
};
203204

204-
ddptensor * Creator::arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype)
205+
ddptensor * Creator::arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype, uint64_t team)
205206
{
206-
return new ddptensor(defer<DeferredArange>(start, end, step, dtype));
207+
return new ddptensor(defer<DeferredArange>(start, end, step, dtype, team));
207208
}
208209

209210
ddptensor * Creator::mk_future(const py::object & b)

src/EWBinOp.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "ddptensor/Creator.hpp"
99
#include "ddptensor/TypePromotion.hpp"
1010
#include "ddptensor/CollComm.hpp"
11-
#include "ddptensor/Chunker.hpp"
1211
#include "ddptensor/DDPTensorImpl.hpp"
1312

1413
#include <imex/Dialect/PTensor/IR/PTensorOps.h>

src/MPIMediator.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "ddptensor/CppTypes.hpp"
1010
#include "ddptensor/MPIMediator.hpp"
1111
#include "ddptensor/MPITransceiver.hpp"
12-
#include "ddptensor/NDSlice.hpp"
1312
#include "ddptensor/Factory.hpp"
1413

1514
constexpr static int REQ_TAG = 14711;
@@ -50,6 +49,7 @@ MPIMediator::~MPIMediator()
5049
}
5150
}
5251

52+
#if 0
5353
void MPIMediator::pull(rank_type from, id_type guid, const NDSlice & slice, void * rbuff)
5454
{
5555
MPI_Request request[2];
@@ -81,6 +81,7 @@ void MPIMediator::pull(rank_type from, id_type guid, const NDSlice & slice, void
8181
MPI_Get_count(&status[1], MPI_CHAR, &cnt);
8282
if(cnt != sz) throw(std::runtime_error("Received unexpected message size."));
8383
}
84+
#endif
8485

8586
void send_to_workers(const Runable * dfrd, bool self, MPI_Comm comm)
8687
{
@@ -162,6 +163,7 @@ void MPIMediator::listen()
162163
uptr.get()->defer(std::move(uptr)); // grmpf
163164
break;
164165
}
166+
#if 0
165167
case PULL_TAG: {
166168
uint64_t id;
167169
ser.value8b(id);
@@ -182,6 +184,7 @@ void MPIMediator::listen()
182184
MPI_Isend(rbuff.data(), rbuff.size(), MPI_CHAR, requester, PUSH_TAG, _comm, &request_out);
183185
break;
184186
}
187+
#endif
185188
case EXIT_TAG:
186189
defer(nullptr);
187190
return;

src/ReduceOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ struct DeferredReduceOp : public Deferred
124124
auto av = dm.getDependent(builder, _a);
125125
auto aPtTyp = ::imex::dist::getPTensorType(av);
126126
assert(aPtTyp);
127-
::mlir::Type dtype = aPtTyp.getRtensor().getElementType();
127+
::mlir::Type dtype = aPtTyp.getElementType();
128128
// return type 0d with same dtype as input
129-
auto retPtTyp = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get({}, dtype), false);
129+
auto retPtTyp = ::imex::ptensor::PTensorType::get(builder.getContext(), 0, dtype, false);
130130
// reduction op
131131
auto mop = ddpt2mlir(_op);
132132
auto op = builder.getIntegerAttr(builder.getIntegerType(sizeof(mop)*8), mop);

src/SetGetItem.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
#include "ddptensor/Mediator.hpp"
66
#include "ddptensor/Factory.hpp"
77

8+
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
9+
#include <imex/Dialect/Dist/IR/DistOps.h>
10+
#include <imex/internal/PassUtils.h>
11+
#include <mlir/IR/Builders.h>
12+
813
#if 0
914
namespace x {
1015

@@ -223,14 +228,48 @@ struct DeferredGetItem : public Deferred
223228

224229
DeferredGetItem() = default;
225230
DeferredGetItem(const tensor_i::future_type & a, const std::vector<py::slice> & v)
226-
: _a(a.id()), _slc(v)
231+
: Deferred(a.dtype(), a.rank()), _a(a.id()), _slc(v)
227232
{}
228233

229234
void run()
230235
{
231236
//const auto a = std::move(Registry::get(_a).get());
232237
//set_value(std::move(TypeDispatch<x::GetItem>(a, _slc)));
233238
}
239+
240+
bool generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::DepManager & dm) override
241+
{
242+
// get params and extract offsets/sizes/strides
243+
const auto dtype = this->dtype();
244+
auto av = dm.getDependent(builder, _a);
245+
auto & offs = _slc.offsets();
246+
auto & sizes = _slc.sizes();
247+
auto & strides = _slc.strides();
248+
auto nd = offs.size();
249+
// convert C++ slices into vectors of MLIR Values
250+
std::vector<::mlir::Value> offsV(nd);
251+
std::vector<::mlir::Value> sizesV(nd);
252+
std::vector<::mlir::Value> stridesV(nd);
253+
for(auto i = 0; i<nd; ++i) {
254+
offsV[i] = ::imex::createIndex(loc, builder, offs[i]);
255+
sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]);
256+
stridesV[i] = ::imex::createIndex(loc, builder, strides[i]);
257+
}
258+
// now we can create the PTensor op using the above Values
259+
dm.addVal(this->guid(),
260+
builder.create<::imex::ptensor::ExtractSliceOp>(loc,
261+
::imex::dist::getPTensorType(av),
262+
av,
263+
offsV,
264+
sizesV,
265+
stridesV),
266+
[this, dtype](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
267+
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
268+
this->set_value(std::move(mk_tnsr(dtype, rank, allocated, aligned, offset, sizes, strides,
269+
gs_allocated, gs_aligned, lo_allocated, lo_aligned)));
270+
});
271+
return false;
272+
}
234273

235274
FactoryId factory() const
236275
{

src/include/ddptensor/Creator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ struct Creator
1010
{
1111
static ddptensor * create_from_shape(CreatorId op, const shape_type & shape, DTypeId dtype=FLOAT64);
1212
static ddptensor * full(const shape_type & shape, const py::object & val, DTypeId dtype=FLOAT64);
13-
static ddptensor * arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype=INT64);
13+
static ddptensor * arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype=INT64, uint64_t team=0);
1414
static ddptensor * mk_future(const py::object & b);
1515
};

0 commit comments

Comments
 (0)