Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 38dd6f9

Browse files
committed
enabling non-compiled operations, adding sort, disabling PVSlice
1 parent 6f1cb8e commit 38dd6f9

File tree

13 files changed

+107
-79
lines changed

13 files changed

+107
-79
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ set(DDPTSrcs
100100
${PROJECT_SOURCE_DIR}/src/ReduceOp.cpp
101101
${PROJECT_SOURCE_DIR}/src/Service.cpp
102102
${PROJECT_SOURCE_DIR}/src/SetGetItem.cpp
103+
${PROJECT_SOURCE_DIR}/src/Sorting.cpp
103104
)
104105
set(IDTRSrcs
105106
${PROJECT_SOURCE_DIR}/src/idtr.cpp

ddptensor/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,8 @@ def to_numpy(a):
102102
exec(
103103
f"{func} = lambda this: dtensor(_cdt.LinAlgOp.{func}(this._t))"
104104
)
105+
106+
for func in api.api_categories["SortOp"]:
107+
exec(
108+
f"{func} = lambda this, axis=-1, descending=False, stable=True: dtensor(_cdt.SortOp.{func}(this._t, descending))"
109+
)

ddptensor/array_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@
190190
"tensordot", # (x1, x2, /, *, axes=2)
191191
"vecdot", # (x1, x2, /, *, axis=-1)
192192
],
193+
194+
"SortOp" : [
195+
"argsort", # (x, /, *, axis=-1, descending=False, stable=True)
196+
"sort", #(x, /, *, axis=-1, descending=False, stable=True)
197+
],
193198
})
194199

195200
misc_methods = [

src/Deferred.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,11 @@ void process_promises()
8888

8989
jit::DepManager dm(function);
9090

91+
Runable::ptr_type d;
9192
while(true) {
92-
Runable::ptr_type d;
9393
_deferred.pop(d);
9494
if(d) {
9595
if(d->generate_mlir(builder, loc, dm)) {
96-
d.reset();
9796
break;
9897
};
9998
// keep alive for later set_value
@@ -105,28 +104,30 @@ void process_promises()
105104
}
106105
}
107106

108-
if(runables.empty()) continue;
107+
if(!runables.empty()) {
108+
// create return statement and adjust function type
109+
uint64_t osz = dm.handleResult(builder);
110+
// also request generation of c-wrapper function
111+
function->setAttr(::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), ::mlir::UnitAttr::get(&jit._context));
112+
function.getFunctionType().dump(); std::cout << std::endl;
113+
// add the function to the module
114+
module.push_back(function);
109115

110-
// create return statement and adjust function type
111-
uint64_t osz = dm.handleResult(builder);
112-
// also request generation of c-wrapper function
113-
function->setAttr(::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), ::mlir::UnitAttr::get(&jit._context));
114-
function.getFunctionType().dump();
115-
// add the function to the module
116-
module.push_back(function);
117-
module.dump();
116+
// get input buffers (before results!)
117+
auto input = std::move(dm.store_inputs());
118118

119-
// get input buffers (before results!)
120-
auto input = std::move(dm.store_inputs());
119+
// compile and run the module
120+
intptr_t * output = new intptr_t[osz];
121+
if(jit.run(module, fname, input, output)) throw std::runtime_error("failed running jit");
121122

122-
// compile and run the module
123-
intptr_t * output = new intptr_t[osz];
124-
if(jit.run(module, fname, input, output)) throw std::runtime_error("failed running jit");
123+
// push results to deliver promises
124+
dm.deliver(output, osz);
125125

126-
// push results to deliver promises
127-
dm.deliver(output, osz);
126+
delete [] output;
127+
} // no else needed
128128

129-
delete [] output;
129+
// now we execute the deferred action which could not be compiled
130+
if(d) d->run();
130131
} while(!done);
131132
}
132133

src/EWBinOp.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "ddptensor/DDPTensorImpl.hpp"
1313

1414
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
15+
#include <imex/Dialect/Dist/IR/DistOps.h>
1516
#include <mlir/IR/Builders.h>
1617
#include <mlir/Dialect/Shape/IR/Shape.h>
1718

