Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit fe1976a

Browse files
committed
new future type and ReadyCallback for no-tensor values, fixing syncing with insert_slice ops
1 parent 8e62a55 commit fe1976a

File tree

14 files changed

+130
-72
lines changed

14 files changed

+130
-72
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ set(DDPTSrcs
9898
${PROJECT_SOURCE_DIR}/src/ManipOp.cpp
9999
${PROJECT_SOURCE_DIR}/src/Random.cpp
100100
${PROJECT_SOURCE_DIR}/src/ReduceOp.cpp
101-
${PROJECT_SOURCE_DIR}/src/Service.cpp
102101
${PROJECT_SOURCE_DIR}/src/SetGetItem.cpp
103102
${PROJECT_SOURCE_DIR}/src/Sorting.cpp
104103
)
@@ -110,6 +109,7 @@ set(RTSrcs
110109
${PROJECT_SOURCE_DIR}/src/Mediator.cpp
111110
${PROJECT_SOURCE_DIR}/src/MPIMediator.cpp
112111
${PROJECT_SOURCE_DIR}/src/Registry.cpp
112+
${PROJECT_SOURCE_DIR}/src/Service.cpp
113113
${PROJECT_SOURCE_DIR}/src/jit/mlir.cpp
114114
)
115115
set(IDTRSrcs

src/DDPTensorImpl.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ddptensor/Transceiver.hpp>
99

1010
#include <algorithm>
11+
#include <iostream>
1112

1213
DDPTensorImpl::DDPTensorImpl(Transceiver *transceiver, DTypeId dtype,
1314
uint64_t ndims, void *allocated, void *aligned,
@@ -161,6 +162,8 @@ void DDPTensorImpl::add_to_args(std::vector<void *> &args, int ndims) {
161162
buff[2] = static_cast<intptr_t>(_offset);
162163
memcpy(buff + 3, _sizes, ndims * sizeof(intptr_t));
163164
memcpy(buff + 3 + ndims, _strides, ndims * sizeof(intptr_t));
165+
for (auto i = 0; i < 3 + 2 * ndims; ++i)
166+
std::cerr << " " << buff[i];
164167
args.push_back(buff);
165168
// second the transceiver
166169
args.push_back(&_transceiver);

src/Deferred.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "include/ddptensor/Deferred.hpp"
1111
#include "include/ddptensor/Mediator.hpp"
1212
#include "include/ddptensor/Registry.hpp"
13+
#include "include/ddptensor/Service.hpp"
1314
#include "include/ddptensor/Transceiver.hpp"
1415

1516
#include <imex/Dialect/Dist/IR/DistOps.h>
@@ -50,7 +51,9 @@ Deferred::future_type defer_tensor(Runable::ptr_type &&_d, bool is_global) {
5051
throw std::runtime_error("Expected Deferred Tensor promise");
5152
if (is_global) {
5253
_dist(d);
53-
d->set_guid(Registry::get_guid());
54+
if (d->guid() == Registry::NOGUID) {
55+
d->set_guid(Registry::get_guid());
56+
}
5457
}
5558
auto f = d->get_future();
5659
Registry::put(f);
@@ -154,9 +157,4 @@ void process_promises() {
154157
} while (!done);
155158
}
156159

157-
void sync_promises() {
158-
// FIXME this does not wait for the last deferred to complete
159-
while (!_deferred.empty()) {
160-
std::this_thread::sleep_for(std::chrono::milliseconds(1));
161-
}
162-
}
160+
void sync_promises() { (void)Service::run().get(); }

src/EWBinOp.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -459,20 +459,19 @@ struct DeferredEWBinOp : public Deferred {
459459
auto outTyp =
460460
::imex::ptensor::PTensorType::get(shape, aTyp.getElementType());
461461

462-
dm.addVal(
463-
this->guid(),
464-
builder.create<::imex::ptensor::EWBinOp>(
465-
loc, outTyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), av, bv),
466-
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
467-
void *aligned, intptr_t offset, const intptr_t *sizes,
468-
const intptr_t *strides, uint64_t *gs_allocated,
469-
uint64_t *gs_aligned, uint64_t *lo_allocated,
470-
uint64_t *lo_aligned, uint64_t balanced) {
471-
this->set_value(std::move(
472-
mk_tnsr(transceiver, _dtype, rank, allocated, aligned, offset,
473-
sizes, strides, gs_allocated, gs_aligned, lo_allocated,
474-
lo_aligned, balanced)));
475-
});
462+
auto bop = builder.create<::imex::ptensor::EWBinOp>(
463+
loc, outTyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), av, bv);
464+
dm.addVal(this->guid(), bop,
465+
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
466+
void *aligned, intptr_t offset, const intptr_t *sizes,
467+
const intptr_t *strides, uint64_t *gs_allocated,
468+
uint64_t *gs_aligned, uint64_t *lo_allocated,
469+
uint64_t *lo_aligned, uint64_t balanced) {
470+
this->set_value(std::move(
471+
mk_tnsr(transceiver, _dtype, rank, allocated, aligned,
472+
offset, sizes, strides, gs_allocated, gs_aligned,
473+
lo_allocated, lo_aligned, balanced)));
474+
});
476475
return false;
477476
}
478477

