|
| 1 | +#include <mpi.h> |
| 2 | +//#include <mkl.h> |
| 3 | +#include "ddptensor/LinAlgOp.hpp" |
| 4 | +#include "ddptensor/TypeDispatch.hpp" |
| 5 | +#include "ddptensor/x.hpp" |
| 6 | +#include "xtensor-blas/xlinalg.hpp" |
| 7 | + |
| 8 | +namespace x { |
| 9 | + |
| 10 | + template<typename T> struct TGEMM; |
| 11 | + template<> struct TGEMM<double> { static constexpr auto tgemm = cblas_dgemm; }; |
| 12 | + template<> struct TGEMM<float> { static constexpr auto tgemm = cblas_sgemm; }; |
| 13 | + |
| 14 | + class LinAlgOp |
| 15 | + { |
| 16 | + public: |
| 17 | + using ptr_type = DPTensorBaseX::ptr_type; |
| 18 | + |
| 19 | + template<typename A, typename B> |
| 20 | + static ptr_type op(int axis, const std::shared_ptr<DPTensorX<A>> & a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr) |
| 21 | + { |
| 22 | + if constexpr (std::is_floating_point<A>::value && std::is_same<A, B>::value) { |
| 23 | + const auto & ax = a_ptr->xarray(); |
| 24 | + const auto & bx = b_ptr->xarray(); |
| 25 | + auto nda = a_ptr->slice().ndims(); |
| 26 | + auto ndb = b_ptr->slice().ndims(); |
| 27 | + |
| 28 | + if(a_ptr->is_sliced() || b_ptr->is_sliced()) { |
| 29 | + if(nda != 1 || ndb != 1) |
| 30 | + throw(std::runtime_error("vecdoc on sliced tensors supported for 1d tensors only")); |
| 31 | + const auto & av = xt::strided_view(ax, a_ptr->lslice()); |
| 32 | + const auto & bv = xt::strided_view(bx, b_ptr->lslice()); |
| 33 | + return vecdot_1d(av, bv, axis); |
| 34 | + } |
| 35 | + |
| 36 | + if(nda == 1 && ndb == 1) { |
| 37 | + return vecdot_1d(ax, bx, axis); |
| 38 | + } else if(nda == 2 && ndb == 2) { |
| 39 | + return matmul_2d(a_ptr, b_ptr, axis); |
| 40 | + } |
| 41 | + throw(std::runtime_error("'vecdot' supported for two 1d or two 2d tensors only.")); |
| 42 | + } else |
| 43 | + throw(std::runtime_error("'vecdot' supported for 2 double or float tensors only.")); |
| 44 | + } |
| 45 | + |
| 46 | + template<typename T1, typename T2> |
| 47 | + static ptr_type vecdot_1d(const T1 & a, const T2 & b, int axis) |
| 48 | + { |
| 49 | + auto d = xt::linalg::dot(a, b)(); |
| 50 | + theTransceiver->reduce_all(&d, DTYPE<decltype(d)>::value, 1, SUM); |
| 51 | + return operatorx<decltype(d)>::mk_tx(d, REPLICATED); |
| 52 | + } |
| 53 | + |
| 54 | + template<typename A, typename B> |
| 55 | + static ptr_type matmul_2d(const std::shared_ptr<DPTensorX<A>> & a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr, int axis) |
| 56 | + { |
| 57 | + if(!a_ptr->slice().is_equally_tiled() || !b_ptr->slice().is_equally_tiled()) |
| 58 | + throw(std::runtime_error("vecdoc_2d supported for eually tiled tensors only")); |
| 59 | + if(a_ptr->slice().split_dim() != 0) |
| 60 | + throw(std::runtime_error("vecdoc_2d supported for split_dim=0 only")); |
| 61 | + |
| 62 | + auto nr = theTransceiver->nranks(); |
| 63 | + auto me = theTransceiver->rank(); |
| 64 | + rank_type right = (me + 1) % nr; |
| 65 | + rank_type left = (nr + me - 1) % nr; |
| 66 | + auto tsz = b_ptr->slice().tile_size(0); |
| 67 | + auto tshpa = a_ptr->slice().tile_shape(0); |
| 68 | + auto tshpb = b_ptr->slice().tile_shape(0); |
| 69 | + |
| 70 | + const auto & ax = a_ptr->xarray(); |
| 71 | + const auto & bx = b_ptr->xarray(); |
| 72 | + xt::xarray<A> cx = xt::zeros<A>({tshpa[0], tshpb[1]}); |
| 73 | + auto buff = xt::empty_like(bx); |
| 74 | + buff = bx; |
| 75 | + |
| 76 | + // We use an algo similar to Canon's |
| 77 | + for(rank_type i = nr; i>0; --i) { |
| 78 | + // std::cerr << me*tshpb[0] << " " << (1+me) * tshpb[0] << std::endl; |
| 79 | + // auto av = xt::view(ax, xt::all(), xt::range(me * tshpb[0], (1+me) * tshpb[0])); |
| 80 | + // cx = cx + xt::linalg::dot(av, buff); |
| 81 | + TGEMM<A>::tgemm(CblasRowMajor, |
| 82 | + CblasNoTrans, |
| 83 | + CblasNoTrans, |
| 84 | + tshpa[0], |
| 85 | + tshpb[1], |
| 86 | + tshpb[0], |
| 87 | + 1, // alpha |
| 88 | + ax.data() + (me * tshpb[0]), |
| 89 | + tshpa[1], // lda |
| 90 | + buff.data(), |
| 91 | + tshpb[1], // ldb |
| 92 | + 1, // beta |
| 93 | + cx.data(), |
| 94 | + tshpb[1]); // ldc |
| 95 | + |
| 96 | + if(i > 1) { |
| 97 | + // data exchange |
| 98 | + theTransceiver->send_recv(buff.data(), |
| 99 | + tsz, |
| 100 | + DTYPE<A>::value, |
| 101 | + left, |
| 102 | + right); |
| 103 | + me = (me + 1) % nr; |
| 104 | + } |
| 105 | + } |
| 106 | + return operatorx<A>::mk_tx(std::move(PVSlice({a_ptr->slice().shape()[0], b_ptr->slice().shape()[1]})), cx); |
| 107 | + } |
| 108 | + }; |
| 109 | +} |
| 110 | + |
| 111 | +tensor_i::ptr_type LinAlgOp::vecdot(tensor_i::ptr_type a, tensor_i::ptr_type b, int axis) |
| 112 | +{ |
| 113 | + return TypeDispatch<x::LinAlgOp>(a, b, axis); |
| 114 | +} |
0 commit comments