Skip to content

Commit d4dbff5

Browse files
committed
Support streaming lazy masked expressions
1 parent fd7ba8d commit d4dbff5

3 files changed

Lines changed: 174 additions & 5 deletions

File tree

include/xtensor/core/xmath.hpp

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <type_traits>
2222

2323
#include <xtl/xcomplex.hpp>
24+
#include <xtl/xmasked_value.hpp>
2425
#include <xtl/xsequence.hpp>
2526
#include <xtl/xtype_traits.hpp>
2627

@@ -569,13 +570,64 @@ namespace xt
569570

570571
namespace math
571572
{
573+
namespace detail
574+
{
575+
template <typename T>
576+
constexpr decltype(auto) masked_data(const T& value) noexcept
577+
{
578+
return value;
579+
}
580+
581+
template <typename T, typename B>
582+
constexpr decltype(auto) masked_data(const xtl::xmasked_value<T, B>& value) noexcept
583+
{
584+
return value.value();
585+
}
586+
587+
template <typename T>
588+
constexpr bool masked_visible(const T&) noexcept
589+
{
590+
return true;
591+
}
592+
593+
template <typename T, typename B>
594+
constexpr bool masked_visible(const xtl::xmasked_value<T, B>& value) noexcept
595+
{
596+
return static_cast<bool>(value.visible());
597+
}
598+
}
599+
572600
template <class T = void>
573601
struct minimum
574602
{
575603
template <class A1, class A2>
576604
constexpr auto operator()(const A1& t1, const A2& t2) const noexcept
577605
{
578-
return xtl::select(t1 < t2, t1, t2);
606+
if constexpr (xtl::is_xmasked_value<std::decay_t<A1>>::value || xtl::is_xmasked_value<std::decay_t<A2>>::value)
607+
{
608+
using value_type = xtl::promote_type_t<
609+
std::decay_t<decltype(detail::masked_data(t1))>,
610+
std::decay_t<decltype(detail::masked_data(t2))>>;
611+
using return_type = xtl::xmasked_value<value_type, bool>;
612+
613+
if (detail::masked_visible(t1) && detail::masked_visible(t2))
614+
{
615+
return return_type(
616+
static_cast<value_type>(
617+
detail::masked_data(t1) < detail::masked_data(t2)
618+
? detail::masked_data(t1)
619+
: detail::masked_data(t2)
620+
),
621+
true
622+
);
623+
}
624+
625+
return return_type(value_type(0), false);
626+
}
627+
else
628+
{
629+
return xtl::select(t1 < t2, t1, t2);
630+
}
579631
}
580632

581633
template <class A1, class A2>
@@ -591,7 +643,31 @@ namespace xt
591643
template <class A1, class A2>
592644
constexpr auto operator()(const A1& t1, const A2& t2) const noexcept
593645
{
594-
return xtl::select(t1 > t2, t1, t2);
646+
if constexpr (xtl::is_xmasked_value<std::decay_t<A1>>::value || xtl::is_xmasked_value<std::decay_t<A2>>::value)
647+
{
648+
using value_type = xtl::promote_type_t<
649+
std::decay_t<decltype(detail::masked_data(t1))>,
650+
std::decay_t<decltype(detail::masked_data(t2))>>;
651+
using return_type = xtl::xmasked_value<value_type, bool>;
652+
653+
if (detail::masked_visible(t1) && detail::masked_visible(t2))
654+
{
655+
return return_type(
656+
static_cast<value_type>(
657+
detail::masked_data(t1) > detail::masked_data(t2)
658+
? detail::masked_data(t1)
659+
: detail::masked_data(t2)
660+
),
661+
true
662+
);
663+
}
664+
665+
return return_type(value_type(0), false);
666+
}
667+
else
668+
{
669+
return xtl::select(t1 > t2, t1, t2);
670+
}
595671
}
596672

597673
template <class A1, class A2>

include/xtensor/views/xmasked_view.hpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,15 @@ namespace xt
118118
using bool_load_type = xtl::xmasked_value<typename data_type::bool_load_type, mask_type>;
119119

120120
using shape_type = typename data_type::shape_type;
121-
using strides_type = typename data_type::strides_type;
121+
using strides_type = xtl::mpl::eval_if_t<has_strides<data_type>, detail::expr_strides_type<data_type>, get_strides_type<shape_type>>;
122+
using backstrides_type = xtl::mpl::eval_if_t<has_strides<data_type>, detail::expr_backstrides_type<data_type>, get_strides_type<shape_type>>;
122123

123124
static constexpr layout_type static_layout = data_type::static_layout;
124125
static constexpr bool contiguous_layout = false;
125126

126127
using inner_shape_type = typename data_type::inner_shape_type;
127-
using inner_strides_type = typename data_type::inner_strides_type;
128-
using inner_backstrides_type = typename data_type::inner_backstrides_type;
128+
using inner_strides_type = xtl::mpl::eval_if_t<has_strides<data_type>, detail::expr_inner_strides_type<data_type>, get_strides_type<shape_type>>;
129+
using inner_backstrides_type = xtl::mpl::eval_if_t<has_strides<data_type>, detail::expr_inner_backstrides_type<data_type>, get_strides_type<shape_type>>;
129130

130131
using expression_tag = xtensor_expression_tag;
131132

@@ -163,7 +164,12 @@ namespace xt
163164

164165
size_type size() const noexcept;
165166
const inner_shape_type& shape() const noexcept;
167+
template <typename DT = data_type>
168+
requires has_strides<DT>::value
166169
const inner_strides_type& strides() const noexcept;
170+
171+
template <typename DT = data_type>
172+
requires has_strides<DT>::value
167173
const inner_backstrides_type& backstrides() const noexcept;
168174
using accessible_base::dimension;
169175
using accessible_base::shape;
@@ -202,6 +208,9 @@ namespace xt
202208
template <class S>
203209
bool has_linear_assign(const S& strides) const noexcept;
204210

211+
template <typename S>
212+
bool broadcast_shape(S& shape, bool reuse_cache = false) const;
213+
205214
data_type& value() noexcept;
206215
const data_type& value() const noexcept;
207216

@@ -338,6 +347,8 @@ namespace xt
338347
* Returns the strides of the xmasked_view.
339348
*/
340349
template <class CTD, class CTM>
350+
template <typename DT>
351+
requires has_strides<DT>::value
341352
inline auto xmasked_view<CTD, CTM>::strides() const noexcept -> const inner_strides_type&
342353
{
343354
return m_data.strides();
@@ -347,6 +358,8 @@ namespace xt
347358
* Returns the backstrides of the xmasked_view.
348359
*/
349360
template <class CTD, class CTM>
361+
template <typename DT>
362+
requires has_strides<DT>::value
350363
inline auto xmasked_view<CTD, CTM>::backstrides() const noexcept -> const inner_backstrides_type&
351364
{
352365
return m_data.backstrides();
@@ -370,6 +383,13 @@ namespace xt
370383
return false;
371384
}
372385

386+
template <typename CTD, typename CTM>
387+
template <typename S>
388+
inline bool xmasked_view<CTD, CTM>::broadcast_shape(S& shape, bool) const
389+
{
390+
return xt::broadcast_shape(m_data.shape(), shape);
391+
}
392+
373393
/**
374394
* Fills the data with the given value.
375395
* @param value the value to fill the data with.

test/test_xmasked_view.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <sstream>
1111

12+
#include "xtensor/core/xmath.hpp"
1213
#include "xtensor/io/xio.hpp"
1314
#include "xtensor/optional/xoptional_assembly.hpp"
1415
#include "xtensor/views/xmasked_view.hpp"
@@ -40,6 +41,42 @@ namespace xt
4041
return masked_view(data, std::move(mask));
4142
}
4243

44+
template <class A, class B, class M, class = void>
45+
struct is_masked_minimum_streamable : std::false_type
46+
{
47+
};
48+
49+
template <class A, class B, class M>
50+
struct is_masked_minimum_streamable<
51+
A,
52+
B,
53+
M,
54+
std::void_t<decltype(
55+
std::declval<std::ostream&>()
56+
<< minimum(
57+
masked_view(std::declval<const A&>(), std::declval<const M&>()),
58+
masked_view(std::declval<const B&>(), std::declval<const M&>())
59+
))>> : std::true_type
60+
{
61+
};
62+
63+
template <class A, class B, class M, class = void>
64+
struct is_masked_view_of_minimum_streamable : std::false_type
65+
{
66+
};
67+
68+
template <class A, class B, class M>
69+
struct is_masked_view_of_minimum_streamable<
70+
A,
71+
B,
72+
M,
73+
std::void_t<decltype(
74+
std::declval<std::ostream&>()
75+
<< masked_view(minimum(std::declval<const A&>(), std::declval<const B&>()), std::declval<const M&>()))>>
76+
: std::true_type
77+
{
78+
};
79+
4380
TEST(xmasked_view, dimension)
4481
{
4582
auto data = make_test_data();
@@ -226,6 +263,27 @@ namespace xt
226263
" {masked, 5, masked}}";
227264
EXPECT_EQ(out.str(), expected);
228265
}
266+
267+
TEST(xmasked_view, lazy_expression_stream)
268+
{
269+
using array_type = xarray<double>;
270+
const array_type a = {1., 1., 1., 1.};
271+
const array_type b = {0.1, 0.7, 0.3, 0.9};
272+
273+
const auto mask = b < 0.5;
274+
const auto expected = eval(masked_view(minimum(a, b), mask));
275+
276+
std::stringstream expected_out;
277+
expected_out << expected;
278+
279+
std::stringstream masked_min_out;
280+
masked_min_out << minimum(masked_view(a, mask), masked_view(b, mask));
281+
EXPECT_EQ(masked_min_out.str(), expected_out.str());
282+
283+
std::stringstream masked_expr_out;
284+
masked_expr_out << masked_view(minimum(a, b), mask);
285+
EXPECT_EQ(masked_expr_out.str(), expected_out.str());
286+
}
229287

230288
TEST(xmasked_view, assign)
231289
{
@@ -240,6 +298,21 @@ namespace xt
240298
EXPECT_EQ(data, expected1);
241299
}
242300

301+
TEST(xmasked_view, assign_const_masked_view_rhs)
302+
{
303+
xarray<double> data = {{1., -2., 3.}, {4., 5., -6.}, {7., 8., -9.}};
304+
const xarray<double> data2 = {{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}};
305+
xarray<bool> mask = {{true, true, true}, {true, false, false}, {true, false, true}};
306+
307+
auto masked_data = masked_view(data, mask);
308+
const auto masked_data2 = masked_view(data2, mask);
309+
310+
masked_data = masked_data2;
311+
312+
xarray<double> expected = {{0.1, 0.2, 0.3}, {0.4, 5., -6.}, {0.7, 8., 0.9}};
313+
EXPECT_EQ(data, expected);
314+
}
315+
243316
TEST(xmasked_view, view)
244317
{
245318
xt::xarray<size_t> data = {{0, 1}, {2, 3}, {4, 5}};

0 commit comments

Comments
 (0)