Skip to content

Commit c878eef

Browse files
authored
buffer supports multiplication now (#232)
1 parent c6d2354 commit c878eef

7 files changed

Lines changed: 216 additions & 1 deletion

File tree

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
.. Copyright 2023 NWChemEx-Project
2+
..
3+
.. Licensed under the Apache License, Version 2.0 (the "License");
4+
.. you may not use this file except in compliance with the License.
5+
.. You may obtain a copy of the License at
6+
..
7+
.. http://www.apache.org/licenses/LICENSE-2.0
8+
..
9+
.. Unless required by applicable law or agreed to in writing, software
10+
.. distributed under the License is distributed on an "AS IS" BASIS,
11+
.. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
.. See the License for the specific language governing permissions and
13+
.. limitations under the License.
14+
15+
###############################
16+
Adding Operations to Contiguous
17+
###############################
18+
19+
The ``Contiguous`` class is the workhorse of most tensor operations because it
20+
provides the kernels that non-contiguous tensors are built on. As such, we may
21+
need to add operations to it from time to time. This document describes how to
22+
do that.
23+
24+
**********************************
25+
Understanding How Contiguous Works
26+
**********************************
27+
28+
.. figure:: assets/how_contiguous_works.png
29+
:align: center
30+
31+
Control flow for an operation resulting in a ``Contiguous`` buffer object.
32+
33+
For concreteness, we'll trace how ``subtraction_assignment`` is implemented.
34+
Other binary operations are implemented nearly identically and the
35+
implementation of unary operations is extremely similar.
36+
37+
1. The input objects, ``lhs`` and ``rhs`` are converted to ``Contiguous``
38+
objects. N.b., we should eventually use performance models to decide whether
39+
the time to convert to ``Contiguous`` objects is worth it, or if we should
40+
rely on algorithms which do not require contiguous data.
41+
2. We work out the shape of the output tensor.
42+
3. A visitor for the desired operation is created. For
43+
``subtraction_assignment``, this is ``detail_::SubtractionVisitor``.
44+
45+
- Visitor definitions live in ``wtf/src/tensorwrapper/buffer/detail_/``.
46+
47+
5. Control enters ``wtf::buffer::visit_contiguous_buffer`` to restore floating-
48+
point types.
49+
6. ``lhs`` and ``rhs`` are converted to ``std::span`` objects.
50+
7. Control enters the visitor.
51+
8. With types known, the output tensor can be initialized (and is).
52+
9. The visitor converts the ``std::span`` objects into the tensor backend's
53+
tensor objects.
54+
55+
- Backend implementations live in ``wtf/src/tensorwrapper/backends/``.
56+
57+
10. The backend's implementation of the operation is invoked.
58+
59+
**********************
60+
Adding a New Operation
61+
**********************
62+
63+
1. Verify that one of the backends supports the desired operation. If not, add
64+
it to a backend first.
65+
2. Create a visitor for it.
66+
3. Add the operation to ``wtf::buffer::Contiguous``.
179 KB
Loading

docs/source/developer/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ Developer Documentation
2121
:caption: Contents:
2222

2323
design/index
24+
adding_operations_to_contiguous

src/tensorwrapper/buffer/detail_/binary_operation_visitor.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,29 @@ class SubtractionVisitor : public BinaryOperationVisitor {
135135
}
136136
};
137137

