Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit c43c05f

Browse files
authored
using static team and offsets in imex (#41)
1 parent 9310c82 commit c43c05f

File tree

13 files changed

+124
-110
lines changed

13 files changed

+124
-110
lines changed

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e2ecbb3c08e4a87256cdf11d8057fe765d3a04b8
1+
89b5d56c4774ddb82ab8f896c3d977c6edae267b

src/Creator.cpp

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -64,32 +64,33 @@ struct DeferredFull : public Deferred {
6464
::imex::ptensor::DType dtyp;
6565
::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
6666

67-
auto team =
68-
_team == 0
69-
? ::mlir::Value()
70-
: ::imex::createIndex(loc, builder,
71-
reinterpret_cast<uint64_t>(getTransceiver()));
67+
auto transceiver = getTransceiver();
68+
auto teamV = team() == 0
69+
? ::mlir::Value()
70+
: ::imex::createIndex(loc, builder,
71+
reinterpret_cast<uint64_t>(team()));
7272

7373
auto rTyp = ::imex::ptensor::PTensorType::get(
7474
shape(), imex::ptensor::toMLIR(builder, dtyp));
7575

7676
dm.addVal(this->guid(),
7777
builder.create<::imex::ptensor::CreateOp>(loc, rTyp, shp, dtyp,
78-
val, nullptr, team),
79-
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
80-
void *l_aligned, intptr_t l_offset,
81-
const intptr_t *l_sizes, const intptr_t *l_strides,
82-
void *o_allocated, void *o_aligned, intptr_t o_offset,
78+
val, nullptr, teamV),
79+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
80+
intptr_t l_offset, const intptr_t *l_sizes,
81+
const intptr_t *l_strides, void *o_allocated,
82+
void *o_aligned, intptr_t o_offset,
8383
const intptr_t *o_sizes, const intptr_t *o_strides,
8484
void *r_allocated, void *r_aligned, intptr_t r_offset,
8585
const intptr_t *r_sizes, const intptr_t *r_strides,
8686
uint64_t *lo_allocated, uint64_t *lo_aligned) {
8787
assert(rank == this->rank());
8888
this->set_value(std::move(mk_tnsr(
89-
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
90-
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
91-
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
92-
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
89+
reinterpret_cast<Transceiver *>(this->team()), _dtype,
90+
this->shape(), l_allocated, l_aligned, l_offset, l_sizes,
91+
l_strides, o_allocated, o_aligned, o_offset, o_sizes,
92+
o_strides, r_allocated, r_aligned, r_offset, r_sizes,
93+
r_strides, lo_allocated, lo_aligned)));
9394
});
9495
return false;
9596
}
@@ -126,11 +127,11 @@ struct DeferredArange : public Deferred {
126127
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
127128
jit::DepManager &dm) override {
128129
// ::mlir::Value
129-
auto team =
130-
_team == 0
131-
? ::mlir::Value()
132-
: ::imex::createIndex(loc, builder,
133-
reinterpret_cast<uint64_t>(getTransceiver()));
130+
auto transceiver = getTransceiver();
131+
auto teamV = team() == 0
132+
? ::mlir::Value()
133+
: ::imex::createIndex(loc, builder,
134+
reinterpret_cast<uint64_t>(team()));
134135

135136
auto _num = shape()[0];
136137

@@ -142,22 +143,23 @@ struct DeferredArange : public Deferred {
142143

143144
dm.addVal(this->guid(),
144145
builder.create<::imex::ptensor::LinSpaceOp>(
145-
loc, rTyp, start, stop, num, false, nullptr, team),
146-
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
147-
void *l_aligned, intptr_t l_offset,
148-
const intptr_t *l_sizes, const intptr_t *l_strides,
149-
void *o_allocated, void *o_aligned, intptr_t o_offset,
146+
loc, rTyp, start, stop, num, false, nullptr, teamV),
147+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
148+
intptr_t l_offset, const intptr_t *l_sizes,
149+
const intptr_t *l_strides, void *o_allocated,
150+
void *o_aligned, intptr_t o_offset,
150151
const intptr_t *o_sizes, const intptr_t *o_strides,
151152
void *r_allocated, void *r_aligned, intptr_t r_offset,
152153
const intptr_t *r_sizes, const intptr_t *r_strides,
153154
uint64_t *lo_allocated, uint64_t *lo_aligned) {
154155
assert(rank == 1);
155156
assert(l_strides[0] == 1);
156157
this->set_value(std::move(mk_tnsr(
157-
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
158-
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
159-
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
160-
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
158+
reinterpret_cast<Transceiver *>(this->team()), _dtype,
159+
this->shape(), l_allocated, l_aligned, l_offset, l_sizes,
160+
l_strides, o_allocated, o_aligned, o_offset, o_sizes,
161+
o_strides, r_allocated, r_aligned, r_offset, r_sizes,
162+
r_strides, lo_allocated, lo_aligned)));
161163
});
162164
return false;
163165
}
@@ -193,11 +195,10 @@ struct DeferredLinspace : public Deferred {
193195
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
194196
jit::DepManager &dm) override {
195197
// ::mlir::Value
196-
auto team =
197-
_team == 0
198-
? ::mlir::Value()
199-
: ::imex::createIndex(loc, builder,
200-
reinterpret_cast<uint64_t>(getTransceiver()));
198+
auto teamV = team() == 0
199+
? ::mlir::Value()
200+
: ::imex::createIndex(loc, builder,
201+
reinterpret_cast<uint64_t>(team()));
201202

202203
auto start = ::imex::createFloat(loc, builder, _start);
203204
auto stop = ::imex::createFloat(loc, builder, _end);
@@ -207,22 +208,23 @@ struct DeferredLinspace : public Deferred {
207208

208209
dm.addVal(this->guid(),
209210
builder.create<::imex::ptensor::LinSpaceOp>(
210-
loc, rTyp, start, stop, num, _endpoint, nullptr, team),
211-
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
212-
void *l_aligned, intptr_t l_offset,
213-
const intptr_t *l_sizes, const intptr_t *l_strides,
214-
void *o_allocated, void *o_aligned, intptr_t o_offset,
211+
loc, rTyp, start, stop, num, _endpoint, nullptr, teamV),
212+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
213+
intptr_t l_offset, const intptr_t *l_sizes,
214+
const intptr_t *l_strides, void *o_allocated,
215+
void *o_aligned, intptr_t o_offset,
215216
const intptr_t *o_sizes, const intptr_t *o_strides,
216217
void *r_allocated, void *r_aligned, intptr_t r_offset,
217218
const intptr_t *r_sizes, const intptr_t *r_strides,
218219
uint64_t *lo_allocated, uint64_t *lo_aligned) {
219220
assert(rank == 1);
220221
assert(l_strides[0] == 1);
221222
this->set_value(std::move(mk_tnsr(
222-
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
223-
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
224-
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
225-
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
223+
reinterpret_cast<Transceiver *>(this->team()), _dtype,
224+
this->shape(), l_allocated, l_aligned, l_offset, l_sizes,
225+
l_strides, o_allocated, o_aligned, o_offset, o_sizes,
226+
o_strides, r_allocated, r_aligned, r_offset, r_sizes,
227+
r_strides, lo_allocated, lo_aligned)));
226228
});
227229
return false;
228230
}

