Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 8c6e63d

Browse files
fschlimbtkarna
andauthored
No copy (#36)
* Adjustments to latest IMEX * fixing no-dist-mode tensor-arg passing * propagation of global shapes * implement idtr_update_halo method * adding a few tests, disabling those using reshape * disabling buffer-dealocation for now * updating imex sha --------- Co-authored-by: Tuomas Karna <tuomas.karna@intel.com>
1 parent 05df8b7 commit 8c6e63d

28 files changed

+790
-529
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ target_link_libraries(_ddpt_rt PRIVATE
174174
IMEXDistDialect
175175
IMEXDistTransforms
176176
IMEXDistToStandard
177+
IMEXUtil
177178
IMEXTransforms
178179
MLIROptLib
179180
MLIRExecutionEngine

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
999b959287d098c0f5062dff294e4c3453ba3f38
1+
cfe20696de8126dee8fcd596ee8e4cc8522977f7

src/Creator.cpp

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@ inline uint64_t mkTeam(uint64_t team) {
2828
}
2929

3030
struct DeferredFull : public Deferred {
31-
shape_type _shape;
3231
PyScalar _val;
3332

3433
DeferredFull() = default;
3534
DeferredFull(const shape_type &shape, PyScalar val, DTypeId dtype,
3635
uint64_t team)
37-
: Deferred(dtype, shape.size(), team, true), _shape(shape), _val(val) {}
36+
: Deferred(dtype, shape, team, true), _val(val) {}
3837

3938
template <typename T> struct ValAndDType {
4039
static ::mlir::Value op(::mlir::OpBuilder &builder, ::mlir::Location loc,
@@ -57,9 +56,9 @@ struct DeferredFull : public Deferred {
5756

5857
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
5958
jit::DepManager &dm) override {
60-
::mlir::SmallVector<::mlir::Value> shp(_shape.size());
61-
for (auto i = 0; i < _shape.size(); ++i) {
62-
shp[i] = ::imex::createIndex(loc, builder, _shape[i]);
59+
::mlir::SmallVector<::mlir::Value> shp(rank());
60+
for (auto i = 0; i < rank(); ++i) {
61+
shp[i] = ::imex::createIndex(loc, builder, shape()[i]);
6362
}
6463

6564
::imex::ptensor::DType dtyp;
@@ -71,27 +70,34 @@ struct DeferredFull : public Deferred {
7170
: ::imex::createIndex(loc, builder,
7271
reinterpret_cast<uint64_t>(getTransceiver()));
7372

73+
auto rTyp = ::imex::ptensor::PTensorType::get(
74+
shape(), imex::ptensor::toMLIR(builder, dtyp));
75+
7476
dm.addVal(this->guid(),
75-
builder.create<::imex::ptensor::CreateOp>(loc, shp, dtyp, val,
76-
nullptr, team),
77-
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
78-
void *aligned, intptr_t offset, const intptr_t *sizes,
79-
const intptr_t *strides, int64_t *gs_allocated,
80-
int64_t *gs_aligned, uint64_t *lo_allocated,
81-
uint64_t *lo_aligned, uint64_t balanced) {
82-
assert(rank == this->_shape.size());
83-
this->set_value(std::move(
84-
mk_tnsr(transceiver, _dtype, rank, allocated, aligned,
85-
offset, sizes, strides, gs_allocated, gs_aligned,
86-
lo_allocated, lo_aligned, balanced)));
77+
builder.create<::imex::ptensor::CreateOp>(loc, rTyp, shp, dtyp,
78+
val, nullptr, team),
79+
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
80+
void *l_aligned, intptr_t l_offset,
81+
const intptr_t *l_sizes, const intptr_t *l_strides,
82+
void *o_allocated, void *o_aligned, intptr_t o_offset,
83+
const intptr_t *o_sizes, const intptr_t *o_strides,
84+
void *r_allocated, void *r_aligned, intptr_t r_offset,
85+
const intptr_t *r_sizes, const intptr_t *r_strides,
86+
uint64_t *lo_allocated, uint64_t *lo_aligned) {
87+
assert(rank == this->rank());
88+
this->set_value(std::move(mk_tnsr(
89+
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
90+
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
91+
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
92+
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
8793
});
8894
return false;
8995
}
9096

9197
FactoryId factory() const { return F_FULL; }
9298

9399
template <typename S> void serialize(S &ser) {
94-
ser.template container<sizeof(shape_type::value_type)>(_shape, 8);
100+
// ser.template container<sizeof(shape_type::value_type)>(_shape, 8);
95101
ser.template value<sizeof(_val)>(_val._int);
96102
ser.template value<sizeof(_dtype)>(_dtype);
97103
}
@@ -111,7 +117,11 @@ struct DeferredArange : public Deferred {
111117
DeferredArange() = default;
112118
DeferredArange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype,
113119
uint64_t team)
114-
: Deferred(dtype, 1, team, true), _start(start), _end(end), _step(step) {}
120+
: Deferred(dtype,
121+
{static_cast<shape_type::value_type>(
122+
(end - start + step + (step < 0 ? 1 : -1)) / step)},
123+
team, true),
124+
_start(start), _end(end), _step(step) {}
115125

116126
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
117127
jit::DepManager &dm) override {
@@ -122,29 +132,32 @@ struct DeferredArange : public Deferred {
122132
: ::imex::createIndex(loc, builder,
123133
reinterpret_cast<uint64_t>(getTransceiver()));
124134

125-
auto _num = (_end - _start + _step + (_step < 0 ? 1 : -1)) / _step;
135+
auto _num = shape()[0];
126136

127137
auto start = ::imex::createFloat(loc, builder, _start);
128138
auto stop = ::imex::createFloat(loc, builder, _start + _num * _step);
129139
auto num = ::imex::createIndex(loc, builder, _num);
130140
auto rTyp = ::imex::ptensor::PTensorType::get(
131-
::llvm::ArrayRef<int64_t>{::mlir::ShapedType::kDynamic},
132-
imex::ptensor::toMLIR(builder, jit::getPTDType(_dtype)));
141+
shape(), imex::ptensor::toMLIR(builder, jit::getPTDType(_dtype)));
133142

134143
dm.addVal(this->guid(),
135144
builder.create<::imex::ptensor::LinSpaceOp>(
136145
loc, rTyp, start, stop, num, false, nullptr, team),
137-
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
138-
void *aligned, intptr_t offset, const intptr_t *sizes,
139-
const intptr_t *strides, int64_t *gs_allocated,
140-
int64_t *gs_aligned, uint64_t *lo_allocated,
141-
uint64_t *lo_aligned, uint64_t balanced) {
146+
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
147+
void *l_aligned, intptr_t l_offset,
148+
const intptr_t *l_sizes, const intptr_t *l_strides,
149+
void *o_allocated, void *o_aligned, intptr_t o_offset,
150+
const intptr_t *o_sizes, const intptr_t *o_strides,
151+
void *r_allocated, void *r_aligned, intptr_t r_offset,
152+
const intptr_t *r_sizes, const intptr_t *r_strides,
153+
uint64_t *lo_allocated, uint64_t *lo_aligned) {
142154
assert(rank == 1);
143-
assert(strides[0] == 1);
144-
this->set_value(std::move(
145-
mk_tnsr(transceiver, _dtype, rank, allocated, aligned,
146-
offset, sizes, strides, gs_allocated, gs_aligned,
147-
lo_allocated, lo_aligned, balanced)));
155+
assert(l_strides[0] == 1);
156+
this->set_value(std::move(mk_tnsr(
157+
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
158+
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
159+
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
160+
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
148161
});
149162
return false;
150163
}
@@ -174,8 +187,8 @@ struct DeferredLinspace : public Deferred {
174187
DeferredLinspace() = default;
175188
DeferredLinspace(double start, double end, uint64_t num, bool endpoint,
176189
DTypeId dtype, uint64_t team)
177-
: Deferred(dtype, 1, team, true), _start(start), _end(end), _num(num),
178-
_endpoint(endpoint) {}
190+
: Deferred(dtype, {static_cast<shape_type::value_type>(num)}, team, true),
191+
_start(start), _end(end), _num(num), _endpoint(endpoint) {}
179192

180193
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
181194
jit::DepManager &dm) override {
@@ -190,23 +203,26 @@ struct DeferredLinspace : public Deferred {
190203
auto stop = ::imex::createFloat(loc, builder, _end);
191204
auto num = ::imex::createIndex(loc, builder, _num);
192205
auto rTyp = ::imex::ptensor::PTensorType::get(
193-
::llvm::ArrayRef<int64_t>{::mlir::ShapedType::kDynamic},
194-
imex::ptensor::toMLIR(builder, jit::getPTDType(_dtype)));
206+
shape(), imex::ptensor::toMLIR(builder, jit::getPTDType(_dtype)));
195207

196208
dm.addVal(this->guid(),
197209
builder.create<::imex::ptensor::LinSpaceOp>(
198210
loc, rTyp, start, stop, num, _endpoint, nullptr, team),
199-
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
200-
void *aligned, intptr_t offset, const intptr_t *sizes,
201-
const intptr_t *strides, int64_t *gs_allocated,
202-
int64_t *gs_aligned, uint64_t *lo_allocated,
203-
uint64_t *lo_aligned, uint64_t balanced) {
211+
[this](Transceiver *transceiver, uint64_t rank, void *l_allocated,
212+
void *l_aligned, intptr_t l_offset,
213+
const intptr_t *l_sizes, const intptr_t *l_strides,
214+
void *o_allocated, void *o_aligned, intptr_t o_offset,
215+
const intptr_t *o_sizes, const intptr_t *o_strides,
216+
void *r_allocated, void *r_aligned, intptr_t r_offset,
217+
const intptr_t *r_sizes, const intptr_t *r_strides,
218+
uint64_t *lo_allocated, uint64_t *lo_aligned) {
204219
assert(rank == 1);
205-
assert(strides[0] == 1);
206-
this->set_value(std::move(
207-
mk_tnsr(transceiver, _dtype, rank, allocated, aligned,
208-
offset, sizes, strides, gs_allocated, gs_aligned,
209-
lo_allocated, lo_aligned, balanced)));
220+
assert(l_strides[0] == 1);
221+
this->set_value(std::move(mk_tnsr(
222+
transceiver, _dtype, this->shape(), l_allocated, l_aligned,
223+
l_offset, l_sizes, l_strides, o_allocated, o_aligned,
224+
o_offset, o_sizes, o_strides, r_allocated, r_aligned,
225+
r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
210226
});
211227
return false;
212228
}

0 commit comments

Comments
 (0)