Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 181d463

Browse files
committed
adding spmd.get_local
1 parent cab2f8b commit 181d463

File tree

13 files changed

+203
-52
lines changed

13 files changed

+203
-52
lines changed

ddptensor/spmd.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from . import _ddptensor as _cdt
22

3-
def get_slice(self, *args):
4-
return _cdt._get_slice(self._t, *args)
3+
def get_slice(obj, *args):
4+
return _cdt._get_slice(obj._t, *args)
5+
6+
def get_local(obj):
7+
return _cdt._get_local(obj._t, obj)

src/Creator.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace x {
1414
static ptr_type op(CreatorId c, const shape_type & shp)
1515
{
1616
PVSlice pvslice(shp);
17-
shape_type shape(std::move(pvslice.tile_shape()));
17+
shape_type shape(std::move(pvslice.shape_of_rank()));
1818
switch(c) {
1919
case EMPTY:
2020
return operatorx<T>::mk_tx(std::move(pvslice), std::move(xt::empty<T>(std::move(shape))));
@@ -32,7 +32,7 @@ namespace x {
3232
{
3333
if(c == FULL) {
3434
PVSlice pvslice(shp);
35-
shape_type shape(std::move(pvslice.tile_shape()));
35+
shape_type shape(std::move(pvslice.shape_of_rank()));
3636
auto a = xt::empty<T>(std::move(shape));
3737
a.fill(to_native<T>(v));
3838
return operatorx<T>::mk_tx(std::move(pvslice), std::move(a));

src/ManipOp.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <mpi.h>
2+
#include "ddptensor/ManipOp.hpp"
3+
#include "ddptensor/TypeDispatch.hpp"
4+
#include "ddptensor/x.hpp"
5+
#include "ddptensor/CollComm.hpp"
6+
7+
namespace x {
8+
9+
class ManipOp
10+
{
11+
public:
12+
using ptr_type = DPTensorBaseX::ptr_type;
13+
14+
// Reshape
15+
// For now we always create a copy/new array.
16+
template<typename T>
17+
static ptr_type op(const shape_type & shape, const std::shared_ptr<DPTensorX<T>> & a_ptr)
18+
{
19+
auto b_ptr = x::operatorx<T>::mk_tx(shape);
20+
CollComm::coll_copy(b_ptr, a_ptr);
21+
return b_ptr;
22+
}
23+
};
24+
}
25+
26+
tensor_i::ptr_type ManipOp::reshape(x::DPTensorBaseX::ptr_type a, const shape_type & shape)
27+
{
28+
return TypeDispatch<x::ManipOp>(a, shape);
29+
}

src/Random.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace x {
1414
{
1515
if constexpr (std::is_floating_point<T>::value) {
1616
PVSlice pvslice(shp);
17-
shape_type shape(std::move(pvslice.tile_shape()));
17+
shape_type shape(std::move(pvslice.shape_of_rank()));
1818
return operatorx<T>::mk_tx(std::move(pvslice), std::move(xt::random::rand(std::move(shape), to_native<T>(lower), to_native<T>(upper))));
1919
}
2020
}

src/SetGetItem.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,37 +111,60 @@ namespace x {
111111
public:
112112
using ptr_type = DPTensorBaseX::ptr_type;
113113

114+
// get_slice
114115
template<typename T>
115116
static py::object op(const NDSlice & slice, const std::shared_ptr<DPTensorX<T>> & a_ptr)
116117
{
117118
auto shp = slice.shape();
118119
auto sz = VPROD(shp);
119120
auto res = py::array_t<T>(sz);
120121
auto ax = xt::adapt(res.mutable_data(), sz, xt::no_ownership(), shp);
121-
std::cerr << ax << std::endl << py::str(res).cast<std::string>() << res.mutable_data() << std::endl;
122-
// Create dtensor without creating id: do not use create_dtensor
123-
// auto out = DPTensorX<T>(ax, PVSlice(shp, NOSPLIT));
124122
PVSlice slc{shp, NOSPLIT};
125123
SetItem::_set_slice<T>(ax, slc, slc.slice(), a_ptr, slice);
126-
std::cerr << ax << std::endl << py::str(res).cast<std::string>() << std::endl;
127-
// res.reshape(shp);
128124
return res;
129125
}
126+
127+
// get_local
128+
template<typename T>
129+
static py::object op(py::handle & handle, const std::shared_ptr<DPTensorX<T>> & a_ptr)
130+
{
131+
auto slc = a_ptr->slice().local_slice_of_rank();
132+
auto tshp = a_ptr->slice().tile_shape();
133+
auto nd = slc.ndims();
134+
// buffer protocol accepts strides in number of bytes not elements!
135+
std::vector<uint64_t> strides(nd, sizeof(T));
136+
uint64_t off = slc.dim(nd-1)._start * sizeof(T); // start offset
137+
for(int i=nd-2; i>=0; --i) {
138+
auto slci = slc.dim(i);
139+
auto tmp = strides[i+1] * tshp[i+1];
140+
strides[i] = slci._step * tmp;
141+
off += slci._start * tmp;
142+
}
143+
off /= sizeof(T); // we need the offset in number of elements
144+
strides.back() = slc.dim(nd-1)._step * sizeof(T);
145+
T * data = a_ptr->xarray().data();
146+
return py::array(std::move(slc.shape()), std::move(strides), data + off, handle);
147+
}
130148
};
131149

132150
} // namespace x
133151

134-
void SetItem::__setitem__(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v, x::DPTensorBaseX::ptr_type b)
152+
void SetItem::__setitem__(tensor_i::ptr_type a, const std::vector<py::slice> & v, tensor_i::ptr_type b)
135153
{
136154
return TypeDispatch<x::SetItem>(a, b, NDSlice(v));
137155
}
138156

139-
tensor_i::ptr_type GetItem::__getitem__(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
157+
tensor_i::ptr_type GetItem::__getitem__(tensor_i::ptr_type a, const std::vector<py::slice> & v)
140158
{
141159
return TypeDispatch<x::GetItem>(a, NDSlice(v));
142160
}
143161

144-
py::object GetItem::get_slice(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
162+
py::object GetItem::get_slice(tensor_i::ptr_type a, const std::vector<py::slice> & v)
145163
{
146164
return TypeDispatch<x::SPMD>(a, NDSlice(v));
147165
}
166+
167+
py::object GetItem::get_local(tensor_i::ptr_type a, py::handle h)
168+
{
169+
return TypeDispatch<x::SPMD>(a, h);
170+
}

src/ddptensor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ PYBIND11_MODULE(_ddptensor, m) {
6262

6363
m.def("fini", &fini)
6464
.def("myrank", &myrank)
65-
.def("_get_slice", &GetItem::get_slice);
65+
.def("_get_slice", &GetItem::get_slice)
66+
.def("_get_local", &GetItem::get_local);
6667

6768
py::class_<Creator>(m, "Creator")
6869
.def("create_from_shape", &Creator::create_from_shape)

src/include/ddptensor/CollComm.hpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// SPDX-License-Identifier: BSD-3-Clause
2+
3+
#pragma once
4+
5+
#include "UtilsAndTypes.hpp"
6+
#include "x.hpp"
7+
8+
struct CollComm
9+
{
10+
// We assume we split in first dimension.
11+
// We also assume partitions are assigned to ranks in sequence from 0-N.
12+
// With this we know that our buffers (old and new) get data in the
13+
// same order. The only thing which might have changed is the tile-size.
14+
// Actually, the tile-size might change only if old or new shape does not evenly
15+
// distribute data (e.g. last partition is smaller).
16+
// In theory we could re-shape in-place when the norm-tile-size does not change.
17+
// This is not implemented: we need an extra mechanism to work with reshape-views or alike.
18+
template<typename T, typename U>
19+
static tensor_i::ptr_type coll_copy(std::shared_ptr<x::DPTensorX<T>> b_ptr, const std::shared_ptr<x::DPTensorX<U>> & a_ptr)
20+
{
21+
assert(! a_ptr->is_sliced() && ! b_ptr->is_sliced());
22+
23+
auto o_slc = a_ptr->slice();
24+
// norm tile-size of orig array
25+
auto o_ntsz = o_slc.tile_size(0);
26+
// tilesize of my local partition of orig array
27+
auto o_tsz = o_slc.tile_size();
28+
// linearized local slice of orig array
29+
auto o_llslc = Slice(o_ntsz * theTransceiver->rank(), o_ntsz * theTransceiver->rank() + o_tsz);
30+
31+
PVSlice n_slc = b_ptr->slice();
32+
// norm tile-size of new (reshaped) array
33+
auto n_ntsz = n_slc.tile_size(0);
34+
// tilesize of my local partition of new (reshaped) array
35+
auto n_tsz = n_slc.tile_size();
36+
// linearized/flattened/1d local slice of new (reshaped) array
37+
auto n_llslc = Slice(n_ntsz * theTransceiver->rank(), n_ntsz * theTransceiver->rank() + n_tsz);
38+
39+
auto nr = theTransceiver->nranks();
40+
// We need a few C-arrays for MPI (counts and displacements in send/recv buffers)
41+
int counts_send[nr] = {0};
42+
int disp_send[nr] = {0};
43+
int counts_recv[nr] = {0};
44+
int disp_recv[nr] = {0};
45+
46+
for(auto r=0; r<nr; ++r) {
47+
// determine what I receive from rank r
48+
// e.g. which parts of my new slice overlap with rank r's old slice
49+
// Get local slice of rank r of orig array
50+
auto o_rslc = o_slc.local_slice_of_rank(r);
51+
// Flatten to 1d
52+
auto o_lrslc = Slice(o_ntsz * r, o_ntsz * r + o_rslc.size());
53+
// Determine overlap with local partition of linearized new array
54+
auto roverlap = n_llslc.overlap(o_lrslc);
55+
// number of elements to be received from rank r
56+
counts_recv[r] = roverlap.size();
57+
// displacement in new array where elements from rank r get copied to
58+
disp_recv[r] = roverlap._start - n_llslc._start;
59+
60+
// determine what I send to rank r
61+
// e.g. which parts of my old slice overlap with rank r's new slice
62+
// Get local slice of rank r of new array
63+
auto n_rslc = n_slc.local_slice_of_rank(r);
64+
// Flatten to 1d
65+
auto n_lrslc = Slice(n_ntsz * r, n_ntsz * r + n_rslc.size());
66+
// Determine overlap with local partition of linearized orig array
67+
auto soverlap = o_llslc.overlap(n_lrslc);
68+
// number of elements to be send to rank r
69+
counts_send[r] = soverlap.size();
70+
// displacement in orig array where elements from rank r get copied from
71+
disp_send[r] = soverlap._start - o_llslc._start;
72+
}
73+
74+
// Now we can send/recv directly to/from xarray buffers.
75+
theTransceiver->alltoall(a_ptr->xarray().data(), // buffer_send
76+
counts_send,
77+
disp_send,
78+
DTYPE<T>::value,
79+
b_ptr->xarray().data(), // buffer_recv
80+
counts_recv,
81+
disp_recv,
82+
DTYPE<T>::value);
83+
84+
return b_ptr;
85+
}
86+
};

src/include/ddptensor/NDIndex.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
///
1010
typedef std::vector<int64_t> NDIndex;
1111

12+
#if 0
1213
///
1314
/// @return tile-sizes for each dimension, as if leading dimensions were cut.
1415
/// @param tile_shape tile-shape in question
@@ -42,3 +43,4 @@ uint64_t linearize(const std::vector<T> & idx, const std::vector<uint64_t> & tss
4243
}
4344
return tidx;
4445
}
46+
#endif

src/include/ddptensor/NDSlice.hpp

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ class NDSlice {
9797
///
9898
/// @return total number of elements represented by the nd-slice
9999
///
100-
value_type::value_type size() const
100+
value_type::value_type size(uint64_t dim = 0) const
101101
{
102102
if(_sizes.empty()) init_sizes();
103-
return _sizes[0];
103+
return _sizes[dim];
104104
}
105105

106106
///
@@ -120,29 +120,6 @@ class NDSlice {
120120
_sizes.resize(0);
121121
}
122122

123-
///
124-
/// @return ith index-tuple in canonical (flat) order of the expanded slice.
125-
/// does not check bounds, e.g. can return indices beyond end of slice
126-
///
127-
value_type operator[](value_type::value_type i) const {
128-
if(_sizes.empty()) init_sizes();
129-
value_type ret(_slice_vec.size(), 0);
130-
auto sz = ++(_sizes.begin());
131-
auto slc = _slice_vec.rbegin();
132-
// iterate over dimensions to compute ith index
133-
for(auto v = ret.begin(); v != ret.end(); ++v, ++slc) {
134-
if(sz != _sizes.end()) {
135-
auto idx = i / (*sz);
136-
*v = (*slc)[idx];
137-
i -= idx * (*sz);
138-
++sz;
139-
} else {
140-
*v = (*slc)[i];
141-
}
142-
}
143-
return ret;
144-
}
145-
146123
template<typename C>
147124
NDSlice _convert(const C & conv) const
148125
{

src/include/ddptensor/PVSlice.hpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,18 @@ class BasePVSlice
3636
}
3737

3838
uint64_t offset() const { return _offset; }
39-
uint64_t tile_size() const { return _tile_size; }
39+
uint64_t tile_size(rank_type rank = theTransceiver->rank()) const
40+
{
41+
if(rank < theTransceiver->nranks() - 1) return _tile_size;
42+
return VPROD(_shape) - (rank-1 * _tile_size);
43+
}
44+
shape_type tile_shape(rank_type rank = theTransceiver->rank()) const
45+
{
46+
shape_type r(_shape);
47+
if(rank < theTransceiver->nranks() - 1) r[_split_dim] = offset();
48+
else r[_split_dim] = r[_split_dim] - (rank-1 * offset());
49+
return r;
50+
}
4051
int split_dim() const { return _split_dim; }
4152
const shape_type & shape() const { return _shape; }
4253
shape_type shape(rank_type rank) const
@@ -129,9 +140,14 @@ class PVSlice
129140
return _base->split_dim();
130141
}
131142

132-
const uint64_t tile_size() const
143+
const bool is_sliced() const
144+
{
145+
return base_shape() != shape();
146+
}
147+
148+
const uint64_t tile_size(rank_type rank = theTransceiver->rank()) const
133149
{
134-
return _base->tile_size();
150+
return _base->tile_size(rank);
135151
}
136152

137153
const shape_type & shape() const
@@ -145,6 +161,11 @@ class PVSlice
145161
}
146162

147163
const shape_type tile_shape(rank_type rank = theTransceiver->rank()) const
164+
{
165+
return _base->tile_shape(rank);
166+
}
167+
168+
const shape_type shape_of_rank(rank_type rank = theTransceiver->rank()) const
148169
{
149170
return slice_of_rank(rank).shape();
150171
}

0 commit comments

Comments
 (0)