Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,53 @@ array tile(
return reshape(x, std::move(final_shape), s);
}

array reflect_pad(
const array& a,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Shape& high_pad_size,
bool include_edge,
StreamOrDevice s /* = {} */) {
// Reflect (include_edge=false) or symmetric (include_edge=true) padding.
// Matches numpy.pad for arbitrary pad sizes (the reflection repeats as needed).
// For an out-of-range coordinate r (relative to the original axis [0, n)),
// map it back into [0, n) by reflection:
// reflect -> period 2(n-1), edge NOT repeated
// symmetric -> period 2n, edge repeated
auto reflect_coord = [](int r, int n, bool include_edge) -> int {
if (n == 1) {
return 0;
}
if (include_edge) {
int period = 2 * n;
int m = ((r % period) + period) % period;
return m < n ? m : (2 * n - 1 - m);
} else {
int period = 2 * (n - 1);
int m = ((r % period) + period) % period;
return m < n ? m : (period - m);
}
};
array out = a;
for (size_t i = 0; i < axes.size(); i++) {
int ax = axes[i];
int L = low_pad_size[i];
int H = high_pad_size[i];
if (L == 0 && H == 0) {
continue;
}
int n = out.shape(ax);
int total = L + n + H;
std::vector<int32_t> idx_vec(total);
for (int p = 0; p < total; p++) {
idx_vec[p] = reflect_coord(p - L, n, include_edge);
}
array idx = array(idx_vec.begin(), {total}, int32);
out = take(out, idx, ax, s);
}
return out;
}

array edge_pad(
const array& a,
const std::vector<int>& axes,
Expand Down Expand Up @@ -1449,6 +1496,10 @@ array pad(
{a, astype(pad_value, a.dtype(), s)});
} else if (mode == "edge") {
return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
} else if (mode == "reflect") {
return reflect_pad(a, axes, low_pad_size, high_pad_size, false, s);
} else if (mode == "symmetric") {
return reflect_pad(a, axes, low_pad_size, high_pad_size, true, s);
} else {
std::ostringstream msg;
msg << "Invalid padding mode (" << mode << ") passed to pad";
Expand Down
4 changes: 3 additions & 1 deletion python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3271,7 +3271,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'reflect', 'symmetric'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Pad an array with a constant value

Expand All @@ -3286,6 +3286,8 @@ void init_ops(nb::module_& m) {
mode: Padding mode. One of the following strings:
"constant" (default): Pads with a constant value.
"edge": Pads with the edge values of array.
"reflect": Pads with the reflection of the array, without repeating the edge values.
"symmetric": Pads with the reflection of the array, repeating the edge values.
constant_value (array or scalar, optional): Optional constant value
to pad the edges of the array with.

Expand Down
32 changes: 32 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,6 +2061,38 @@ def test_nan_to_num(self):
out_mx = mx.nan_to_num(a, nan=0.0, posinf=1000, neginf=-1000)
self.assertTrue(np.allclose(out_mx, out_np))

def test_pad_reflect_symmetric(self):
# mx.pad reflect/symmetric must match numpy.pad exactly (it is a gather).
# Covers in-bounds, multi-reflect (pad larger than the axis), asymmetric
# per-axis widths, zero-width sides, and degenerate axes (n == 1, n == 2).
cases = [
((8,), [(2, 3)]),
((8,), [(0, 4)]),
((8,), [(3, 0)]),
((8,), [(7, 8)]),
((4,), [(10, 7)]), # multi-reflect
((4,), [(20, 20)]), # multi-reflect, both sides
((3,), [(9, 1)]), # multi-reflect
((1,), [(3, 2)]), # degenerate axis
((2,), [(5, 6)]), # smallest non-trivial, multi-reflect
((5, 6), [(2, 3), (1, 2)]),
((5, 6), [(9, 9), (11, 0)]), # both axes multi-reflect
((3, 4, 5), [(1, 1), (0, 0), (2, 2)]),
((3, 4, 5), [(4, 4), (0, 0), (7, 3)]),
]
for mode in ("reflect", "symmetric"):
for shape, pw in cases:
a_npy = np.random.randn(*shape).astype(np.float32)
a_mlx = mx.array(a_npy)
b_npy = np.pad(a_npy, pw, mode=mode)
b_mlx = mx.pad(a_mlx, pw, mode=mode)
self.assertEqual(b_mlx.shape, tuple(b_npy.shape))
self.assertTrue(
np.array_equal(np.array(b_mlx), b_npy),
msg=f"mismatch mode={mode} shape={shape} pad={pw}",
)
self.assertEqual(b_mlx.dtype, mx.float32)

def test_as_strided(self):
x_npy = np.random.randn(128).astype(np.float32)
x_mlx = mx.array(x_npy)
Expand Down
35 changes: 35 additions & 0 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2868,6 +2868,41 @@ TEST_CASE("test pad") {
0.0f},
{4, 4});
CHECK(array_equal(padded_x, expected).item<bool>());

// reflect padding (mirror without repeating the edge value)
x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {5});
CHECK(array_equal(
pad(x, {{2, 2}}, array(0.0f), "reflect"),
array(
{3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 4.0f, 3.0f},
{9}))
.item<bool>());
CHECK(array_equal(
pad(x, {{0, 3}}, array(0.0f), "reflect"),
array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 4.0f, 3.0f, 2.0f}, {8}))
.item<bool>());

// symmetric padding (mirror repeating the edge value)
CHECK(array_equal(
pad(x, {{2, 2}}, array(0.0f), "symmetric"),
array(
{2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 5.0f, 4.0f},
{9}))
.item<bool>());
CHECK(array_equal(
pad(x, {{3, 0}}, array(0.0f), "symmetric"),
array({3.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {8}))
.item<bool>());

// multi-reflect: pad larger than the axis repeats the reflection (numpy parity)
x = array({1.0f, 2.0f, 3.0f}, {3});
CHECK(array_equal(
pad(x, {{5, 5}}, array(0.0f), "reflect"),
array(
{2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f,
2.0f, 3.0f, 2.0f},
{13}))
.item<bool>());
}

TEST_CASE("test power") {
Expand Down