Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit b2a7b0d

Browse files
committed
adding reshape (partial) and arange
1 parent 077bffd commit b2a7b0d

File tree

13 files changed

+366
-245
lines changed

13 files changed

+366
-245
lines changed

ddptensor/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,20 @@
5656
exec(
5757
f"{func} = lambda shape, val, dtype: dtensor(_cdt.Creator.full(_cdt.{FUNC}, shape, val, dtype))"
5858
)
59+
elif func == "arange":
60+
exec(
61+
f"{func} = lambda start, end, step, dtype: dtensor(_cdt.Creator.arange(start, end, step, dtype))"
62+
)
5963

6064
for func in api.api_categories["ReduceOp"]:
6165
FUNC = func.upper()
6266
exec(
6367
f"{func} = lambda this, dim: dtensor(_cdt.ReduceOp.op(_cdt.{FUNC}, this._t, dim))"
6468
)
69+
70+
for func in api.api_categories["ManipOp"]:
71+
FUNC = func.upper()
72+
if func == "reshape":
73+
exec(
74+
f"{func} = lambda this, shape: dtensor(_cdt.ManipOp.reshape(this._t, shape))"
75+
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def build_cmake(self, ext):
2929
extdir.parent.mkdir(parents=True, exist_ok=True)
3030

3131
# example of cmake args
32-
config = 'Debug' if self.debug else 'Release'
32+
config = 'Debug'# if self.debug else 'Release'
3333
cmake_args = [
3434
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(extdir.parent.absolute()),
3535
'-DCMAKE_BUILD_TYPE=' + config

src/Creator.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ namespace x {
3838
}
3939
throw std::runtime_error("Unknown creator");
4040
}
41+
42+
static ptr_type op(uint64_t start, uint64_t end, uint64_t step)
43+
{
44+
PVSlice pvslice({Slice(start, end, step).size()});
45+
auto lslc = pvslice.slice_of_rank();
46+
const auto & l1dslc = lslc.dim(0);
47+
auto a = xt::arange<T>(start + l1dslc._start*step, start + l1dslc._end * step, l1dslc._step);
48+
return operatorx<T>::mk_tx(std::move(pvslice), std::move(a));
49+
}
4150
}; // class creatorx
4251
} // namespace x
4352

@@ -51,3 +60,8 @@ tensor_i::ptr_type Creator::full(const shape_type & shape, const py::object & va
5160
auto op = FULL;
5261
return TypeDispatch<x::Creator>(dtype, op, shape, val);
5362
}
63+
64+
tensor_i::ptr_type Creator::arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype)
65+
{
66+
return TypeDispatch<x::Creator>(dtype, start, end, step);
67+
}

src/MPITransceiver.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,23 @@ void MPITransceiver::reduce_all(void * inout, DTypeId T, size_t N, RedOpType op)
7272
{
7373
MPI_Allreduce(MPI_IN_PLACE, inout, N, to_mpi(T), to_mpi(op), MPI_COMM_WORLD);
7474
}
75+
76+
void MPITransceiver::alltoall(const void* buffer_send,
77+
const int* counts_send,
78+
const int* displacements_send,
79+
DTypeId datatype_send,
80+
void* buffer_recv,
81+
const int* counts_recv,
82+
const int* displacements_recv,
83+
DTypeId datatype_recv)
84+
{
85+
MPI_Alltoallv(buffer_send,
86+
counts_send,
87+
displacements_send,
88+
to_mpi(datatype_send),
89+
buffer_recv,
90+
counts_recv,
91+
displacements_recv,
92+
to_mpi(datatype_recv),
93+
MPI_COMM_WORLD);
94+
}

