@@ -54,8 +54,6 @@ namespace x {
5454 template <typename A, typename B>
5555 static ptr_type matmul_2d (const std::shared_ptr<DPTensorX<A>> & a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr, int axis)
5656 {
57- if (!a_ptr->slice ().is_equally_tiled () || !b_ptr->slice ().is_equally_tiled ())
58- throw (std::runtime_error (" vecdoc_2d supported for eually tiled tensors only" ));
5957 if (a_ptr->slice ().split_dim () != 0 )
6058 throw (std::runtime_error (" vecdoc_2d supported for split_dim=0 only" ));
6159
@@ -64,37 +62,40 @@ namespace x {
6462 rank_type right = (me + 1 ) % nr;
6563 rank_type left = (nr + me - 1 ) % nr;
6664 auto tsz = b_ptr->slice ().tile_size (0 );
67- auto tshpa = a_ptr->slice ().tile_shape (0 );
68- auto tshpb = b_ptr->slice ().tile_shape (0 );
65+ auto my_tshp_a = a_ptr->slice ().tile_shape (me);
66+ auto tshp_b = b_ptr->slice ().tile_shape (0 );
67+ auto my_tshp_b = me == 0 ? tshp_b : b_ptr->slice ().tile_shape (me);
6968
7069 const auto & ax = a_ptr->xarray ();
7170 const auto & bx = b_ptr->xarray ();
72- xt::xarray<A> cx = xt::zeros<A>({tshpa[0 ], tshpb[1 ]});
73- auto buff = xt::empty_like (bx);
74- buff = bx;
71+ xt::xarray<A> cx = xt::zeros<A>({my_tshp_a[0 ], tshp_b[1 ]});
72+ auto buff = xt::empty<B>(tshp_b);
73+ if (tshp_b[0 ] == my_tshp_b[0 ]) {
74+ buff = bx;
75+ } else { // last partitions can be smaller -> need a view to assign values
76+ auto bv = xt::view (buff, xt::range (0 , my_tshp_b[0 ]), xt::range (0 , my_tshp_b[1 ]));
77+ bv = bx;
78+ }
7579
7680 // We use an algo similar to Canon's
81+ // the last partitions can be smaller -> k depends on "me", the source rank of current partition
7782 for (rank_type i = nr; i>0 ; --i) {
78- // std::cerr << me*tshpb[0] << " " << (1+me) * tshpb[0] << std::endl;
79- // auto av = xt::view(ax, xt::all(), xt::range(me * tshpb[0], (1+me) * tshpb[0]));
80- // cx = cx + xt::linalg::dot(av, buff);
81- TGEMM<A>::tgemm (CblasRowMajor,
82- CblasNoTrans,
83- CblasNoTrans,
84- tshpa[0 ],
85- tshpb[1 ],
86- tshpb[0 ],
87- 1 , // alpha
88- ax.data () + (me * tshpb[0 ]),
89- tshpa[1 ], // lda
90- buff.data (),
91- tshpb[1 ], // ldb
92- 1 , // beta
93- cx.data (),
94- tshpb[1 ]); // ldc
95-
83+ if (my_tshp_a[0 ]) {
84+ TGEMM<A>::tgemm (CblasRowMajor, CblasNoTrans, CblasNoTrans,
85+ my_tshp_a[0 ], tshp_b[1 ], me == 0 ? tshp_b[0 ] : b_ptr->slice ().tile_shape (me)[0 ],
86+ 1 , // alpha
87+ ax.data () + (me * tshp_b[0 ]),
88+ my_tshp_a[1 ], // lda
89+ buff.data (),
90+ tshp_b[1 ], // ldb
91+ 1 , // beta
92+ cx.data (),
93+ tshp_b[1 ]); // ldc
94+ }
95+
9696 if (i > 1 ) {
9797 // data exchange
98+ // FIXME: optimize data transfer: last partition might contain unused data
9899 theTransceiver->send_recv (buff.data (),
99100 tsz,
100101 DTYPE<A>::value,
@@ -103,7 +104,7 @@ namespace x {
103104 me = (me + 1 ) % nr;
104105 }
105106 }
106- return operatorx<A>::mk_tx (std::move (PVSlice ({tshpa [0 ], tshpb [1 ]})), cx);
107+ return operatorx<A>::mk_tx (std::move (PVSlice ({a_ptr-> shape () [0 ], b_ptr-> shape () [1 ]})), cx);
107108 }
108109 };
109110}
0 commit comments