src/DDPTensorImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ void DDPTensorImpl::add_to_args(std::vector<void *> &args) {
195195
args.push_back(storeMR(_lData));
196196
} else {
197197
// transceiver/team first
198-
args.push_back(_transceiver);
198+
// args.push_back(_transceiver);
199199
// local tensor first
200200
if (ndims > 0) {
201201
args.push_back(storeMR(_lhsHalo));

src/EWBinOp.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,20 @@ struct DeferredEWBinOp : public Deferred {
106106
// builder.create<::imex::ptensor::EWBinOp>(loc, ddpt2mlir(_op), av,
107107
// bv);
108108
dm.addVal(this->guid(), bop,
109-
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
110-
void *l_aligned, intptr_t l_offset,
111-
const intptr_t *l_sizes, const intptr_t *l_strides,
112-
void *o_allocated, void *o_aligned, intptr_t o_offset,
109+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
110+
intptr_t l_offset, const intptr_t *l_sizes,
111+
const intptr_t *l_strides, void *o_allocated,
112+
void *o_aligned, intptr_t o_offset,
113113
const intptr_t *o_sizes, const intptr_t *o_strides,
114114
void *r_allocated, void *r_aligned, intptr_t r_offset,
115115
const intptr_t *r_sizes, const intptr_t *r_strides,
116116
uint64_t *lo_allocated, uint64_t *lo_aligned) {
117117
this->set_value(std::move(mk_tnsr(
118-
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
119-
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
120-
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
121-
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
118+
reinterpret_cast<Transceiver *>(this->team()), _dtype,
119+
this->shape(), l_allocated, l_aligned, l_offset, l_sizes,
120+
l_strides, o_allocated, o_aligned, o_offset, o_sizes,
121+
o_strides, r_allocated, r_aligned, r_offset, r_sizes,
122+
r_strides, lo_allocated, lo_aligned)));
122123
});
123124
return false;
124125
}

src/EWUnyOp.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,20 @@ struct DeferredEWUnyOp : public Deferred {
212212
auto uop = builder.create<::imex::ptensor::EWUnyOp>(
213213
loc, outTyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), av);
214214
dm.addVal(this->guid(), uop,
215-
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
216-
void *l_aligned, intptr_t l_offset,
217-
const intptr_t *l_sizes, const intptr_t *l_strides,
218-
void *o_allocated, void *o_aligned, intptr_t o_offset,
215+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
216+
intptr_t l_offset, const intptr_t *l_sizes,
217+
const intptr_t *l_strides, void *o_allocated,
218+
void *o_aligned, intptr_t o_offset,
219219
const intptr_t *o_sizes, const intptr_t *o_strides,
220220
void *r_allocated, void *r_aligned, intptr_t r_offset,
221221
const intptr_t *r_sizes, const intptr_t *r_strides,
222222
uint64_t *lo_allocated, uint64_t *lo_aligned) {
223223
this->set_value(std::move(mk_tnsr(
224-
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
225-
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
226-
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
227-
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
224+
reinterpret_cast<Transceiver *>(this->team()), _dtype,
225+
this->shape(), l_allocated, l_aligned, l_offset, l_sizes,
226+
l_strides, o_allocated, o_aligned, o_offset, o_sizes,
227+
o_strides, r_allocated, r_aligned, r_offset, r_sizes,
228+
r_strides, lo_allocated, lo_aligned)));
228229
});
229230
return false;
230231
}