138+
/// Visitor that calls hadamard_assignment or contraction_assignment
139+
class MultiplicationVisitor : public BinaryOperationVisitor {
140+
public:
141+
using BinaryOperationVisitor::BinaryOperationVisitor;
142+
using BinaryOperationVisitor::operator();
143+
144+
template<typename FloatType>
145+
void operator()(std::span<FloatType> lhs, std::span<FloatType> rhs) {
146+
using clean_t = std::decay_t<FloatType>;
147+
auto pthis = this->make_this_eigen_tensor_<clean_t>();
148+
auto plhs = this->make_lhs_eigen_tensor_(lhs);
149+
auto prhs = this->make_rhs_eigen_tensor_(rhs);
150+
151+
if(this_labels().is_hadamard_product(lhs_labels(), rhs_labels()))
152+
pthis->hadamard_assignment(this_labels(), lhs_labels(),
153+
rhs_labels(), *plhs, *prhs);
154+
else if(this_labels().is_contraction(lhs_labels(), rhs_labels()))
155+
pthis->contraction_assignment(this_labels(), lhs_labels(),
156+
rhs_labels(), *plhs, *prhs);
157+
else
158+
throw std::runtime_error(
159+
"MultiplicationVisitor: Batched contraction NYI");
160+
}
161+
};
162+
138163
} // namespace tensorwrapper::buffer::detail_

src/tensorwrapper/buffer/mdbuffer.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,26 @@ auto MDBuffer::multiplication_assignment_(label_type this_labels,
157157
const_labeled_reference lhs,
158158
const_labeled_reference rhs)
159159
-> dsl_reference {
160-
throw std::runtime_error("multiplication NYI");
160+
const auto& lhs_down = downcast(lhs.object());
161+
const auto& rhs_down = downcast(rhs.object());
162+
const auto& lhs_shape = lhs_down.m_shape_;
163+
const auto& rhs_shape = rhs_down.m_shape_;
164+
165+
auto labeled_lhs_shape = lhs_shape(lhs.labels());
166+
auto labeled_rhs_shape = rhs_shape(rhs.labels());
167+
168+
m_shape_.multiplication_assignment(this_labels, labeled_lhs_shape,
169+
labeled_rhs_shape);
170+
171+
detail_::MultiplicationVisitor visitor(m_buffer_, this_labels, m_shape_,
172+
lhs.labels(), lhs_shape,
173+
rhs.labels(), rhs_shape);
174+
175+
wtf::buffer::visit_contiguous_buffer<fp_types>(visitor, lhs_down.m_buffer_,
176+
rhs_down.m_buffer_);
177+
178+
mark_for_rehash_();
179+
return *this;
161180
}
162181

163182
auto MDBuffer::permute_assignment_(label_type this_labels,

tests/cxx/unit_tests/tensorwrapper/buffer/detail_/binary_operation_visitor.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,69 @@ TEMPLATE_LIST_TEST_CASE("SubtractionVisitor", "[buffer][detail_]",
146146
REQUIRE(empty_buffer.at(3) == TestType(0.0));
147147
}
148148
}
149+
150+
TEMPLATE_LIST_TEST_CASE("MultiplicationVisitor", "[buffer][detail_]",
151+
types::floating_point_types) {
152+
using VisitorType = buffer::detail_::MultiplicationVisitor;
153+
using buffer_type = typename VisitorType::buffer_type;
154+
using label_type = typename VisitorType::label_type;
155+
using shape_type = typename VisitorType::shape_type;
156+
157+
TestType one{1.0}, two{2.0}, three{3.0}, four{4.0};
158+
std::vector<TestType> this_data{one, two, three, four};
159+
std::vector<TestType> lhs_data{four, three, two, one};
160+
std::vector<TestType> rhs_data{one, one, one, one};
161+
shape_type shape({4});
162+
label_type labels("i");
163+
164+
std::span<TestType> lhs_span(lhs_data.data(), lhs_data.size());
165+
std::span<const TestType> clhs_span(lhs_data.data(), lhs_data.size());
166+
std::span<TestType> rhs_span(rhs_data.data(), rhs_data.size());
167+
std::span<const TestType> crhs_span(rhs_data.data(), rhs_data.size());
168+
169+
SECTION("existing buffer: Hadamard") {
170+
buffer_type this_buffer(this_data);
171+
VisitorType visitor(this_buffer, labels, shape, labels, shape, labels,
172+
shape);
173+
174+
visitor(lhs_span, rhs_span);
175+
REQUIRE(this_buffer.at(0) == TestType(4.0));
176+
REQUIRE(this_buffer.at(1) == TestType(3.0));
177+
REQUIRE(this_buffer.at(2) == TestType(2.0));
178+
REQUIRE(this_buffer.at(3) == TestType(1.0));
179+
}
180+
181+
SECTION("existing buffer: contraction") {
182+
buffer_type this_buffer(this_data);
183+
shape_type scalar_shape;
184+
VisitorType visitor(this_buffer, label_type(""), scalar_shape, labels,
185+
shape, labels, shape);
186+
187+
visitor(lhs_span, rhs_span);
188+
REQUIRE(this_buffer.size() == 1);
189+
REQUIRE(this_buffer.at(0) == TestType(10.0));
190+
}
191+
192+
SECTION("existing buffer: batched contraction") {
193+
buffer_type this_buffer(this_data);
194+
shape_type out_shape({2});
195+
label_type lhs_labels("a,i");
196+
label_type rhs_labels("i,a");
197+
VisitorType visitor(this_buffer, labels, out_shape, lhs_labels, shape,
198+
rhs_labels, shape);
199+
200+
REQUIRE_THROWS_AS(visitor(lhs_span, rhs_span), std::runtime_error);
201+
}
202+
203+
SECTION("non-existing buffer") {
204+
buffer_type empty_buffer;
205+
VisitorType visitor(empty_buffer, labels, shape, labels, shape, labels,
206+
shape);
207+
208+
visitor(clhs_span, crhs_span);
209+
REQUIRE(empty_buffer.at(0) == TestType(4.0));
210+
REQUIRE(empty_buffer.at(1) == TestType(3.0));
211+
REQUIRE(empty_buffer.at(2) == TestType(2.0));
212+
REQUIRE(empty_buffer.at(3) == TestType(1.0));
213+
}
214+
}

