Skip to content

Commit 3abdac3

Browse files
authored
Adds Slicing to Layout (#237)
* shapes support slicing * more tests * fix docs errors * move shape_traits for consistency * can slice layout
1 parent 346b968 commit 3abdac3

24 files changed

Lines changed: 1019 additions & 118 deletions

include/tensorwrapper/layout/layout_base.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <tensorwrapper/shape/shape_base.hpp>
2222
#include <tensorwrapper/sparsity/pattern.hpp>
2323
#include <tensorwrapper/symmetry/group.hpp>
24+
#include <tensorwrapper/types/layout_traits.hpp>
2425

2526
namespace tensorwrapper::layout {
2627

@@ -29,6 +30,10 @@ namespace tensorwrapper::layout {
2930
*/
3031
class LayoutBase : public tensorwrapper::detail_::PolymorphicBase<LayoutBase>,
3132
public tensorwrapper::detail_::DSLBase<LayoutBase> {
33+
private:
34+
/// Type defining types for *this
35+
using traits_type = types::ClassTraits<LayoutBase>;
36+
3237
public:
3338
/// Type all layouts derive from
3439
using layout_base = LayoutBase;
@@ -70,7 +75,7 @@ class LayoutBase : public tensorwrapper::detail_::PolymorphicBase<LayoutBase>,
7075
using sparsity_pointer = std::unique_ptr<sparsity_type>;
7176

7277
/// Type used for indexing and offsets
73-
using size_type = std::size_t;
78+
using size_type = typename traits_type::size_type;
7479

7580
// -------------------------------------------------------------------------
7681
// -- Ctors and dtor
@@ -186,6 +191,9 @@ class LayoutBase : public tensorwrapper::detail_::PolymorphicBase<LayoutBase>,
186191
return *m_sparsity_;
187192
}
188193

194+
/** @brief True if *this is a NULL layout and false otherwise. */
195+
bool is_null() const noexcept { return !static_cast<bool>(m_shape_); }
196+
189197
/** @brief The rank of the tensor this layout describes.
190198
*
191199
* This method is convenience function for calling the rank methods on one
@@ -214,6 +222,10 @@ class LayoutBase : public tensorwrapper::detail_::PolymorphicBase<LayoutBase>,
214222
* @throw None No throw guarantee.
215223
*/
216224
bool operator==(const layout_base& rhs) const noexcept {
225+
if(is_null() && rhs.is_null())
226+
return true;
227+
else if(is_null() || rhs.is_null())
228+
return false;
217229
if(m_shape_->are_different(*rhs.m_shape_)) return false;
218230
if(m_symmetry_->are_different(*rhs.m_symmetry_)) return false;
219231
if(m_sparsity_->are_different(*rhs.m_sparsity_)) return false;
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*
2+
* Copyright 2026 NWChemEx-Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
#include <tensorwrapper/layout/layout_base.hpp>
19+
#include <tensorwrapper/types/layout_traits.hpp>
20+
21+
namespace tensorwrapper::layout {
22+
23+
template<typename Derived>
24+
class LayoutCommon : public LayoutBase {
25+
private:
26+
/// Type of *this
27+
using my_type = LayoutCommon<Derived>;
28+
29+
/// Type defining the types for *this
30+
using traits_type = types::ClassTraits<my_type>;
31+
32+
public:
33+
///@{
34+
using slice_type = typename traits_type::slice_type;
35+
using offset_il_type = typename traits_type::offset_il_type;
36+
///@}
37+
38+
/// Pull in base class's ctors
39+
using LayoutBase::LayoutBase;
40+
41+
/** @brief Slices a layout given two initializer lists.
42+
*
43+
* C++ doesn't allow templates to work with initializer lists, therefore
44+
* we must provide a special overload for when the input containers are
45+
* initializer lists. This method simply dispatches to the range-based
46+
* method by calling begin()/end() on each initializer list. See the
47+
* description of that method for more details.
48+
*
49+
* @param[in] first_elem An initializer list containing the offsets of
50+
* the first element IN the slice such that
51+
* `first_elem[i]` is the offset along mode i.
52+
* @param[in] last_elem An initializer list containing the offsets of
53+
* the first element NOT IN the slice such that
54+
* `last_elem[i]` is the offset along mode i.
55+
*
56+
* @return The requested slice.
57+
*
58+
* @throws ??? If the range-based method throws. Same throw guarantee.
59+
*/
60+
slice_type slice(offset_il_type first_elem,
61+
offset_il_type last_elem) const {
62+
return slice(first_elem.begin(), first_elem.end(), last_elem.begin(),
63+
last_elem.end());
64+
}
65+
66+
/** @brief Slices a layout given two containers.
67+
*
68+
* @tparam ContainerType0 The type of first_elem. Assumed to have
69+
* begin()/end() methods.
70+
* @tparam ContainerType1 The type of last_elem. Assumed to have
71+
* begin()/end() methods.
72+
*
73+
* Element indices are usually stored in containers. This overload is a
74+
* convenience method for calling begin()/end() on the containers before
75+
* dispatching to the range-based overload. See the documentation for the
76+
* range-based overload for more details.
77+
*
78+
* @param[in] first_elem A container containing the offsets of
79+
* the first element IN the slice such that
80+
* `first_elem[i]` is the offset along mode i.
81+
* @param[in] last_elem A container containing the offsets of
82+
* the first element NOT IN the slice such that
83+
* `last_elem[i]` is the offset along mode i.
84+
*
85+
* @return The requested slice.
86+
*
87+
* @throws ??? If the range-based method throws. Same throw guarantee.
88+
*/
89+
template<typename ContainerType0, typename ContainerType1>
90+
slice_type slice(ContainerType0&& first_elem, ContainerType1&& last_elem) {
91+
return slice(first_elem.begin(), first_elem.end(), last_elem.begin(),
92+
last_elem.end());
93+
}
94+
95+
/** @brief Implements slicing given two ranges.
96+
*
97+
* @tparam BeginItr The type of the iterators pointing to offsets in the
98+
* container holding the first element of the slice.
99+
* @tparam EndItr The type of the iterators pointing to the offsets in
100+
* the container holding the first element NOT in the
101+
* slice.
102+
*
103+
* All other slice functions dispatch to this method.
104+
*
105+
* Slices are assumed to be contiguous, meaning we can uniquely specify
106+
* the slice by providing the first element IN the slice and the first
107+
* element NOT IN the slice.
108+
*
109+
* Specifying an element of a rank @f$r@f$ tensor requires providing
110+
* @f$r@f$ offsets (one for each mode). Generally speaking, this requires
111+
* the offsets to be in a container. This method takes iterators to those
112+
* containers such that the @f$r@f$ elements in the range
113+
* [first_elem_begin, first_elem_end) are the offsets of first element IN
114+
* the slice and [last_elem_begin, last_elem_end) are the offsets of the
115+
* first element NOT IN the slice.
116+
*
117+
* @note Both [first_elem_begin, first_elem_end) and
118+
* [last_elem_begin, last_elem_end) being empty is allowed as long
119+
* as *this is null or for a scalar. In these cases you will get back
120+
* the only slice possible, which is the entire shape, i.e. a copy of
121+
* *this.
122+
*
123+
* @param[in] first_elem_begin An iterator to the offset along mode 0 for
124+
* the first element in the slice.
125+
* @param[in] first_elem_end An iterator pointing to just past the offset
126+
* along mode "r-1" (r being the rank of *this) for the first
127+
* element in the slice.
128+
* @param[in] last_elem_begin An iterator to the offset along mode 0 for
129+
* the first element NOT in the slice.
130+
* @param[in] last_elem_end An iterator pointing to just past the offset
131+
* along mode "r-1" (r being the rank of *this) for the first
132+
* element NOT in the slice.
133+
*
134+
* @return The requested slice.
135+
*
136+
* @throw std::runtime_error if the range
137+
* [first_elem_begin, first_elem_end) does not contain the same
138+
* number of elements as [last_elem_begin, last_elem_end).
139+
* Strong throw guarantee.
140+
* @throw std::runtime_error if the offsets in the range
141+
* [first_elem_begin, first_elem_end) do not come before the
142+
* offsets in [last_elem_begin, last_elem_end). Strong throw
143+
* guarantee.
144+
* @throw std::runtime_error if [first_elem_begin, first_elem_end) and
145+
* [last_elem_begin, last_elem_end) contain the
146+
* same number of offsets, but that number is NOT
147+
* equal to the rank of *this. Strong throw
148+
* guarantee.
149+
*
150+
*/
151+
template<typename BeginItr, typename EndItr>
152+
slice_type slice(BeginItr first_elem_begin, BeginItr first_elem_end,
153+
EndItr last_elem_begin, EndItr last_elem_end) const;
154+
};
155+
156+
template<typename Derived>
157+
template<typename BeginItr, typename EndItr>
158+
inline auto LayoutCommon<Derived>::slice(BeginItr first_elem_begin,
159+
BeginItr first_elem_end,
160+
EndItr last_elem_begin,
161+
EndItr last_elem_end) const
162+
-> slice_type {
163+
if(this->is_null()) return Derived{};
164+
auto new_shape = shape().as_smooth().slice(first_elem_begin, first_elem_end,
165+
last_elem_begin, last_elem_end);
166+
auto new_symmetry = symmetry().slice(first_elem_begin, first_elem_end,
167+
last_elem_begin, last_elem_end);
168+
auto new_sparsity = sparsity().slice(first_elem_begin, first_elem_end,
169+
last_elem_begin, last_elem_end);
170+
return slice_type{new_shape, new_symmetry, new_sparsity};
171+
}
172+
173+
} // namespace tensorwrapper::layout
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright 2026 NWChemEx-Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
namespace tensorwrapper::layout {
20+
21+
class LayoutBase;
22+
23+
template<typename Derived>
24+
class LayoutCommon;
25+
26+
class Logical;
27+
class Physical;
28+
29+
} // namespace tensorwrapper::layout

include/tensorwrapper/layout/physical.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
*/
1616

1717
#pragma once
18-
#include <tensorwrapper/layout/layout_base.hpp>
1918

19+
#include <tensorwrapper/layout/layout_common.hpp>
2020
namespace tensorwrapper::layout {
2121

2222
/** @brief Specializes a LayoutBase for a layout describing how a tensor is
@@ -26,10 +26,10 @@ namespace tensorwrapper::layout {
2626
* to hold details such as row major vs column major that matter for the
2727
* physical layout, but not the logical layout.
2828
*/
29-
class Physical : public LayoutBase {
29+
class Physical : public LayoutCommon<Physical> {
3030
private:
3131
/// Type *this derives from
32-
using my_base_type = LayoutBase;
32+
using my_base_type = LayoutCommon<Physical>;
3333

3434
public:
3535
/// Pull in base class's types

include/tensorwrapper/shape/shape_base.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616

1717
#pragma once
1818
#include <cstddef>
19-
#include <memory>
2019
#include <tensorwrapper/detail_/dsl_base.hpp>
2120
#include <tensorwrapper/detail_/polymorphic_base.hpp>
22-
#include <tensorwrapper/shape/shape_traits.hpp>
2321
#include <tensorwrapper/shape/smooth_view.hpp>
22+
#include <tensorwrapper/types/shape_traits.hpp>
2423

2524
namespace tensorwrapper::shape {
2625

@@ -42,7 +41,7 @@ class ShapeBase : public tensorwrapper::detail_::PolymorphicBase<ShapeBase>,
4241
public tensorwrapper::detail_::DSLBase<ShapeBase> {
4342
private:
4443
/// Type implementing the traits of this
45-
using traits_type = ShapeTraits<ShapeBase>;
44+
using traits_type = types::ClassTraits<ShapeBase>;
4645

4746
protected:
4847
/// Typedef of the PolymorphicBase class of *this

include/tensorwrapper/shape/shape_traits.hpp

Lines changed: 0 additions & 82 deletions
This file was deleted.

include/tensorwrapper/shape/smooth.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
#pragma once
1818
#include <functional>
1919
#include <numeric>
20-
#include <shape/shape_traits.hpp>
2120
#include <shape/smooth_common.hpp>
2221
#include <tensorwrapper/shape/shape_base.hpp>
22+
#include <tensorwrapper/types/shape_traits.hpp>
2323
#include <vector>
2424

2525
namespace tensorwrapper::shape {

0 commit comments

Comments
 (0)