Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 2a9da9c

Browse files
committed
allow unequally tiled arrays in vecdot
1 parent c27786f commit 2a9da9c

File tree

3 files changed

+54
-38
lines changed

3 files changed

+54
-38
lines changed

src/LinAlgOp.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ namespace x {
5454
template<typename A, typename B>
5555
static ptr_type matmul_2d(const std::shared_ptr<DPTensorX<A>> & a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr, int axis)
5656
{
57-
if(!a_ptr->slice().is_equally_tiled() || !b_ptr->slice().is_equally_tiled())
58-
throw(std::runtime_error("vecdoc_2d supported for eually tiled tensors only"));
5957
if(a_ptr->slice().split_dim() != 0)
6058
throw(std::runtime_error("vecdoc_2d supported for split_dim=0 only"));
6159

@@ -64,37 +62,40 @@ namespace x {
6462
rank_type right = (me + 1) % nr;
6563
rank_type left = (nr + me - 1) % nr;
6664
auto tsz = b_ptr->slice().tile_size(0);
67-
auto tshpa = a_ptr->slice().tile_shape(0);
68-
auto tshpb = b_ptr->slice().tile_shape(0);
65+
auto my_tshp_a = a_ptr->slice().tile_shape(me);
66+
auto tshp_b = b_ptr->slice().tile_shape(0);
67+
auto my_tshp_b = me == 0 ? tshp_b : b_ptr->slice().tile_shape(me);
6968

7069
const auto & ax = a_ptr->xarray();
7170
const auto & bx = b_ptr->xarray();
72-
xt::xarray<A> cx = xt::zeros<A>({tshpa[0], tshpb[1]});
73-
auto buff = xt::empty_like(bx);
74-
buff = bx;
71+
xt::xarray<A> cx = xt::zeros<A>({my_tshp_a[0], tshp_b[1]});
72+
auto buff = xt::empty<B>(tshp_b);
73+
if(tshp_b[0] == my_tshp_b[0]) {
74+
buff = bx;
75+
} else { // last partitions can be smaller -> need a view to assign values
76+
auto bv = xt::view(buff, xt::range(0, my_tshp_b[0]), xt::range(0, my_tshp_b[1]));
77+
bv = bx;
78+
}
7579

7680
// We use an algo similar to Canon's
81+
// the last partitions can be smaller -> k depends on "me", the source rank of current partition
7782
for(rank_type i = nr; i>0; --i) {
78-
// std::cerr << me*tshpb[0] << " " << (1+me) * tshpb[0] << std::endl;
79-
// auto av = xt::view(ax, xt::all(), xt::range(me * tshpb[0], (1+me) * tshpb[0]));
80-
// cx = cx + xt::linalg::dot(av, buff);
81-
TGEMM<A>::tgemm(CblasRowMajor,
82-
CblasNoTrans,
83-
CblasNoTrans,
84-
tshpa[0],
85-
tshpb[1],
86-
tshpb[0],
87-
1, // alpha
88-
ax.data() + (me * tshpb[0]),
89-
tshpa[1], // lda
90-
buff.data(),
91-
tshpb[1], // ldb
92-
1, // beta
93-
cx.data(),
94-
tshpb[1]); // ldc
95-
83+
if(my_tshp_a[0]) {
84+
TGEMM<A>::tgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
85+
my_tshp_a[0], tshp_b[1], me == 0 ? tshp_b[0] : b_ptr->slice().tile_shape(me)[0],
86+
1, // alpha
87+
ax.data() + (me * tshp_b[0]),
88+
my_tshp_a[1], // lda
89+
buff.data(),
90+
tshp_b[1], // ldb
91+
1, // beta
92+
cx.data(),
93+
tshp_b[1]); // ldc
94+
}
95+
9696
if(i > 1) {
9797
// data exchange
98+
// FIXME: optimize data transfer: last partition might contain unused data
9899
theTransceiver->send_recv(buff.data(),
99100
tsz,
100101
DTYPE<A>::value,
@@ -103,7 +104,7 @@ namespace x {
103104
me = (me + 1) % nr;
104105
}
105106
}
106-
return operatorx<A>::mk_tx(std::move(PVSlice({tshpa[0], tshpb[1]})), cx);
107+
return operatorx<A>::mk_tx(std::move(PVSlice({a_ptr->shape()[0], b_ptr->shape()[1]})), cx);
107108
}
108109
};
109110
}

src/include/ddptensor/PVSlice.hpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,30 @@ class BasePVSlice
4848

4949
uint64_t tile_size(rank_type rank = theTransceiver->rank()) const
5050
{
51-
if(rank == 0 || rank < theTransceiver->nranks() - 1) return _tile_size;
52-
return VPROD(_shape) - (rank-1 * _tile_size);
51+
// only rank 0 is guaranteed to have _tile_size, all other parts can be < _tile_size
52+
if(rank == 0) return _tile_size;
53+
auto sz = VPROD(_shape);
54+
auto off = rank * _tile_size;
55+
if(sz >= off) return _tile_size;
56+
auto r = off - sz;
57+
// if r < _tile_size it's the remainder, otherwise we are past the end of the global array
58+
return r < _tile_size ? r : 0UL;
5359
}
5460

5561
shape_type tile_shape(rank_type rank = theTransceiver->rank()) const
5662
{
63+
// only rank 0 is guaranteed to have _tile_size, all other parts can be < _tile_size
5764
shape_type r(_shape);
58-
if(rank == 0 || rank < theTransceiver->nranks() - 1) r[_split_dim] = offset();
59-
else r[_split_dim] = r[_split_dim] - (rank-1 * offset());
65+
if(rank == 0) r[_split_dim] = offset();
66+
else {
67+
auto end = (rank+1) * offset();
68+
if(r[_split_dim] >= end) r[_split_dim] = offset();
69+
else {
70+
auto diff = end - r[_split_dim];
71+
// if diff < offset() it's the remainder, otherwise we are past the end of the global array
72+
r[_split_dim] = diff < offset() ? diff : 0UL;
73+
}
74+
}
6075
return r;
6176
}
6277

test/test_linalg.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
import ddptensor as dt
2-
a = dt.arange(1, 25, 1, dtype=dt.float64)
3-
b = dt.arange(1, 31, 1, dtype=dt.float64)
2+
a = dt.arange(1, 36, 1, dtype=dt.float64)
3+
b = dt.arange(1, 22, 1, dtype=dt.float64)
44
print("a", a)
55
print("b", b)
6-
a = dt.reshape(a, (4,6))
7-
b = dt.reshape(b, (6,5))
6+
a = dt.reshape(a, (5,7))
7+
b = dt.reshape(b, (7,3))
88
print("a", a)
99
print("b", b)
1010
print()
1111
c = dt.vecdot(a, b, 0)
1212
print(c)
1313

1414
import numpy as np
15-
a = np.arange(1, 25, 1, dtype=np.float64)
16-
b = np.arange(1, 31, 1, dtype=np.float64)
17-
a = np.reshape(a, (4,6))
18-
b = np.reshape(b, (6,5))
15+
a = np.arange(1, 36, 1, dtype=np.float64)
16+
b = np.arange(1, 22, 1, dtype=np.float64)
17+
a = np.reshape(a, (5,7))
18+
b = np.reshape(b, (7,3))
1919
c = np.dot(a, b)
2020
print(c)
2121

0 commit comments

Comments
 (0)