@@ -462,8 +463,12 @@ struct DeferredEWBinOp : public Deferred
462463
// FIXME the type of the result is based on a only
463464
auto av = dm.getDependent(builder, _a);
464465
auto bv = dm.getDependent(builder, _b);
466+
467+
auto aPtTyp = ::imex::dist::getPTensorType(av);
468+
assert(aPtTyp);
469+
465470
dm.addVal(this->guid(),
466-
builder.create<::imex::ptensor::EWBinOp>(loc, av.getType(), builder.getI32IntegerAttr(ddpt2mlir(_op)), av, bv),
471+
builder.create<::imex::ptensor::EWBinOp>(loc, aPtTyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), av, bv),
467472
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
468473
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
469474
this->set_value(std::move(mk_tnsr(_dtype, rank, allocated, aligned, offset, sizes, strides,

src/ReduceOp.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,9 @@ struct DeferredReduceOp : public Deferred
122122
{
123123
// FIXME reduction over individual dimensions is not supported
124124
auto av = dm.getDependent(builder, _a);
125-
auto aDtTyp = av.getType().dyn_cast<::imex::dist::DistTensorType>();
126-
::mlir::Type dtype;
127-
if(aDtTyp) {
128-
dtype = aDtTyp.getPTensorType().getRtensor().getElementType();
129-
} else {
130-
auto aPtTyp = av.getType().dyn_cast<::imex::ptensor::PTensorType>();
131-
dtype = aPtTyp.getRtensor().getElementType();
132-
}
125+
auto aPtTyp = ::imex::dist::getPTensorType(av);
126+
assert(aPtTyp);
127+
::mlir::Type dtype = aPtTyp.getRtensor().getElementType();
133128
// return type 0d with same dtype as input
134129
auto retPtTyp = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get({}, dtype), false);
135130
// reduction op

src/ddptensor.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,23 @@ using namespace pybind11::literals; // to bring _a
2323

2424
#define DEF_PY11_ENUMS // used in p2c_types.hpp
2525

26-
#include "ddptensor/MPITransceiver.hpp"
27-
#include "ddptensor/MPIMediator.hpp"
28-
#include "ddptensor/Deferred.hpp"
2926
#include "ddptensor/Creator.hpp"
30-
#include "ddptensor/IEWBinOp.hpp"
27+
#include "ddptensor/Deferred.hpp"
3128
#include "ddptensor/EWBinOp.hpp"
3229
#include "ddptensor/EWUnyOp.hpp"
33-
#include "ddptensor/ReduceOp.hpp"
34-
#include "ddptensor/ManipOp.hpp"
35-
#include "ddptensor/SetGetItem.hpp"
36-
#include "ddptensor/Random.hpp"
37-
#include "ddptensor/LinAlgOp.hpp"
38-
#include "ddptensor/Service.hpp"
3930
#include "ddptensor/Factory.hpp"
31+
#include "ddptensor/IEWBinOp.hpp"
4032
#include "ddptensor/IO.hpp"
4133
#include "ddptensor/jit/mlir.hpp"
34+
#include "ddptensor/LinAlgOp.hpp"
35+
#include "ddptensor/ManipOp.hpp"
36+
#include "ddptensor/MPIMediator.hpp"
37+
#include "ddptensor/MPITransceiver.hpp"
38+
#include "ddptensor/Random.hpp"
39+
#include "ddptensor/ReduceOp.hpp"
40+
#include "ddptensor/Service.hpp"
41+
#include "ddptensor/SetGetItem.hpp"
42+
#include "ddptensor/Sorting.hpp"
4243

4344
// #########################################################################
4445
// The following classes are wrappers bridging pybind11 defs to TypeDispatch
@@ -92,20 +93,21 @@ void init(bool cw)
9293
// #########################################################################
9394
// Finally our Python module
9495
PYBIND11_MODULE(_ddptensor, m) {
95-
Factory::init<F_ARANGE>();
96-
Factory::init<F_FULL>();
97-
Factory::init<F_FROMSHAPE>();
9896
// Factory::init<F_UNYOP>();
97+
Factory::init<F_ARANGE>();
98+
Factory::init<F_EWBINOP>();
9999
Factory::init<F_EWUNYOP>();
100+
Factory::init<F_FROMSHAPE>();
101+
Factory::init<F_FULL>();
102+
Factory::init<F_GETITEM>();
100103
Factory::init<F_IEWBINOP>();
101-
Factory::init<F_EWBINOP>();
102-
Factory::init<F_REDUCEOP>();
103-
Factory::init<F_MANIPOP>();
104104
Factory::init<F_LINALGOP>();
105-
Factory::init<F_GETITEM>();
106-
Factory::init<F_SETITEM>();
105+
Factory::init<F_MANIPOP>();
107106
Factory::init<F_RANDOM>();
107+
Factory::init<F_REDUCEOP>();
108108
Factory::init<F_SERVICE>();
109+
Factory::init<F_SETITEM>();
110+
Factory::init<F_SORTOP>();
109111
Factory::init<F_TONUMPY>();
110112

111113
jit::init();
@@ -148,6 +150,9 @@ PYBIND11_MODULE(_ddptensor, m) {
148150
py::class_<LinAlgOp>(m, "LinAlgOp")
149151
.def("vecdot", &LinAlgOp::vecdot);
150152

153+
py::class_<SortOp>(m, "SortOp")
154+
.def("sort", &SortOp::sort);
155+
151156
/// trigger compile&run and return given attribute _x
152157
#define SYNC_RETURN(_f, _a) Service::run(); return (_f).get().get()->_a()
153158
/// Rerplicate ddptensor/future and SYNC_RETURN attributre _a

src/idtr.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ uint64_t idtr_prank(int64_t team)
4949
id_t idtr_init_dtensor(const uint64_t * shape, uint64_t nD)
5050
{
5151
auto guid = get_guid();
52-
gtensors[guid] = std::unique_ptr<DDPTensorImpl>(nD ? new DDPTensorImpl(shape, nD) : new DDPTensorImpl);
52+
// gtensors[guid] = std::unique_ptr<DDPTensorImpl>(nD ? new DDPTensorImpl(shape, nD) : new DDPTensorImpl);
5353
return guid;
5454
}
5555

@@ -62,13 +62,15 @@ id_t _idtr_init_dtensor(void * alloced, void * aligned, intptr_t offset, intptr_
6262
// Result is stored in provided array.
6363
void idtr_local_offsets(id_t guid, uint64_t * offsets, uint64_t nD)
6464
{
65+
#if 0
6566
const auto & tnsr = gtensors.at(guid);
6667
auto slcs = tnsr->slice().local_slice().slices();
6768
assert(nD == slcs.size());
6869
int i = -1;
6970
for(auto s : slcs) {
7071
offsets[++i] = s._start;
7172
}
73+
#endif
7274
}
7375

7476
void _idtr_local_offsets(id_t guid, void * alloced, void * aligned, intptr_t offset, intptr_t size, intptr_t stride, uint64_t nD)
@@ -80,9 +82,11 @@ void _idtr_local_offsets(id_t guid, void * alloced, void * aligned, intptr_t off
8082
// Result is stored in provided array.
8183
void idtr_local_shape(id_t guid, uint64_t * lshape, uint64_t N)
8284
{
85+
#if 0
8386
const auto & tnsr = gtensors.at(guid);
8487
auto shp = tnsr->slice().local_slice().shape();
8588
std::copy(shp.begin(), shp.end(), lshape);
89+
#endif
8690
}
8791

8892
void _idtr_local_shape(id_t guid, void * alloced, void * aligned, intptr_t offset, intptr_t size, intptr_t stride, uint64_t nD)

src/include/ddptensor/CollComm.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct CollComm
2626
template<typename T, typename U>
2727
static tensor_i::ptr_type coll_copy(std::shared_ptr<DDPTensorImpl> b_ptr, const std::shared_ptr<DDPTensorImpl> & a_ptr)
2828
{
29+
#if 0
2930
assert(! a_ptr->is_sliced() && ! b_ptr->is_sliced());
3031
auto info = CollComm::map(b_ptr->slice(), a_ptr->slice());
3132

@@ -38,13 +39,15 @@ struct CollComm
3839
info[2].data(),
3940
info[3].data(),
4041
DTYPE<T>::value);
42+
#endif
4143

4244
return b_ptr;
4345
}
4446

4547
template<typename T, typename U>
4648
static std::array<int, 4> coll_map(const std::shared_ptr<DDPTensorImpl> & b_ptr, const std::shared_ptr<DDPTensorImpl> & a_ptr, std::vector<U> & rbuff)
4749
{
50+
#if 0
4851
auto info = CollComm::map(b_ptr->slice(), a_ptr->slice());
4952

5053
auto nr = getTransceiver()->nranks();
@@ -83,11 +86,14 @@ struct CollComm
8386
DTYPE<U>::value);
8487

8588
return {my_cnt_send, info[1][r], my_cnt_recv, info[3][r]};
89+
#endif
90+
return {-1,-1,-1,-1};
8691
}
8792

8893
template<typename A, typename B>
8994
static std::array<uint64_t, 2> coll_copy(const std::shared_ptr<DDPTensorImpl> & a_ptr, const std::array<std::vector<NDSlice>, 2> & a_overlap, std::vector<B> & rbuff)
9095
{
96+
#if 0
9197
if(a_overlap[0].empty()) return {0, 0};
9298

9399
auto nr = getTransceiver()->nranks();
@@ -120,5 +126,7 @@ struct CollComm
120126
&disp_recv[0],
121127
DTYPE<B>::value);
122128
return {(uint64_t)disp_send[rank], (uint64_t)disp_recv[rank]};
129+
#endif
130+
return {-1,-1};
123131
}
124132
};

src/include/ddptensor/CppTypes.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,20 +132,21 @@ using id_type = uint64_t;
132132

133133
enum FactoryId : int {
134134
F_ARANGE,
135+
F_EWBINOP,
136+
F_EWUNYOP,
135137
F_FROMSHAPE,
136138
F_FULL,
137-
F_UNYOP,
138-
F_EWUNYOP,
139+
F_GETITEM,
139140
F_IEWBINOP,
140-
F_EWBINOP,
141-
F_REDUCEOP,
142-
F_MANIPOP,
143141
F_LINALGOP,
144-
F_GETITEM,
145-
F_SETITEM,
142+
F_MANIPOP,
146143
F_RANDOM,
144+
F_REDUCEOP,
147145
F_SERVICE,
146+
F_SETITEM,
147+
F_SORTOP,
148148
F_TONUMPY,
149+
F_UNYOP,
149150
FACTORY_LAST
150151
};
151152

0 commit comments

Comments
 (0)