tests/cxx/unit_tests/tensorwrapper/buffer/mdbuffer.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,44 @@ TEMPLATE_LIST_TEST_CASE("MDBuffer", "", types::floating_point_types) {
317317
}
318318
}
319319

320+
SECTION("multiplication_assignment_") {
321+
// N.b., dispatching among hadamard, contraction, etc. is the visitor's
322+
// responsibility and happens there. Here we just test hadamard.
323+
324+
SECTION("scalar") {
325+
label_type labels("");
326+
MDBuffer result;
327+
result.multiplication_assignment(labels, scalar(labels),
328+
scalar(labels));
329+
REQUIRE(result.shape() == scalar_shape);
330+
REQUIRE(result.get_elem({}) == TestType(1.0));
331+
}
332+
333+
SECTION("vector") {
334+
label_type labels("i");
335+
MDBuffer result;
336+
result.multiplication_assignment(labels, vector(labels),
337+
vector(labels));
338+
REQUIRE(result.shape() == vector_shape);
339+
REQUIRE(result.get_elem({0}) == TestType(1.0));
340+
REQUIRE(result.get_elem({1}) == TestType(4.0));
341+
REQUIRE(result.get_elem({2}) == TestType(9.0));
342+
REQUIRE(result.get_elem({3}) == TestType(16.0));
343+
}
344+
345+
SECTION("matrix") {
346+
label_type labels("i,j");
347+
MDBuffer result;
348+
result.multiplication_assignment(labels, matrix(labels),
349+
matrix(labels));
350+
REQUIRE(result.shape() == matrix_shape);
351+
REQUIRE(result.get_elem({0, 0}) == TestType(1.0));
352+
REQUIRE(result.get_elem({0, 1}) == TestType(4.0));
353+
REQUIRE(result.get_elem({1, 0}) == TestType(9.0));
354+
REQUIRE(result.get_elem({1, 1}) == TestType(16.0));
355+
}
356+
}
357+
320358
SECTION("scalar_multiplication_") {
321359
// TODO: Test with other scalar types when public API supports it
322360
using scalar_type = double;

0 commit comments

Comments
 (0)