@@ -146,3 +146,69 @@ TEMPLATE_LIST_TEST_CASE("SubtractionVisitor", "[buffer][detail_]",
146146 REQUIRE (empty_buffer.at (3 ) == TestType (0.0 ));
147147 }
148148}
149+
150+ TEMPLATE_LIST_TEST_CASE (" MultiplicationVisitor" , " [buffer][detail_]" ,
151+ types::floating_point_types) {
152+ using VisitorType = buffer::detail_::MultiplicationVisitor;
153+ using buffer_type = typename VisitorType::buffer_type;
154+ using label_type = typename VisitorType::label_type;
155+ using shape_type = typename VisitorType::shape_type;
156+
157+ TestType one{1.0 }, two{2.0 }, three{3.0 }, four{4.0 };
158+ std::vector<TestType> this_data{one, two, three, four};
159+ std::vector<TestType> lhs_data{four, three, two, one};
160+ std::vector<TestType> rhs_data{one, one, one, one};
161+ shape_type shape ({4 });
162+ label_type labels (" i" );
163+
164+ std::span<TestType> lhs_span (lhs_data.data (), lhs_data.size ());
165+ std::span<const TestType> clhs_span (lhs_data.data (), lhs_data.size ());
166+ std::span<TestType> rhs_span (rhs_data.data (), rhs_data.size ());
167+ std::span<const TestType> crhs_span (rhs_data.data (), rhs_data.size ());
168+
169+ SECTION (" existing buffer: Hadamard" ) {
170+ buffer_type this_buffer (this_data);
171+ VisitorType visitor (this_buffer, labels, shape, labels, shape, labels,
172+ shape);
173+
174+ visitor (lhs_span, rhs_span);
175+ REQUIRE (this_buffer.at (0 ) == TestType (4.0 ));
176+ REQUIRE (this_buffer.at (1 ) == TestType (3.0 ));
177+ REQUIRE (this_buffer.at (2 ) == TestType (2.0 ));
178+ REQUIRE (this_buffer.at (3 ) == TestType (1.0 ));
179+ }
180+
181+ SECTION (" existing buffer: contraction" ) {
182+ buffer_type this_buffer (this_data);
183+ shape_type scalar_shape;
184+ VisitorType visitor (this_buffer, label_type (" " ), scalar_shape, labels,
185+ shape, labels, shape);
186+
187+ visitor (lhs_span, rhs_span);
188+ REQUIRE (this_buffer.size () == 1 );
189+ REQUIRE (this_buffer.at (0 ) == TestType (10.0 ));
190+ }
191+
192+ SECTION (" existing buffer: batched contraction" ) {
193+ buffer_type this_buffer (this_data);
194+ shape_type out_shape ({2 });
195+ label_type lhs_labels (" a,i" );
196+ label_type rhs_labels (" i,a" );
197+ VisitorType visitor (this_buffer, labels, out_shape, lhs_labels, shape,
198+ rhs_labels, shape);
199+
200+ REQUIRE_THROWS_AS (visitor (lhs_span, rhs_span), std::runtime_error);
201+ }
202+
203+ SECTION (" non-existing buffer" ) {
204+ buffer_type empty_buffer;
205+ VisitorType visitor (empty_buffer, labels, shape, labels, shape, labels,
206+ shape);
207+
208+ visitor (clhs_span, crhs_span);
209+ REQUIRE (empty_buffer.at (0 ) == TestType (4.0 ));
210+ REQUIRE (empty_buffer.at (1 ) == TestType (3.0 ));
211+ REQUIRE (empty_buffer.at (2 ) == TestType (2.0 ));
212+ REQUIRE (empty_buffer.at (3 ) == TestType (1.0 ));
213+ }
214+ }
0 commit comments