Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit a8330e2

Browse files
authored
adding linspace (#6)
adding linspace
1 parent 6c906a5 commit a8330e2

File tree

6 files changed

+125
-14
lines changed

6 files changed

+125
-14
lines changed

ddptensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def to_numpy(a):
8383
exec(
8484
f"{func} = lambda start, end, step, dtype, team=0: dtensor(_cdt.Creator.arange(start, end, step, dtype, team))"
8585
)
86+
elif func == "linspace":
87+
exec(
88+
f"{func} = lambda start, end, step, endpoint, dtype, team=0: dtensor(_cdt.Creator.linspace(start, end, step, endpoint, dtype, team))"
89+
)
8690

8791
for func in api.api_categories["ReduceOp"]:
8892
FUNC = func.upper()

src/Creator.cpp

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,11 @@ namespace x {
7979

8080
struct DeferredFromShape : public Deferred {
8181
shape_type _shape;
82-
DTypeId _dtype;
8382
CreatorId _op;
8483

8584
DeferredFromShape() = default;
8685
DeferredFromShape(CreatorId op, const shape_type &shape, DTypeId dtype)
87-
: Deferred(dtype, shape.size(), true), _shape(shape), _dtype(dtype),
88-
_op(op) {}
86+
: Deferred(dtype, shape.size(), true), _shape(shape), _op(op) {}
8987

9088
void run() {
9189
// set_value(std::move(TypeDispatch<x::Creator>(_dtype, _op, _shape)));
@@ -97,7 +95,6 @@ struct DeferredFromShape : public Deferred {
9795

9896
template <typename S> void serialize(S &ser) {
9997
ser.template container<sizeof(shape_type::value_type)>(_shape, 8);
100-
ser.template value<sizeof(_dtype)>(_dtype);
10198
ser.template value<sizeof(_op)>(_op);
10299
}
103100
};
@@ -110,12 +107,10 @@ ddptensor *Creator::create_from_shape(CreatorId op, const shape_type &shape,
110107
struct DeferredFull : public Deferred {
111108
shape_type _shape;
112109
PyScalar _val;
113-
DTypeId _dtype;
114110

115111
DeferredFull() = default;
116112
DeferredFull(const shape_type &shape, PyScalar val, DTypeId dtype)
117-
: Deferred(dtype, shape.size(), true), _shape(shape), _val(val),
118-
_dtype(dtype) {}
113+
: Deferred(dtype, shape.size(), true), _shape(shape), _val(val) {}
119114

120115
void run() {
121116
// auto op = FULL;
@@ -189,6 +184,8 @@ ddptensor *Creator::full(const shape_type &shape, const py::object &val,
189184
return new ddptensor(defer<DeferredFull>(shape, v, dtype));
190185
}
191186

187+
// ***************************************************************************
188+
192189
struct DeferredArange : public Deferred {
193190
uint64_t _start, _end, _step, _team;
194191

@@ -205,19 +202,25 @@ struct DeferredArange : public Deferred {
205202

206203
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
207204
jit::DepManager &dm) override {
208-
auto start = ::imex::createInt(loc, builder, _start);
209-
auto stop = ::imex::createInt(loc, builder, _end);
210-
auto step = ::imex::createInt(loc, builder, _step);
211205
// ::mlir::Value
212206
auto team = /*getTransceiver()->nranks() <= 1
213207
? ::mlir::Value()
214208
:*/
215209
::imex::createIndex(loc, builder,
216210
reinterpret_cast<uint64_t>(getTransceiver()));
217211

212+
auto _num = (_end - _start + _step + (_step < 0 ? 1 : -1)) / _step;
213+
214+
auto start = ::imex::createFloat(loc, builder, _start);
215+
auto stop = ::imex::createFloat(loc, builder, _start + _num * _step);
216+
auto num = ::imex::createIndex(loc, builder, _num);
217+
auto rTyp = ::imex::ptensor::PTensorType::get(
218+
::llvm::ArrayRef<int64_t>{::mlir::ShapedType::kDynamic},
219+
imex::ptensor::toMLIR(builder, jit::getPTDType(_dtype)));
220+
218221
dm.addVal(this->guid(),
219-
builder.create<::imex::ptensor::ARangeOp>(loc, start, stop, step,
220-
nullptr, team),
222+
builder.create<::imex::ptensor::LinSpaceOp>(
223+
loc, rTyp, start, stop, num, false, nullptr, team),
221224
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
222225
void *aligned, intptr_t offset, const intptr_t *sizes,
223226
const intptr_t *strides, uint64_t *gs_allocated,
@@ -239,7 +242,6 @@ struct DeferredArange : public Deferred {
239242
ser.template value<sizeof(_start)>(_start);
240243
ser.template value<sizeof(_end)>(_end);
241244
ser.template value<sizeof(_step)>(_step);
242-
ser.template value<sizeof(_dtype)>(_dtype);
243245
}
244246
};
245247

@@ -248,6 +250,76 @@ ddptensor *Creator::arange(uint64_t start, uint64_t end, uint64_t step,
248250
return new ddptensor(defer<DeferredArange>(start, end, step, dtype, team));
249251
}
250252

253+
// ***************************************************************************
254+
255+
struct DeferredLinspace : public Deferred {
256+
double _start, _end;
257+
uint64_t _num, _team;
258+
bool _endpoint;
259+
260+
DeferredLinspace() = default;
261+
DeferredLinspace(double start, double end, uint64_t num, bool endpoint,
262+
DTypeId dtype, uint64_t team = 0)
263+
: Deferred(dtype, 1, true), _start(start), _end(end), _num(num),
264+
_team(team), _endpoint(endpoint) {}
265+
266+
void run() override{
267+
// set_value(std::move(TypeDispatch<x::Creator>(_dtype, _start, _end,
268+
// _num)));
269+
};
270+
271+
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
272+
jit::DepManager &dm) override {
273+
// ::mlir::Value
274+
auto team = /*getTransceiver()->nranks() <= 1
275+
? ::mlir::Value()
276+
:*/
277+
::imex::createIndex(loc, builder,
278+
reinterpret_cast<uint64_t>(getTransceiver()));
279+
280+
auto start = ::imex::createFloat(loc, builder, _start);
281+
auto stop = ::imex::createFloat(loc, builder, _end);
282+
auto num = ::imex::createIndex(loc, builder, _num);
283+
auto rTyp = ::imex::ptensor::PTensorType::get(
284+
::llvm::ArrayRef<int64_t>{::mlir::ShapedType::kDynamic},
285+
imex::ptensor::toMLIR(builder, jit::getPTDType(_dtype)));
286+
287+
dm.addVal(this->guid(),
288+
builder.create<::imex::ptensor::LinSpaceOp>(
289+
loc, rTyp, start, stop, num, _endpoint, nullptr, team),
290+
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
291+
void *aligned, intptr_t offset, const intptr_t *sizes,
292+
const intptr_t *strides, uint64_t *gs_allocated,
293+
uint64_t *gs_aligned, uint64_t *lo_allocated,
294+
uint64_t *lo_aligned, uint64_t balanced) {
295+
assert(rank == 1);
296+
assert(strides[0] == 1);
297+
this->set_value(std::move(
298+
mk_tnsr(transceiver, _dtype, rank, allocated, aligned,
299+
offset, sizes, strides, gs_allocated, gs_aligned,
300+
lo_allocated, lo_aligned, balanced)));
301+
});
302+
return false;
303+
}
304+
305+
FactoryId factory() const { return F_ARANGE; }
306+
307+
template <typename S> void serialize(S &ser) {
308+
ser.template value<sizeof(_start)>(_start);
309+
ser.template value<sizeof(_end)>(_end);
310+
ser.template value<sizeof(_num)>(_num);
311+
ser.template value<sizeof(_endpoint)>(_endpoint);
312+
}
313+
};
314+
315+
ddptensor *Creator::linspace(double start, double end, uint64_t num,
316+
bool endpoint, DTypeId dtype, uint64_t team) {
317+
return new ddptensor(
318+
defer<DeferredLinspace>(start, end, num, endpoint, dtype, team));
319+
}
320+
321+
// ***************************************************************************
322+
251323
std::pair<ddptensor *, bool> Creator::mk_future(const py::object &b) {
252324
if (py::isinstance<ddptensor>(b)) {
253325
return {b.cast<ddptensor *>(), false};
@@ -263,3 +335,4 @@ std::pair<ddptensor *, bool> Creator::mk_future(const py::object &b) {
263335
FACTORY_INIT(DeferredFromShape, F_FROMSHAPE);
264336
FACTORY_INIT(DeferredFull, F_FULL);
265337
FACTORY_INIT(DeferredArange, F_ARANGE);
338+
FACTORY_INIT(DeferredLinspace, F_LINSPACE);

src/ddptensor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ PYBIND11_MODULE(_ddptensor, m) {
128128
Factory::init<F_GETITEM>();
129129
Factory::init<F_IEWBINOP>();
130130
Factory::init<F_LINALGOP>();
131+
Factory::init<F_LINSPACE>();
131132
Factory::init<F_MANIPOP>();
132133
Factory::init<F_RANDOM>();
133134
Factory::init<F_REDUCEOP>();
@@ -162,7 +163,8 @@ PYBIND11_MODULE(_ddptensor, m) {
162163
py::class_<Creator>(m, "Creator")
163164
.def("create_from_shape", &Creator::create_from_shape)
164165
.def("full", &Creator::full)
165-
.def("arange", &Creator::arange);
166+
.def("arange", &Creator::arange)
167+
.def("linspace", &Creator::linspace);
166168

167169
py::class_<EWUnyOp>(m, "EWUnyOp").def("op", &EWUnyOp::op);
168170
py::class_<IEWBinOp>(m, "IEWBinOp").def("op", &IEWBinOp::op);

src/include/ddptensor/CppTypes.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ enum FactoryId : int {
203203
F_GETLOCAL,
204204
F_IEWBINOP,
205205
F_LINALGOP,
206+
F_LINSPACE,
206207
F_MANIPOP,
207208
F_MAP,
208209
F_RANDOM,

src/include/ddptensor/Creator.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@ struct Creator {
1717
DTypeId dtype = FLOAT64);
1818
static ddptensor *arange(uint64_t start, uint64_t end, uint64_t step,
1919
DTypeId dtype = INT64, uint64_t team = 0);
20+
static ddptensor *linspace(double start, double end, uint64_t num,
21+
bool endpoint, DTypeId dtype, uint64_t team);
2022
static std::pair<ddptensor *, bool> mk_future(const py::object &b);
2123
};

src/include/ddptensor/jit/mlir.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,35 @@ template <> struct PT_DTYPE<bool> {
6363
constexpr static ::imex::ptensor::DType value = ::imex::ptensor::I1;
6464
};
6565

66+
inline ::imex::ptensor::DType getPTDType(DTypeId dt) {
67+
switch (dt) {
68+
case FLOAT64:
69+
return PT_DTYPE<TYPE<FLOAT64>::dtype>::value;
70+
case INT64:
71+
return PT_DTYPE<TYPE<INT64>::dtype>::value;
72+
case FLOAT32:
73+
return PT_DTYPE<TYPE<FLOAT32>::dtype>::value;
74+
case INT32:
75+
return PT_DTYPE<TYPE<INT32>::dtype>::value;
76+
case INT16:
77+
return PT_DTYPE<TYPE<INT16>::dtype>::value;
78+
case INT8:
79+
return PT_DTYPE<TYPE<INT8>::dtype>::value;
80+
case UINT64:
81+
return PT_DTYPE<TYPE<UINT64>::dtype>::value;
82+
case UINT32:
83+
return PT_DTYPE<TYPE<UINT32>::dtype>::value;
84+
case UINT16:
85+
return PT_DTYPE<TYPE<UINT16>::dtype>::value;
86+
case UINT8:
87+
return PT_DTYPE<TYPE<UINT8>::dtype>::value;
88+
case BOOL:
89+
return PT_DTYPE<TYPE<BOOL>::dtype>::value;
90+
default:
91+
throw std::runtime_error("unknown dtype");
92+
}
93+
}
94+
6695
// function type used for reporting back tensor results generated
6796
// by Deferred::generate_mlir
6897
using SetResFunc = std::function<void(

0 commit comments

Comments
 (0)