Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 157888f

Browse files
committed
adding basic vecdot/matmul
1 parent 181d463 commit 157888f

File tree

16 files changed

+275
-30
lines changed

16 files changed

+275
-30
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ find_package(MPI REQUIRED)
2121
#find_package(OpenMP)
2222

2323
set(MKL_LIBRARIES -L$ENV{MKLROOT}/lib -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core -liomp5 -lpthread -lrt -ldl -lm)
24+
#set(CMAKE_INSTALL_RPATH $ENV{MKLROOT}/lib)
2425
# Use -fPIC even if statically compiled
2526
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
2627

@@ -41,7 +42,7 @@ set(MyCppSources ${MyCppSources} ${P2C_HPP})
4142

4243
pybind11_add_module(_ddptensor MODULE ${MyCppSources})
4344

44-
target_compile_definitions(_ddptensor PRIVATE USE_MKL=1 XTENSOR_USE_XSIMD=1 XTENSOR_USE_OPENMP=1 DDPT_2TYPES=1)
45+
target_compile_definitions(_ddptensor PRIVATE XTENSOR_USE_XSIMD=1 XTENSOR_USE_OPENMP=1 DDPT_2TYPES=1 USE_MKL=1)
4546
target_include_directories(_ddptensor PRIVATE
4647
${PROJECT_SOURCE_DIR}/src/include
4748
${PROJECT_SOURCE_DIR}/third_party/xtl/include

ddptensor/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,18 @@
7373
exec(
7474
f"{func} = lambda this, shape: dtensor(_cdt.ManipOp.reshape(this._t, shape))"
7575
)
76+
77+
for func in api.api_categories["LinAlgOp"]:
78+
FUNC = func.upper()
79+
if func in ["tensordot", "vecdot",]:
80+
exec(
81+
f"{func} = lambda this, other, axis: dtensor(_cdt.LinAlgOp.{func}(this._t, other._t, axis))"
82+
)
83+
elif func == "matmul":
84+
exec(
85+
f"{func} = lambda this, other: dtensor(_cdt.LinAlgOp.vecdot(this._t, other._t, 0))"
86+
)
87+
elif func == "matrix_transpose":
88+
exec(
89+
f"{func} = lambda this: dtensor(_cdt.LinAlgOp.{func}(this._t))"
90+
)

ddptensor/array_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@
183183
"squeeze", # (x, /, axis)
184184
"stack", # (arrays, /, *, axis=0)
185185
],
186+
187+
"LinAlgOp" : [
188+
"matmul", # (x1, x2, /)
189+
"matrix_transpose", # (x, /)
190+
"tensordot", # (x1, x2, /, *, axes=2)
191+
"vecdot", # (x1, x2, /, *, axis=-1)
192+
],
186193
})
187194

