Skip to content

Commit 42d124c

Browse files
authored
Fix xview assignment through leading newaxis slices (#2896)
# Checklist - [x] The title and commit message(s) are descriptive. - [x] Small commits made to fix your PR have been squashed to avoid history pollution. - [x] Tests have been added for new features or bug fixes. - [x] API of new functions and classes are documented. # Description This PR fixes incorrect index mapping in xview assignments when a view starts with one or more newaxis() slices. The change updates the index computation to use a newaxis-aware slice index before applying integral-slice adjustments. That keeps writes through views with multiple leading newaxis() entries aligned with the correct element in the underlying tensor. Co-authored-by: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com>
1 parent 68f9a77 commit 42d124c

2 files changed

Lines changed: 42 additions & 2 deletions

File tree

include/xtensor/views/xview.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,8 +1645,9 @@ namespace xt
16451645
{
16461646
if constexpr (lesser_condition<I>::value)
16471647
{
1648-
return sliced_access<I - integral_count_before<S...>(I) + newaxis_count_before<S...>(I + 1)>(
1649-
std::get<I + newaxis_count_before<S...>(I + 1)>(m_slices),
1648+
constexpr size_type slice_index = newaxis_skip<S...>(I);
1649+
return sliced_access<slice_index - integral_count_before<S...>(slice_index)>(
1650+
std::get<slice_index>(m_slices),
16501651
args...
16511652
);
16521653
}

test/test_xview.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,45 @@ namespace xt
15911591
EXPECT_EQ(a, b);
15921592
}
15931593

1594+
TEST(xview, assign_through_multiple_leading_newaxis)
1595+
{
1596+
SUBCASE("updates the underlying tensor for every element")
1597+
{
1598+
xt::xtensor<uint8_t, 2> tensor = xt::zeros<uint8_t>({4, 3});
1599+
auto view = xt::view(tensor, xt::newaxis(), xt::newaxis(), xt::newaxis(), xt::all(), xt::all());
1600+
1601+
uint8_t value = 0;
1602+
for (std::size_t row = 0; row < 4; ++row)
1603+
{
1604+
for (std::size_t col = 0; col < 3; ++col)
1605+
{
1606+
view(std::size_t{0}, std::size_t{0}, std::size_t{0}, row, col) = value;
1607+
EXPECT_EQ(tensor(row, col), value);
1608+
++value;
1609+
}
1610+
}
1611+
1612+
EXPECT_EQ(tensor, xt::arange<uint8_t>(12).reshape({4, 3}));
1613+
}
1614+
1615+
SUBCASE("preserves bool assignment semantics")
1616+
{
1617+
xt::xtensor<bool, 2> tensor = xt::zeros<bool>({4, 3});
1618+
auto view = xt::view(tensor, xt::newaxis(), xt::newaxis(), xt::newaxis(), xt::all(), xt::all());
1619+
1620+
for (std::size_t row = 0; row < 4; ++row)
1621+
{
1622+
for (std::size_t col = 0; col < 3; ++col)
1623+
{
1624+
view(std::size_t{0}, std::size_t{0}, std::size_t{0}, row, col) = true;
1625+
EXPECT_TRUE(tensor(row, col));
1626+
}
1627+
}
1628+
1629+
EXPECT_EQ(tensor, xt::ones<bool>({4, 3}));
1630+
}
1631+
}
1632+
15941633
TEST(xview, in_bounds)
15951634
{
15961635
xt::xtensor<size_t, 2> a = {{0, 1, 2}, {3, 4, 5}};

0 commit comments

Comments
 (0)