src/IEWBinOp.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ struct DeferredIEWBinOp : public Deferred {
8383
szs, strds);
8484
// ... and use av as to later create the ptensor
8585
dm.addVal(this->guid(), av,
86-
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
87-
void *l_aligned, intptr_t l_offset,
88-
const intptr_t *l_sizes, const intptr_t *l_strides,
89-
void *o_allocated, void *o_aligned, intptr_t o_offset,
86+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
87+
intptr_t l_offset, const intptr_t *l_sizes,
88+
const intptr_t *l_strides, void *o_allocated,
89+
void *o_aligned, intptr_t o_offset,
9090
const intptr_t *o_sizes, const intptr_t *o_strides,
9191
void *r_allocated, void *r_aligned, intptr_t r_offset,
9292
const intptr_t *r_sizes, const intptr_t *r_strides,

src/ManipOp.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,20 @@ struct DeferredReshape : public Deferred {
4141
builder.create<::imex::ptensor::ReshapeOp>(loc, outTyp, av, shp, copyA);
4242

4343
dm.addVal(this->guid(), op,
44-
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
45-
void *l_aligned, intptr_t l_offset,
46-
const intptr_t *l_sizes, const intptr_t *l_strides,
47-
void *o_allocated, void *o_aligned, intptr_t o_offset,
44+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
45+
intptr_t l_offset, const intptr_t *l_sizes,
46+
const intptr_t *l_strides, void *o_allocated,
47+
void *o_aligned, intptr_t o_offset,
4848
const intptr_t *o_sizes, const intptr_t *o_strides,
4949
void *r_allocated, void *r_aligned, intptr_t r_offset,
5050
const intptr_t *r_sizes, const intptr_t *r_strides,
5151
uint64_t *lo_allocated, uint64_t *lo_aligned) {
52-
auto t = mk_tnsr(
53-
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
54-
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
55-
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
56-
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned);
52+
auto t = mk_tnsr(reinterpret_cast<Transceiver *>(this->team()),
53+
_dtype, this->shape(), l_allocated, l_aligned,
54+
l_offset, l_sizes, l_strides, o_allocated,
55+
o_aligned, o_offset, o_sizes, o_strides,
56+
r_allocated, r_aligned, r_offset, r_sizes,
57+
r_strides, lo_allocated, lo_aligned);
5758
if (_copy != COPY_ALWAYS) {
5859
assert(!"copy-free reshape not supported");
5960
if (Registry::has(_a)) {

src/ReduceOp.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,20 @@ struct DeferredReduceOp : public Deferred {
129129
dm.addVal(
130130
this->guid(),
131131
builder.create<::imex::ptensor::ReductionOp>(loc, retPtTyp, op, av),
132-
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
133-
void *l_aligned, intptr_t l_offset, const intptr_t *l_sizes,
132+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
133+
intptr_t l_offset, const intptr_t *l_sizes,
134134
const intptr_t *l_strides, void *o_allocated, void *o_aligned,
135135
intptr_t o_offset, const intptr_t *o_sizes,
136136
const intptr_t *o_strides, void *r_allocated, void *r_aligned,
137137
intptr_t r_offset, const intptr_t *r_sizes,
138138
const intptr_t *r_strides, uint64_t *lo_allocated,
139139
uint64_t *lo_aligned) {
140-
this->set_value(std::move(mk_tnsr(
141-
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
142-
l_offset, l_sizes, l_strides, o_allocated, o_aligned, o_offset,
143-
o_sizes, o_strides, r_allocated, r_aligned, r_offset, r_sizes,
144-
r_strides, lo_allocated, lo_aligned)));
140+
this->set_value(std::move(
141+
mk_tnsr(reinterpret_cast<Transceiver *>(this->team()), _dtype,
142+
this->shape(), l_allocated, l_aligned, l_offset, l_sizes,
143+
l_strides, o_allocated, o_aligned, o_offset, o_sizes,
144+
o_strides, r_allocated, r_aligned, r_offset, r_sizes,
145+
r_strides, lo_allocated, lo_aligned)));
145146
});
146147
return false;
147148
}

src/SetGetItem.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,19 +308,20 @@ struct DeferredGetItem : public Deferred {
308308
loc, outTyp, av, offsV, sizesV, stridesV);
309309

310310
dm.addVal(this->guid(), res,
311-
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
312-
void *l_aligned, intptr_t l_offset,
313-
const intptr_t *l_sizes, const intptr_t *l_strides,
314-
void *o_allocated, void *o_aligned, intptr_t o_offset,
311+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
312+
intptr_t l_offset, const intptr_t *l_sizes,
313+
const intptr_t *l_strides, void *o_allocated,
314+
void *o_aligned, intptr_t o_offset,
315315
const intptr_t *o_sizes, const intptr_t *o_strides,
316316
void *r_allocated, void *r_aligned, intptr_t r_offset,
317317
const intptr_t *r_sizes, const intptr_t *r_strides,
318318
uint64_t *lo_allocated, uint64_t *lo_aligned) {
319-
auto t = mk_tnsr(
320-
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
321-
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
322-
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
323-
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned);
319+
auto t = mk_tnsr(reinterpret_cast<Transceiver *>(this->team()),
320+
_dtype, this->shape(), l_allocated, l_aligned,
321+
l_offset, l_sizes, l_strides, o_allocated,
322+
o_aligned, o_offset, o_sizes, o_strides,
323+
r_allocated, r_aligned, r_offset, r_sizes,
324+
r_strides, lo_allocated, lo_aligned);
324325
if (Registry::has(_a)) {
325326
t->set_base(Registry::get(_a).get());
326327
} // else _a is a temporary and was dropped

src/include/ddptensor/Deferred.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#pragma once
1313

1414
#include "Registry.hpp"
15+
#include "Transceiver.hpp"
1516
#include "jit/mlir.hpp"
1617
#include "tensor_i.hpp"
1718

@@ -43,6 +44,7 @@ struct Runable {
4344
extern void push_runable(Runable::ptr_type &&r);
4445

4546
// helper class
47+
// FIXME team is currently set to getTransceiver() always
4648
template <typename P, typename F> struct DeferredT : public P, public Runable {
4749
using ptr_type = std::unique_ptr<DeferredT>;
4850
using promise_type = P;
@@ -52,11 +54,14 @@ template <typename P, typename F> struct DeferredT : public P, public Runable {
5254
DeferredT(const DeferredT<P, F> &) = delete;
5355
DeferredT(id_type guid, DTypeId dt, shape_type &&shape, uint64_t team,
5456
bool balanced)
55-
: P(guid, dt, std::forward<shape_type>(shape), team, balanced),
57+
: P(guid, dt, std::forward<shape_type>(shape),
58+
team ? reinterpret_cast<uint64_t>(getTransceiver()) : 0, balanced),
5659
Runable() {}
5760
DeferredT(id_type guid, DTypeId dt, const shape_type &shape, uint64_t team,
5861
bool balanced)
59-
: P(guid, dt, shape, team, balanced), Runable() {}
62+
: P(guid, dt, shape,
63+
team ? reinterpret_cast<uint64_t>(getTransceiver()) : 0, balanced),
64+
Runable() {}
6065
};
6166

6267
/// Deferred operation returning/producing a tensor

0 commit comments

Comments
 (0)