Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit a6a8d78

Browse files
committed
adding PGAS feature 'get_slice'
1 parent f25ca19 commit a6a8d78

File tree

9 files changed

+88
-41
lines changed

9 files changed

+88
-41
lines changed

ddptensor/ddptensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,6 @@ def __getitem__(self, *args):
116116

117117
def __setitem__(self, key, value):
118118
x = self._t.__setitem__(key, value._t if isinstance(value, dtensor) else value)
119+
120+
def get_slice(self, *args):
121+
return self._t.get_slice(*args)

src/MPIMediator.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ MPIMediator::MPIMediator()
3030
MPIMediator::~MPIMediator()
3131
{
3232
std::cerr << "MPIMediator::~MPIMediator()" << std::endl;
33+
MPI_Barrier(MPI_COMM_WORLD);
3334
int rank;
3435
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
3536
Buffer buff;
@@ -40,7 +41,6 @@ MPIMediator::~MPIMediator()
4041
MPI_Send(buff.data(), buff.size(), MPI_CHAR, rank, PULL_TAG, MPI_COMM_WORLD);
4142
_listener.join();
4243
s_ak.clear();
43-
MPI_Barrier(MPI_COMM_WORLD);
4444
}
4545

4646
uint64_t MPIMediator::register_array(tensor_i::ptr_type ary)
@@ -49,7 +49,7 @@ uint64_t MPIMediator::register_array(tensor_i::ptr_type ary)
4949
return s_last_id;
5050
}
5151

52-
void MPIMediator::pull(rank_type from, const tensor_i::ptr_type & ary, const NDSlice & slice, void * rbuff)
52+
void MPIMediator::pull(rank_type from, const tensor_i * ary, const NDSlice & slice, void * rbuff)
5353
{
5454
MPI_Comm comm = MPI_COMM_WORLD;
5555
MPI_Request request[2];

src/ddptensor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ class dtensor
5050
return _tensor->dtype();
5151
}
5252