src/Registry.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ id_type get_guid() { return ++_nguid; }
2222

2323
void put(const tensor_i::future_type &ptr) {
2424
locker _l(_mutex);
25-
_keeper[ptr.id()] = ptr;
25+
_keeper.insert({ptr.id(), ptr});
26+
}
27+
28+
bool has(id_type id) {
29+
locker _l(_mutex);
30+
return _keeper.find(id) != _keeper.end();
2631
}
2732

2833
tensor_i::future_type get(id_type id) {

src/Service.cpp

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@ namespace x {
3939
}
4040
#endif // if 0
4141

42-
struct DeferredService : public Deferred {
43-
enum Op : int { REPLICATE, DROP, RUN, SERVICE_LAST };
42+
// **************************************************************************
43+
44+
struct DeferredService : public DeferredT<Service::service_promise_type,
45+
Service::service_future_type> {
46+
enum Op : int { DROP, RUN, SERVICE_LAST };
4447

4548
id_type _a;
4649
Op _op;
@@ -51,18 +54,12 @@ struct DeferredService : public Deferred {
5154

5255
void run() {
5356
switch (_op) {
54-
case REPLICATE: {
55-
const auto a = std::move(Registry::get(_a).get());
56-
auto ddpt = dynamic_cast<DDPTensorImpl *>(a.get());
57-
assert(ddpt);
58-
ddpt->replicate();
59-
set_value(a);
60-
break;
61-
}
6257
case RUN:
58+
set_value(true);
6359
break;
6460
default:
65-
throw(std::runtime_error("Unkown Service operation requested."));
61+
throw(std::runtime_error(
62+
"Execution of unkown service operation requested."));
6663
}
6764
}
6865

@@ -71,13 +68,14 @@ struct DeferredService : public Deferred {
7168
switch (_op) {
7269
case DROP:
7370
dm.drop(_a);
71+
set_value(true);
7472
// FIXME create delete op and return it
7573
break;
7674
case RUN:
77-
case REPLICATE:
7875
return true;
7976
default:
80-
throw(std::runtime_error("Unkown Service operation requested."));
77+
throw(std::runtime_error(
78+
"MLIR generation for unkown service operation requested."));
8179
}
8280

8381
return false;
@@ -91,24 +89,52 @@ struct DeferredService : public Deferred {
9189
}
9290
};
9391

94-
ddptensor *Service::replicate(const ddptensor &a) {
95-
return new ddptensor(
96-
defer<DeferredService>(DeferredService::REPLICATE, a.get()));
97-
}
92+
// **************************************************************************
9893

99-
void Service::run() {
100-
defer<DeferredService>(DeferredService::RUN);
101-
// defer_lambda([](){ return true; });
102-
}
94+
struct DeferredReplicate : public Deferred {
95+
id_type _a;
96+
97+
DeferredReplicate() : _a() {}
98+
DeferredReplicate(const tensor_i::future_type &a) : _a(a.id()) {}
99+
100+
void run() {
101+
const auto a = std::move(Registry::get(_a).get());
102+
auto ddpt = dynamic_cast<DDPTensorImpl *>(a.get());
103+
assert(ddpt);
104+
ddpt->replicate();
105+
set_value(a);
106+
}
107+
108+
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
109+
jit::DepManager &dm) override {
110+
return true;
111+
}
112+
113+
FactoryId factory() const { return F_REPLICATE; }
114+
115+
template <typename S> void serialize(S &ser) {
116+
ser.template value<sizeof(_a)>(_a);
117+
}
118+
};
119+
120+
// **************************************************************************
103121

104122
bool inited = false;
105123
bool finied = false;
106124

107-
void Service::drop(const ddptensor &a) {
125+
Service::service_future_type Service::drop(const ddptensor &a) {
108126
if (inited) {
109-
// if(getTransceiver()->is_spmd()) getTransceiver()->barrier();
110-
defer<DeferredService>(DeferredService::DROP, a.get());
127+
return defer<DeferredService>(DeferredService::DROP, a.get());
111128
}
112129
}
113130

131+
Service::service_future_type Service::run() {
132+
return defer<DeferredService>(DeferredService::RUN);
133+
}
134+
135+
ddptensor *Service::replicate(const ddptensor &a) {
136+
return new ddptensor(defer<DeferredReplicate>(a.get()));
137+
}
138+
114139
FACTORY_INIT(DeferredService, F_SERVICE);
140+
FACTORY_INIT(DeferredReplicate, F_REPLICATE);

src/SetGetItem.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ struct DeferredSetItem : public Deferred {
273273
DeferredSetItem(const tensor_i::future_type &a,
274274
const tensor_i::future_type &b,
275275
const std::vector<py::slice> &v)
276-
: _a(a.id()), _b(b.id()), _slc(v) {}
276+
: Deferred(a.id(), a.dtype(), a.rank(), a.balanced()), _a(a.id()),
277+
_b(b.id()), _slc(v) {}
277278

278279
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
279280
jit::DepManager &dm) override {
@@ -298,19 +299,10 @@ struct DeferredSetItem : public Deferred {
298299
(void)builder.create<::imex::ptensor::InsertSliceOp>(loc, av, bv, offsV,
299300
sizesV, stridesV);
300301
// ... and use av as to later create the ptensor
301-
dm.addVal(this->guid(), av,
302-
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
303-
void *aligned, intptr_t offset, const intptr_t *sizes,
304-
const intptr_t *strides, uint64_t *gs_allocated,
305-
uint64_t *gs_aligned, uint64_t *lo_allocated,
306-
uint64_t *lo_aligned, uint64_t balanced) {
307-
this->set_value(Registry::get(this->_a).get());
308-
// this->set_value(std::move(mk_tnsr(dtype, rank, allocated,
309-
// aligned, offset, sizes, strides,
310-
// gs_allocated, gs_aligned,
311-
// lo_allocated,
312-
// lo_aligned)));
313-
});
302+
dm.addReady(this->guid(), [this](id_type guid) {
303+
assert(this->guid() == guid);
304+
this->set_value(Registry::get(this->_a).get());
305+
});
314306
return false;
315307
}
316308

@@ -368,6 +360,10 @@ struct DeferredGetItem : public Deferred {
368360
// auto outnd = nd == 0 || _slc.size() == 1 ? 0 : nd;
369361
auto outTyp =
370362
::imex::ptensor::PTensorType::get(shape, oTyp.getElementType());
363+
// if(auto dtyp = av.getType().dyn_cast<::imex::dist::DistTensorType>()) {
364+
// av = builder.create<::mlir::UnrealizedConversionCastOp>(loc,
365+
// dtyp.getPTensorType(), av).getResult(0);
366+
// }
371367
// now we can create the PTensor op using the above Values
372368
auto res = builder.create<::imex::ptensor::SubviewOp>(
373369
loc, outTyp, av, offsV, sizesV, stridesV);
@@ -413,10 +409,10 @@ ddptensor *SetItem::__setitem__(ddptensor &a, const std::vector<py::slice> &v,
413409
const py::object &b) {
414410

415411
auto bb = Creator::mk_future(b);
416-
auto res = new ddptensor(defer<DeferredSetItem>(a.get(), bb.first->get(), v));
412+
a.put(defer<DeferredSetItem>(a.get(), bb.first->get(), v));
417413
if (bb.second)
418414
delete bb.first;
419-
return res;
415+
return &a;
420416
}
421417

422418
py::object GetItem::get_slice(const ddptensor &a,

src/include/ddptensor/CppTypes.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ enum FactoryId : int {
210210
F_UNYOP,
211211
F_GATHER,
212212
F_GETLOCAL,
213+
F_REPLICATE,
213214
FACTORY_LAST
214215
};
215216

src/include/ddptensor/Registry.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ tensor_i::future_type get(id_type id);
2626
/// remove future tensor with guid id from registry
2727
void del(id_type id);
2828

29+
/// @return true if given guid is registered
30+
bool has(id_type);
31+
2932
/// finalize registry (before shutdown)
3033
void fini();
3134

src/include/ddptensor/Service.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@
66

77
#pragma once
88

9+
#include <future>
10+
911
class ddptensor;
1012

1113
struct Service {
14+
using service_promise_type = std::promise<bool>;
15+
using service_future_type = std::shared_future<bool>;
16+
1217
/// replicate the given ddptensor on all ranks
1318
static ddptensor *replicate(const ddptensor &a);
1419
/// start running/executing operations, e.g. trigger compile&run
1520
/// this is not blocking, use futures for synchronization
16-
static void run();
21+
static service_future_type run();
1722
/// signal that the given ddptensor is no longer needed and can be deleted
18-
static void drop(const ddptensor &a);
23+
static service_future_type drop(const ddptensor &a);
1924
};

0 commit comments

Comments
 (0)