Skip to content

Commit bdb917b

Browse files
authored
ReplicatedView supports slicing (#239)
1 parent 926511d commit bdb917b

20 files changed

Lines changed: 1107 additions & 146 deletions

include/tensorwrapper/buffer/buffer_base.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class BufferBase : public BufferBaseCommon<BufferBase>,
142142

143143
const_layout_reference layout_() const { return *m_layout_; }
144144

145+
layout_reference layout_() { return *m_layout_; }
146+
145147
dsl_reference addition_assignment_(label_type this_labels,
146148
const_labeled_reference lhs,
147149
const_labeled_reference rhs) override;

include/tensorwrapper/buffer/buffer_base_common.hpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,14 @@ class BufferBaseCommon {
3535
/// Type of *this
3636
using my_type = BufferBaseCommon<Derived>;
3737

38-
/// Traits for my_type
39-
using traits_type = types::ClassTraits<my_type>;
40-
4138
/// Traits for Derived
42-
using derived_traits = types::ClassTraits<Derived>;
39+
using traits_type = types::ClassTraits<Derived>;
4340

4441
public:
4542
///@{
4643
using layout_type = typename traits_type::layout_type;
44+
using layout_reference = typename traits_type::layout_reference;
45+
using layout_pointer = typename traits_type::layout_pointer;
4746
using const_layout_reference = typename traits_type::const_layout_reference;
4847
using rank_type = typename traits_type::rank_type;
4948
///@}
@@ -60,6 +59,18 @@ class BufferBaseCommon {
6059
*/
6160
bool has_layout() const noexcept { return derived_().has_layout_(); }
6261

62+
/** @brief Retrieves the layout of *this.
63+
*
64+
* @return A reference to the layout.
65+
*
66+
* @throw std::runtime_error if *this does not have a layout. Strong throw
67+
* guarantee.
68+
*/
69+
layout_reference layout() {
70+
assert_layout_();
71+
return derived_().layout_();
72+
}
73+
6374
/** @brief Retrieves the layout of *this.
6475
*
6576
* @return A read-only reference to the layout.
@@ -143,7 +154,7 @@ class BufferBaseCommon {
143154

144155
/// Access derived for CRTP
145156
const Derived& derived_() const noexcept {
146-
return *static_cast<const Derived*>(this);
157+
return static_cast<const Derived&>(*this);
147158
}
148159
};
149160

include/tensorwrapper/buffer/buffer_fwd.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,19 @@ class BufferBase;
2626
template<typename BufferBaseType>
2727
class BufferViewBase;
2828

29-
class Contiguous;
30-
3129
class Local;
3230

31+
template<typename LocalType>
32+
class LocalView;
33+
34+
template<typename Derived>
35+
class ReplicatedCommon;
36+
3337
class Replicated;
3438

39+
template<typename ReplicatedType>
40+
class ReplicatedView;
41+
42+
class Contiguous;
43+
3544
} // namespace tensorwrapper::buffer

include/tensorwrapper/buffer/buffer_view_base.hpp

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,28 @@
1515
*/
1616

1717
#pragma once
18+
#include <memory>
1819
#include <tensorwrapper/buffer/buffer_base.hpp>
1920
#include <tensorwrapper/buffer/buffer_base_common.hpp>
21+
#include <tensorwrapper/buffer/detail_/buffer_view_base_pimpl.hpp>
2022
#include <type_traits>
2123

2224
namespace tensorwrapper::buffer {
2325

2426
/** @brief View of a BufferBase that aliases existing state instead of owning
25-
* it.
27+
* it.
2628
*
2729
* BufferViewBase has the same layout/equality API as BufferBase (has_layout(),
28-
* layout(), rank(), operator==, operator!=, approximately_equal) but holds a
29-
* non-owning pointer to a BufferBase and delegates all operations to it.
30+
* layout(), rank(), operator==, operator!=, approximately_equal) but uses a
31+
* PIMPL. The view delegates layout operations to the PIMPL.
3032
*
31-
* BufferViewBase is templated on the type of the aliased buffer, which must
32-
* be either BufferBase or const BufferBase. This controls whether the view is
33-
* a mutable or const view of the underlying BufferBase.
33+
* BufferViewBase is templated on the type of the aliased buffer (BufferBase or
34+
* const BufferBase) for API compatibility; construction from a buffer copies
35+
* a non-owning pointer to that buffer's layout into the PIMPL.
3436
*
35-
* The aliased buffer must outlive this view. Default-constructed or
36-
* moved-from views have no aliased buffer (has_layout() is false, layout()
37-
* throws).
37+
* The referenced layout (and its owner) must outlive this view. Default-
38+
* constructed or moved-from views have no layout (has_layout() is false,
39+
* layout() throws).
3840
*
3941
* @tparam BufferBaseType Either BufferBase or const BufferBase.
4042
*/
@@ -48,61 +50,87 @@ class BufferViewBase : public BufferBaseCommon<BufferViewBase<BufferBaseType>> {
4850

4951
/// Type *this derives from
5052
using my_base_type = BufferBaseCommon<BufferViewBase<BufferBaseType>>;
51-
using typename my_base_type::const_layout_reference;
5253

53-
using aliased_type = BufferBaseType;
54-
using aliased_pointer = aliased_type*;
54+
/// Type of the PIMPL
55+
using pimpl_type = detail_::BufferViewBasePIMPL<BufferBaseType>;
56+
using pimpl_reference = pimpl_type&;
57+
using const_pimpl_reference = const pimpl_type&;
5558

5659
public:
60+
using typename my_base_type::const_layout_reference;
61+
using typename my_base_type::layout_pointer;
62+
using typename my_base_type::layout_reference;
63+
using typename my_base_type::layout_type;
5764
// -------------------------------------------------------------------------
5865
// -- Ctors and assignment
5966
// -------------------------------------------------------------------------
6067

61-
/** @brief Creates a view that aliases no buffer.
68+
/** @brief Creates a view with no layout.
6269
*
6370
* @throw None No throw guarantee.
6471
*/
65-
BufferViewBase() noexcept : m_aliased_(nullptr) {}
72+
BufferViewBase() noexcept : m_pimpl_(nullptr) {}
6673

67-
/** @brief Creates a view that aliases @p buffer.
74+
/** @brief Creates a view that aliases the layout of @p buffer.
6875
*
69-
* @param[in] buffer The buffer to alias. Must outlive *this.
76+
* @param[in] buffer The buffer whose layout to alias. The layout must
77+
* outlive *this.
7078
*
7179
* @throw None No throw guarantee.
7280
*/
73-
explicit BufferViewBase(aliased_type& buffer) noexcept :
74-
m_aliased_(&buffer) {}
81+
explicit BufferViewBase(BufferBaseType& buffer) noexcept :
82+
m_pimpl_(buffer.has_layout() ?
83+
std::make_unique<pimpl_type>(&buffer.layout()) :
84+
nullptr) {}
85+
86+
/** Creates a read-only view from a mutable buffer. */
87+
template<typename OtherBufferBaseType>
88+
requires(!std::is_const_v<OtherBufferBaseType> &&
89+
std::is_const_v<BufferBaseType>)
90+
explicit BufferViewBase(OtherBufferBaseType& other) noexcept :
91+
m_pimpl_(other.has_layout() ?
92+
std::make_unique<pimpl_type>(&other.layout()) :
93+
nullptr) {}
94+
95+
explicit BufferViewBase(layout_pointer layout) noexcept :
96+
m_pimpl_(std::make_unique<pimpl_type>(layout)) {}
7597

76-
/** @brief Creates a view that aliases the same buffer as @p other.
98+
/** @brief Creates a view that aliases the same layout as @p other.
7799
*
78100
* @param[in] other The view to copy.
79101
*
80102
* @throw None No throw guarantee.
81103
*/
82-
BufferViewBase(const BufferViewBase& other) noexcept = default;
104+
BufferViewBase(const BufferViewBase& other) noexcept :
105+
m_pimpl_(other.m_pimpl_ ? other.m_pimpl_->clone() : nullptr) {}
83106

84-
/** @brief Creates a view by taking the alias from @p other.
107+
/** @brief Creates a view by taking the PIMPL from @p other.
85108
*
86-
* After construction *this aliases the buffer @p other did, and @p other
87-
* aliases no buffer.
109+
* After construction *this aliases the layout @p other did, and @p other
110+
* has no layout.
88111
*
89112
* @param[in,out] other The view to move from.
90113
*
91114
* @throw None No throw guarantee.
92115
*/
93116
BufferViewBase(BufferViewBase&& other) noexcept = default;
94117

95-
/** @brief Makes *this alias the same buffer as @p rhs.
118+
/** @brief Makes *this alias the same layout as @p rhs.
96119
*
97120
* @param[in] rhs The view to copy.
98121
*
99122
* @return *this.
100123
*
101124
* @throw None No throw guarantee.
102125
*/
103-
BufferViewBase& operator=(const BufferViewBase& rhs) noexcept = default;
126+
BufferViewBase& operator=(const BufferViewBase& rhs) noexcept {
127+
if(this != &rhs) {
128+
m_pimpl_ = rhs.m_pimpl_ ? rhs.m_pimpl_->clone() : nullptr;
129+
}
130+
return *this;
131+
}
104132

105-
/** @brief Replaces the alias in *this with that of @p rhs.
133+
/** @brief Replaces the PIMPL in *this with that of @p rhs.
106134
*
107135
* @param[in,out] rhs The view to move from.
108136
*
@@ -133,41 +161,53 @@ class BufferViewBase : public BufferBaseCommon<BufferViewBase<BufferBaseType>> {
133161
// -------------------------------------------------------------------------
134162

135163
bool has_layout_() const noexcept {
136-
return m_aliased_ != nullptr && m_aliased_->has_layout();
164+
return m_pimpl_ != nullptr && m_pimpl_->has_layout();
165+
}
166+
167+
layout_reference layout_() { return pimpl_().layout(); }
168+
169+
const_layout_reference layout_() const { return pimpl_().layout(); }
170+
171+
// Will be polymorphic eventually
172+
template<typename OtherBufferBaseType>
173+
bool approximately_equal_(const BufferViewBase<OtherBufferBaseType>& rhs,
174+
double) const {
175+
return *this == rhs;
176+
}
177+
178+
// Will be polymorphic eventually
179+
bool approximately_equal_(const BufferBase& rhs, double) const {
180+
return *this == rhs;
137181
}
138182

139-
const_layout_reference layout_() const {
140-
if(m_aliased_ == nullptr) {
183+
private:
184+
void assert_pimpl_() const {
185+
if(!m_pimpl_) {
141186
throw std::runtime_error(
142-
"Buffer has no layout. Was it default initialized?");
187+
"BufferViewBase has no PIMPL. Was it default initialized?");
143188
}
144-
return m_aliased_->layout();
145189
}
146-
147-
template<typename OtherBufferBase>
148-
bool approximately_equal_(const BufferViewBase<OtherBufferBase>& rhs,
149-
double tol) const {
150-
if(m_aliased_ == nullptr) return !rhs.has_layout();
151-
return m_aliased_->approximately_equal(*rhs.m_aliased_, tol);
190+
pimpl_reference pimpl_() {
191+
assert_pimpl_();
192+
return *m_pimpl_;
152193
}
153194

154-
bool approximately_equal_(const BufferBase& rhs, double tol) const {
155-
if(m_aliased_ == nullptr) return !rhs.has_layout();
156-
return m_aliased_->approximately_equal(rhs, tol);
195+
const_pimpl_reference pimpl_() const {
196+
assert_pimpl_();
197+
return *m_pimpl_;
157198
}
158199

159-
private:
160-
/// The buffer *this aliases (non-owning)
161-
aliased_pointer m_aliased_;
200+
/// PIMPL holding non-owning pointer to the aliased layout
201+
std::unique_ptr<pimpl_type> m_pimpl_;
162202
};
163203

164-
// Out-of-line definition so both BufferBase and BufferViewBase are complete
204+
// Out-of-line definition so both BufferBase and BufferViewBase are complete.
205+
165206
template<typename BufferBaseType>
166207
bool BufferBase::approximately_equal_(const BufferViewBase<BufferBaseType>& rhs,
167208
double tol) const {
168209
if(!rhs.has_layout()) return !has_layout();
169-
return approximately_equal_(
170-
*static_cast<const BufferBaseType*>(rhs.m_aliased_), tol);
210+
return !this->layout().are_different(rhs.layout());
171211
}
172212

173213
} // namespace tensorwrapper::buffer

0 commit comments

Comments
 (0)