53+
py::object get_slice(const std::vector<py::slice> & v)
54+
{
55+
return _tensor->get_slice(NDSlice(v));
56+
}
57+
5358
dtensor __getitem__(const NDIndex & v)
5459
{
5560
return dtensor(_tensor->__getitem__(NDSlice(v)));
@@ -224,7 +229,8 @@ PYBIND11_MODULE(_ddptensor, m) {
224229
.def("__getitem__", py::overload_cast<const std::vector<py::slice> &>(&dtensor::__getitem__))
225230
.def("__getitem__", py::overload_cast<const py::slice &>(&dtensor::__getitem__))
226231
.def("__getitem__", py::overload_cast<int64_t>(&dtensor::__getitem__))
227-
.def("__setitem__", &dtensor::__setitem__);
232+
.def("__setitem__", &dtensor::__setitem__)
233+
.def("get_slice", &dtensor::get_slice);
228234

229235
//py::class_<dpdlpack>(m, "dpdlpack")
230236
// .def("__dlpack__", &dpdlpack.__dlpack__);

src/include/ddptensor/MPIMediator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class MPIMediator : public Mediator
1313
MPIMediator();
1414
virtual ~MPIMediator();
1515
virtual uint64_t register_array(tensor_i::ptr_type ary);
16-
virtual void pull(rank_type from, const tensor_i::ptr_type & ary, const NDSlice & slice, void * buffer);
16+
virtual void pull(rank_type from, const tensor_i * ary, const NDSlice & slice, void * buffer);
1717

1818
protected:
1919
void listen();

src/include/ddptensor/Mediator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Mediator
1313
public:
1414
virtual ~Mediator() {}
1515
virtual uint64_t register_array(tensor_i::ptr_type ary) = 0;
16-
virtual void pull(rank_type from, const tensor_i::ptr_type & ary, const NDSlice & slice, void * buffer) = 0;
16+
virtual void pull(rank_type from, const tensor_i * ary, const NDSlice & slice, void * buffer) = 0;
1717
};
1818

1919
extern Mediator * theMediator;

src/include/ddptensor/PVSlice.hpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
using offsets_type = std::vector<uint64_t>;
99

10+
constexpr static int NOSPLIT = -1;
11+
1012
class BasePVSlice
1113
{
1214
uint64_t _offset;
@@ -18,13 +20,13 @@ class BasePVSlice
1820
BasePVSlice(const BasePVSlice &) = delete;
1921
BasePVSlice(BasePVSlice &&) = default;
2022
BasePVSlice(const shape_type & shape, int split=0)
21-
: _offset((shape[split] + theTransceiver->nranks() - 1) / theTransceiver->nranks()),
23+
: _offset(split == NOSPLIT ? 0 : (shape[split] + theTransceiver->nranks() - 1) / theTransceiver->nranks()),
2224
_shape(shape),
2325
_split_dim(split)
2426
{
2527
}
2628
BasePVSlice(shape_type && shape, int split=0)
27-
: _offset((shape[split] + theTransceiver->nranks() - 1) / theTransceiver->nranks()),
29+
: _offset(split == NOSPLIT ? 0 : (shape[split] + theTransceiver->nranks() - 1) / theTransceiver->nranks()),
2830
_shape(std::move(shape)),
2931
_split_dim(split)
3032
{
@@ -35,6 +37,9 @@ class BasePVSlice
3537
const shape_type & shape() const { return _shape; }
3638
shape_type shape(rank_type rank) const
3739
{
40+
if(split_dim() == NOSPLIT) {
41+
return rank == theTransceiver->rank() ? _shape : shape_type();
42+
}
3843
shape_type shp(_shape);
3944
auto end = (rank+1) * _offset;
4045
if(end <= _shape[_split_dim]) shp[_split_dim] = _offset;
@@ -43,7 +48,7 @@ class BasePVSlice
4348
}
4449
rank_type owner(const NDSlice & slice) const
4550
{
46-
return slice.dim(split_dim())._start / offset();
51+
return split_dim() == NOSPLIT ? theTransceiver->rank() : slice.dim(split_dim())._start / offset();
4752
}
4853
};
4954

@@ -150,10 +155,12 @@ class PVSlice
150155
return _slice;
151156
}
152157

158+
#if 0
153159
NDSlice normalized_slice() const
154160
{
155161
return _slice.normalize(_base->split_dim());
156162
}
163+
#endif
157164

158165
NDSlice map_slice(const NDSlice & slc) const
159166
{
@@ -162,11 +169,17 @@ class PVSlice
162169

163170
NDSlice slice_of_rank(rank_type rank) const
164171
{
172+
if(_base->split_dim() == NOSPLIT) {
173+
return rank == theTransceiver->rank() ? slice() : NDSlice();
174+
}
165175
return _slice.trim(_base->split_dim(), rank * _base->offset(), (rank+1) * _base->offset());
166176
}
167177

168178
NDSlice local_slice_of_rank(rank_type rank) const
169179
{
180+
if(_base->split_dim() == NOSPLIT) {
181+
return rank == theTransceiver->rank() ? slice() : NDSlice();
182+
}
170183
return _slice.trim_shift(_base->split_dim(),
171184
rank * _base->offset(),
172185
(rank+1) * _base->offset(),
@@ -175,6 +188,7 @@ class PVSlice
175188

176189
bool need_reduce(const dim_vec_type & dims) const
177190
{
191+
if(_base->split_dim() == NOSPLIT) return false;
178192
auto nd = dims.size();
179193
// Reducing to a single scalar or over a subset of dimensions *including* the split axis.
180194
if(nd == 0

src/include/ddptensor/ddptensor_impl.hpp

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,17 @@ class dtensor_impl : public tensor_i
224224
}
225225

226226
// since the API works on tensor_i we need to downcast to the actual type
227-
const dtensor_impl<T> * cast(const ptr_type & b) const
227+
static dtensor_impl<T> * cast(ptr_type & b)
228228
{
229229
// FIXME; use attribute/vfunction + reinterpret_cast for better performance
230-
auto ptr = dynamic_cast<const dtensor_impl<T>*>(b.get());
230+
auto ptr = dynamic_cast<dtensor_impl<T>*>(b.get());
231231
// if(ptr == nullptr) throw(std::runtime_error("Incompatible tensor types."));
232232
return ptr;
233233
}
234+
static const dtensor_impl<T> * cast(const ptr_type & b)
235+
{
236+
return cast(const_cast<ptr_type &>(b));
237+
}
234238

235239
ptr_type _ew_op(const char * op, const char * mod, py::args args, const py::kwargs & kwargs)
236240
{
@@ -331,42 +335,28 @@ class dtensor_impl : public tensor_i
331335
}
332336
}
333337

334-
// FIXME We use a generic SPMD/PGAS mechanism to pull elements from remote
335-
// on all procs simultaneously. Since __setitem__ is collective we could
336-
// implement a probaly more efficient mechanism which pushes data and/or using RMA.
337-
void __setitem__(const NDSlice & slice, const ptr_type & val)
338+
// copy data from val into (*dest)[slice]
339+
// this is a non-collective call.
340+
static void _set_slice(const dtensor_impl<T> * val, const NDSlice & val_slice, dtensor_impl<T> * dest, const NDSlice & dest_slice)
338341
{
339-
std::cerr << " __setitem__ " << slice << " " << val->pvslice().slice() << std::endl;
340-
auto nd = shape().size();
341-
if(owner() == REPLICATED && nd > 0)
342+
std::cerr << "_set_slice " << val_slice << " " << dest_slice << std::endl;
343+
auto nd = dest->shape().size();
344+
if(dest->owner() == REPLICATED && nd > 0)
342345
std::cerr << "Warning: __setitem__ on replicated data updates local tile only" << std::endl;
343-
if(nd != slice.ndims())
346+
if(nd != dest_slice.ndims())
344347
throw std::runtime_error("Index dimensionality must match array dimensionality");
348+
if(val_slice.size() != dest_slice.size())
349+
throw std::runtime_error("Input and output slices must be of same size");
345350

346-
auto slc_sz = slice.size();
347-
auto val_sz = VPROD(val->shape());
348-
if(slc_sz != val_sz)
349-
throw std::runtime_error("Given tensor does not match: it has different size than 'slice'");
350-
351-
NDSlice norm_slice = pvslice().normalized_slice();
352-
std::cerr << "norm_slice: " << norm_slice << std::endl;
353351
// Use given slice to create a global view into orig array
354-
PVSlice g_slc_view(pvslice(), slice);
352+
PVSlice g_slc_view(dest->pvslice(), dest_slice);
355353
std::cerr << "g_slice: " << g_slc_view.slice() << std::endl;
356-
PVSlice my_view(g_slc_view, theTransceiver->rank());
357-
NDSlice my_slice = my_view.slice();
358-
std::cerr << "my_slice: " << my_slice << std::endl;
359-
NDSlice my_norm_slice = g_slc_view.map_slice(my_slice);
360-
std::cerr << "my_norm_slice: " << my_norm_slice << std::endl;
361-
362354
// Create a view into val
363-
PVSlice needed_val_view(val->pvslice(), my_norm_slice);
355+
PVSlice needed_val_view(val->pvslice(), val_slice);
364356
std::cerr << "needed_val_view: " << needed_val_view.slice() << " (was " << val->pvslice().slice() << ")" << std::endl;
365357

366358
// Get the pointer to the local buffer
367-
auto ns = get_array_impl(_pyarray);
368-
//auto my_binfo = _pyarray.cast<py::buffer>().request();
369-
// T * my_buffer = reinterpret_cast<T*>(my_binfo.ptr);
359+
auto ns = get_array_impl(dest->_pyarray);
370360

371361
// we can now compute which ranks actually hold which piece of the data from val that we need locally
372362
for(rank_type i=0; i<theTransceiver->nranks(); ++i ) {
@@ -377,7 +367,7 @@ class dtensor_impl : public tensor_i
377367
std::cerr << i << " curr_needed_val_slice: " << curr_needed_val_slice << std::endl;
378368
NDSlice curr_local_val_slice = val_local_view.map_slice(curr_needed_val_slice);
379369
std::cerr << i << " curr_local_val_slice: " << curr_local_val_slice << std::endl;
380-
NDSlice curr_needed_norm_slice = val->pvslice().map_slice(curr_needed_val_slice);
370+
NDSlice curr_needed_norm_slice = needed_val_view.map_slice(curr_needed_val_slice);
381371
std::cerr << i << " curr_needed_norm_slice: " << curr_needed_norm_slice << std::endl;
382372
PVSlice my_curr_needed_view = PVSlice(g_slc_view, curr_needed_norm_slice);
383373
std::cerr << i << " my_curr_needed_slice: " << my_curr_needed_view.slice() << std::endl;
@@ -387,23 +377,39 @@ class dtensor_impl : public tensor_i
387377
py::tuple tpl = _make_tuple(my_curr_local_slice); //my_curr_view.slice());
388378
if(i == theTransceiver->rank()) {
389379
// copy locally
390-
auto rhs = cast(val)->_pyarray.attr("__getitem__")(_make_tuple(curr_local_val_slice));
380+
auto rhs = val->_pyarray.attr("__getitem__")(_make_tuple(curr_local_val_slice));
391381
std::cerr << py::str(rhs).cast<std::string>() << std::endl;
392-
_pyarray.attr("__setitem__")(tpl, rhs);
382+
dest->_pyarray.attr("__setitem__")(tpl, rhs);
393383
} else {
394384
// pull slice directly into new array
395385
auto obj = ns.attr("empty")(_make_tuple(curr_local_val_slice.shape()));
396386
auto binfo = obj.cast<py::buffer>().request();
397387
theMediator->pull(i, val, curr_local_val_slice, binfo.ptr);
398-
_pyarray.attr("__setitem__")(tpl, obj);
388+
dest->_pyarray.attr("__setitem__")(tpl, obj);
399389
}
400390
}
401391
}
402392
}
403393

394+
// FIXME We use a generic SPMD/PGAS mechanism to pull elements from remote
395+
// on all procs simultaneously. Since __setitem__ is collective we could
396+
// implement a probaly more efficient mechanism which pushes data and/or using RMA.
397+
void __setitem__(const NDSlice & slice, const ptr_type & val)
398+
{
399+
// Use given slice to create a global view into orig array
400+
PVSlice g_slc_view(this->pvslice(), slice);
401+
std::cerr << "g_slice: " << g_slc_view.slice() << std::endl;
402+
NDSlice my_slice = g_slc_view.slice_of_rank(theTransceiver->rank());
403+
std::cerr << "my_slice: " << my_slice << std::endl;
404+
NDSlice my_norm_slice = g_slc_view.map_slice(my_slice);
405+
std::cerr << "my_norm_slice: " << my_norm_slice << std::endl;
406+
407+
_set_slice(cast(val), my_norm_slice, this, my_slice);
408+
}
409+
404410
void bufferize(const NDSlice & slice, Buffer & buff)
405411
{
406-
PVSlice my_local_view = PVSlice(tile_shape()); // pvslice().view_normalized_by_rank(theTransceiver->rank());
412+
PVSlice my_local_view = PVSlice(tile_shape());
407413
PVSlice lview = PVSlice(my_local_view, slice);
408414
NDSlice lslice = lview.slice();
409415

@@ -422,6 +428,14 @@ class dtensor_impl : public tensor_i
422428
}
423429
}
424430

431+
py::object get_slice(const NDSlice & slice) const
432+
{
433+
auto shp = slice.shape();
434+
auto out = create_dtensor(PVSlice(shp, NOSPLIT), shp, DTYPE<T>::value, "empty");
435+
_set_slice(this, slice, cast(out), {shp});
436+
return cast(out)->_pyarray;
437+
}
438+
425439
std::string __repr__() const
426440
{
427441
return "dtensor(shape=" + to_string(shape(), 'x') + ", n_tiles="

src/include/ddptensor/tensor_i.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,6 @@ class tensor_i
4444
virtual void _ew_binary_op_inplace(const char * op, const ptr_type & b) = 0;
4545
virtual void _ew_binary_op_inplace(const char * op, const py::object & b) = 0;
4646
virtual ptr_type _reduce_op(const char * op, const dim_vec_type & dims) const = 0;
47+
48+
virtual py::object get_slice(const NDSlice & slice) const = 0;
4749
};

test/test_spmd.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from mpi4py import MPI
2+
import ddptensor as dt
3+
a = dt.ones((8,8))
4+
MPI.COMM_WORLD.barrier()
5+
b = a.get_slice((slice(1, 3+MPI.COMM_WORLD.rank), slice(2, 4+MPI.COMM_WORLD.rank)))
6+
print(type(b), b.shape, float(b[1,1]))
7+
print("done")
8+
dt.fini()

0 commit comments

Comments
 (0)