188195
misc_methods = [

ddptensor/random.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from . import _ddptensor as _cdt
2+
from . import float64
23
from . ddptensor import dtensor
34

4-
def uniform(low, high, size, dtype=_cdt.float64):
5+
def uniform(low, high, size, dtype=float64):
56
return dtensor(_cdt.Random.uniform(dtype, size, low, high))
67

78
def seed(s):

src/EWBinOp.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "ddptensor/EWBinOp.hpp"
2+
#include "ddptensor/LinAlgOp.hpp"
23
#include "ddptensor/TypeDispatch.hpp"
34
#include "ddptensor/x.hpp"
45

@@ -17,14 +18,14 @@ namespace x {
1718
if(a_ptr->is_sliced() || b_ptr->is_sliced()) {
1819
const auto & av = xt::strided_view(ax, a_ptr->lslice());
1920
const auto & bv = xt::strided_view(bx, b_ptr->lslice());
20-
return do_op(bop, av, bv, a_ptr);
21+
return do_op(bop, av, bv, a_ptr, b_ptr);
2122
}
22-
return do_op(bop, ax, bx, a_ptr);
23+
return do_op(bop, ax, bx, a_ptr, b_ptr);
2324
}
2425

2526
#pragma GCC diagnostic ignored "-Wswitch"
26-
template<typename T1, typename T2, typename A>
27-
static ptr_type do_op(EWBinOpId bop, const T1 & a, const T2 & b, const std::shared_ptr<DPTensorX<A>> & a_ptr)
27+
template<typename T1, typename T2, typename A, typename B>
28+
static ptr_type do_op(EWBinOpId bop, const T1 & a, const T2 & b, const std::shared_ptr<DPTensorX<A>> & a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr)
2829
{
2930
switch(bop) {
3031
case __ADD__:
@@ -73,13 +74,20 @@ namespace x {
7374
case __RTRUEDIV__:
7475
return operatorx<A>::mk_tx_(a_ptr, b / a);
7576
case __MATMUL__:
77+
return LinAlgOp::vecdot(a_ptr, b_ptr, 0);
7678
case __POW__:
7779
case POW:
80+
return operatorx<A>::mk_tx_(a_ptr, xt::pow(a, b));
7881
case __RPOW__:
82+
return operatorx<A>::mk_tx_(a_ptr, xt::pow(b, a));
7983
case LOGADDEXP:
84+
return operatorx<A>::mk_tx_(a_ptr, xt::log(xt::exp(a) + xt::exp(b)));
8085
case LOGICAL_AND:
86+
// return operatorx<A>::mk_tx_(a_ptr, a && b);
8187
case LOGICAL_OR:
88+
// return operatorx<A>::mk_tx_(a_ptr, a || b);
8289
case LOGICAL_XOR:
90+
// return operatorx<A>::mk_tx_(a_ptr, xt::not_equal(!a, !b));
8391
// FIXME
8492
throw std::runtime_error("Binary operation not implemented");
8593
}

src/LinAlgOp.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#include <mpi.h>
2+
//#include <mkl.h>
3+
#include "ddptensor/LinAlgOp.hpp"
4+
#include "ddptensor/TypeDispatch.hpp"
5+
#include "ddptensor/x.hpp"
6+
#include "xtensor-blas/xlinalg.hpp"
7+
8+
namespace x {
9+
10+
template<typename T> struct TGEMM;
11+
template<> struct TGEMM<double> { static constexpr auto tgemm = cblas_dgemm; };
12+
template<> struct TGEMM<float> { static constexpr auto tgemm = cblas_sgemm; };
13+
14+
class LinAlgOp
15+
{
16+
public:
17+
using ptr_type = DPTensorBaseX::ptr_type;
18+
19+
template<typename A, typename B>
20+
static ptr_type op(int axis, const std::shared_ptr<DPTensorX<A>> & a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr)
21+
{
22+
if constexpr (std::is_floating_point<A>::value && std::is_same<A, B>::value) {
23+
const auto & ax = a_ptr->xarray();
24+
const auto & bx = b_ptr->xarray();
25+
auto nda = a_ptr->slice().ndims();
26+
auto ndb = b_ptr->slice().ndims();
27+
28+
if(a_ptr->is_sliced() || b_ptr->is_sliced()) {
29+
if(nda != 1 || ndb != 1)
30+
throw(std::runtime_error("vecdoc on sliced tensors supported for 1d tensors only"));
31+
const auto & av = xt::strided_view(ax, a_ptr->lslice());
32+
const auto & bv = xt::strided_view(bx, b_ptr->lslice());
33+
return vecdot_1d(av, bv, axis);
34+
}
35+
36+
if(nda == 1 && ndb == 1) {
37+
return vecdot_1d(ax, bx, axis);
38+
} else if(nda == 2 && ndb == 2) {
39+
return matmul_2d(a_ptr, b_ptr, axis);
40+
}
41+
throw(std::runtime_error("'vecdot' supported for two 1d or two 2d tensors only."));
42+
} else
43+
throw(std::runtime_error("'vecdot' supported for 2 double or float tensors only."));
44+
}
45+
46+
template<typename T1, typename T2>
47+
static ptr_type vecdot_1d(const T1 & a, const T2 & b, int axis)
48+
{
49+
auto d = xt::linalg::dot(a, b)();
50+
theTransceiver->reduce_all(&d, DTYPE<decltype(d)>::value, 1, SUM);
51+
return operatorx<decltype(d)>::mk_tx(d, REPLICATED);
52+
}
53+
54+
template<typename A, typename B>
55+
static ptr_type matmul_2d(const std::shared_ptr<DPTensorX<A>> & a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr, int axis)
56+
{
57+
if(!a_ptr->slice().is_equally_tiled() || !b_ptr->slice().is_equally_tiled())
58+
throw(std::runtime_error("vecdoc_2d supported for eually tiled tensors only"));
59+
if(a_ptr->slice().split_dim() != 0)
60+
throw(std::runtime_error("vecdoc_2d supported for split_dim=0 only"));
61+
62+
auto nr = theTransceiver->nranks();
63+
auto me = theTransceiver->rank();
64+
rank_type right = (me + 1) % nr;
65+
rank_type left = (nr + me - 1) % nr;
66+
auto tsz = b_ptr->slice().tile_size(0);
67+
auto tshpa = a_ptr->slice().tile_shape(0);
68+
auto tshpb = b_ptr->slice().tile_shape(0);
69+
70+
const auto & ax = a_ptr->xarray();
71+
const auto & bx = b_ptr->xarray();
72+
xt::xarray<A> cx = xt::zeros<A>({tshpa[0], tshpb[1]});
73+
auto buff = xt::empty_like(bx);
74+
buff = bx;
75+
76+
// We use an algo similar to Canon's
77+
for(rank_type i = nr; i>0; --i) {
78+
// std::cerr << me*tshpb[0] << " " << (1+me) * tshpb[0] << std::endl;
79+
// auto av = xt::view(ax, xt::all(), xt::range(me * tshpb[0], (1+me) * tshpb[0]));
80+
// cx = cx + xt::linalg::dot(av, buff);
81+
TGEMM<A>::tgemm(CblasRowMajor,
82+
CblasNoTrans,
83+
CblasNoTrans,
84+
tshpa[0],
85+
tshpb[1],
86+
tshpb[0],
87+
1, // alpha
88+
ax.data() + (me * tshpb[0]),
89+
tshpa[1], // lda
90+
buff.data(),
91+
tshpb[1], // ldb
92+
1, // beta
93+
cx.data(),
94+
tshpb[1]); // ldc
95+
96+
if(i > 1) {
97+
// data exchange
98+
theTransceiver->send_recv(buff.data(),
99+
tsz,
100+
DTYPE<A>::value,
101+
left,
102+
right);
103+
me = (me + 1) % nr;
104+
}
105+
}
106+
return operatorx<A>::mk_tx(std::move(PVSlice({a_ptr->slice().shape()[0], b_ptr->slice().shape()[1]})), cx);
107+
}
108+
};
109+
}
110+
111+
tensor_i::ptr_type LinAlgOp::vecdot(tensor_i::ptr_type a, tensor_i::ptr_type b, int axis)
112+
{
113+
return TypeDispatch<x::LinAlgOp>(a, b, axis);
114+
}

src/MPITransceiver.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,21 @@ void MPITransceiver::alltoall(const void* buffer_send,
9292
to_mpi(datatype_recv),
9393
MPI_COMM_WORLD);
9494
}
95+
96+
void MPITransceiver::send_recv(void* buffer_send,
97+
int count_send,
98+
DTypeId datatype_send,
99+
int dest,
100+
int source)
101+
{
102+
constexpr int SRTAG = 505;
103+
MPI_Sendrecv_replace(buffer_send,
104+
count_send,
105+
to_mpi(datatype_send),
106+
dest,
107+
SRTAG,
108+
source,
109+
SRTAG,
110+
MPI_COMM_WORLD,
111+
MPI_STATUS_IGNORE);
112+
}

src/ReduceOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ namespace x {
6363
};
6464
} // namespace x
6565

66-
tensor_i::ptr_type ReduceOp::op(ReduceOpId op, x::DPTensorBaseX::ptr_type a, const dim_vec_type & dim)
66+
tensor_i::ptr_type ReduceOp::op(ReduceOpId op, tensor_i::ptr_type a, const dim_vec_type & dim)
6767
{
6868
return TypeDispatch<x::ReduceOp>(a, op, dim);
6969
}

src/ddptensor.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ using namespace pybind11::literals; // to bring _a
2929
#include "ddptensor/ManipOp.hpp"
3030
#include "ddptensor/SetGetItem.hpp"
3131
#include "ddptensor/Random.hpp"
32+
#include "ddptensor/LinAlgOp.hpp"
3233

3334
// #########################################################################
3435
// The following classes are wrappers bridging pybind11 defs to TypeDispatch
@@ -85,6 +86,9 @@ PYBIND11_MODULE(_ddptensor, m) {
8586
py::class_<ManipOp>(m, "ManipOp")
8687
.def("reshape", &ManipOp::reshape);
8788

89+
py::class_<LinAlgOp>(m, "LinAlgOp")
90+
.def("vecdot", &LinAlgOp::vecdot);
91+
8892
py::class_<tensor_i, tensor_i::ptr_type>(m, "DPTensorX")
8993
.def_property_readonly("dtype", &tensor_i::dtype)
9094
.def_property_readonly("shape", &tensor_i::shape)

src/include/ddptensor/LinAlgOp.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// SPDX-License-Identifier: BSD-3-Clause
2+
3+
#pragma once
4+
5+
#include "UtilsAndTypes.hpp"
6+
#include "tensor_i.hpp"
7+
#include "p2c_ids.hpp"
8+
9+
struct LinAlgOp
10+
{
11+
static tensor_i::ptr_type vecdot(tensor_i::ptr_type a, tensor_i::ptr_type b, int axis);
12+
};

0 commit comments

Comments
 (0)