src/ddptensor.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ PYBIND11_MODULE(_ddptensor, m) {
6060

6161
py::class_<Creator>(m, "Creator")
6262
.def("create_from_shape", &Creator::create_from_shape)
63-
.def("full", &Creator::full);
63+
.def("full", &Creator::full)
64+
.def("arange", &Creator::arange);
6465

6566
py::class_<EWUnyOp>(m, "EWUnyOp")
6667
.def("op", &EWUnyOp::op);
@@ -74,6 +75,9 @@ PYBIND11_MODULE(_ddptensor, m) {
7475
py::class_<ReduceOp>(m, "ReduceOp")
7576
.def("op", &ReduceOp::op);
7677

78+
py::class_<ManipOp>(m, "ManipOp")
79+
.def("reshape", &ManipOp::reshape);
80+
7781
py::class_<tensor_i, tensor_i::ptr_type>(m, "DPTensorX")
7882
.def_property_readonly("dtype", &tensor_i::dtype)
7983
.def_property_readonly("shape", &tensor_i::shape)

src/include/ddptensor/MPITransceiver.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ class MPITransceiver : public Transceiver
2323
virtual void barrier();
2424
virtual void bcast(void * ptr, size_t N, rank_type root);
2525
virtual void reduce_all(void * inout, DTypeId T, size_t N, RedOpType op);
26+
virtual void alltoall(const void* buffer_send,
27+
const int* counts_send,
28+
const int* displacements_send,
29+
DTypeId datatype_send,
30+
void* buffer_recv,
31+
const int* counts_recv,
32+
const int* displacements_recv,
33+
DTypeId datatype_recv);
2634

2735
private:
2836
rank_type _nranks, _rank;

src/include/ddptensor/NDSlice.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,26 @@ class NDSlice {
201201
} );
202202
}
203203

204+
///
205+
/// @return Copy of NDSlice which was trimmed by given slice
206+
///
207+
NDSlice trim(const NDSlice & slc) const
208+
{
209+
return _convert([&](uint64_t i) {
210+
return _slice_vec[i].trim(slc.dim(i)._start, slc.dim(i)._end);
211+
} );
212+
}
213+
214+
///
215+
/// @return Copy of NDSlice which was trimmed by given slice
216+
///
217+
NDSlice overlap(const NDSlice & slc) const
218+
{
219+
return _convert([&](uint64_t i) {
220+
return _slice_vec[i].overlap(slc.dim(i));
221+
} );
222+
}
223+
204224
///
205225
/// @return Copy of NDSlice which was trimmed by t_slc and shifted by s_slc._start's
206226
///

src/include/ddptensor/Operations.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ struct Creator
1111
{
1212
static tensor_i::ptr_type create_from_shape(CreatorId op, const shape_type & shape, DTypeId dtype=FLOAT64);
1313
static tensor_i::ptr_type full(const shape_type & shape, const py::object & val, DTypeId dtype=FLOAT64);
14+
static tensor_i::ptr_type arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype=INT64);
1415
};
1516

1617
struct IEWBinOp
@@ -44,6 +45,11 @@ struct SetItem
4445
static void __setitem__(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v, x::DPTensorBaseX::ptr_type b);
4546
};
4647

48+
struct ManipOp
49+
{
50+
static tensor_i::ptr_type reshape(x::DPTensorBaseX::ptr_type a, const shape_type & shape);
51+
};
52+
4753

4854
// Dependent on dt, dispatch arguments to a operation class.
4955
// The operation must

src/include/ddptensor/PVSlice.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ constexpr static int NOSPLIT = -1;
1212
class BasePVSlice
1313
{
1414
uint64_t _offset;
15+
uint64_t _tile_size;
1516
shape_type _shape;
1617
int _split_dim;
1718

@@ -24,15 +25,18 @@ class BasePVSlice
2425
_shape(shape),
2526
_split_dim(split)
2627
{
28+
_tile_size = VPROD(_shape) / shape[_split_dim] * _offset;
2729
}
2830
BasePVSlice(shape_type && shape, int split=0)
2931
: _offset(split == NOSPLIT ? 0 : (shape[split] + theTransceiver->nranks() - 1) / theTransceiver->nranks()),
3032
_shape(std::move(shape)),
3133
_split_dim(split)
3234
{
35+
_tile_size = VPROD(_shape) / _shape[_split_dim] * _offset;
3336
}
3437

3538
uint64_t offset() const { return _offset; }
39+
uint64_t tile_size() const { return _tile_size; }
3640
int split_dim() const { return _split_dim; }
3741
const shape_type & shape() const { return _shape; }
3842
shape_type shape(rank_type rank) const
@@ -125,6 +129,11 @@ class PVSlice
125129
return _base->split_dim();
126130
}
127131

132+
const uint64_t tile_size() const
133+
{
134+
return _base->tile_size();
135+
}
136+
128137
const shape_type & shape() const
129138
{
130139
if(_shape.size() != _slice.ndims()) {

0 commit comments

Comments
 (0)