Skip to content

Add reflect and symmetric padding modes to mx.pad#3608

Open
katlun-lgtm wants to merge 1 commit into
ml-explore:mainfrom
katlun-lgtm:add-reflect-symmetric-pad
Open

Add reflect and symmetric padding modes to mx.pad#3608
katlun-lgtm wants to merge 1 commit into
ml-explore:mainfrom
katlun-lgtm:add-reflect-symmetric-pad

Conversation

@katlun-lgtm
Copy link
Copy Markdown

Summary

Adds numpy.pad-compatible "reflect" and "symmetric" modes to mx.pad. These join the existing "constant" and "edge" modes.

Both modes match numpy.pad semantics for arbitrary pad sizes — when the pad width exceeds the axis length the reflection repeats, exactly as NumPy does. (Earlier attempts at these modes were limited to pad < dim; this implementation removes that restriction.)

  • reflect — mirror padding that does not repeat the edge value (period 2(n-1)).
  • symmetric — mirror padding that does repeat the edge value (period 2n).

Implementation

reflect_pad (in mlx/ops.cpp) builds a per-axis index map with a triangle-wave reflection function and gathers with take — one take per padded axis. A degenerate axis of length 1 maps every coordinate to 0. No new primitive or kernel is introduced; it composes existing ops, so it works on every backend and is differentiable for free.

Files changed

  • mlx/ops.cppreflect_pad helper + reflect / symmetric dispatch branches in pad.
  • python/src/ops.cpp — extend the mode Literal and docstring.
  • python/tests/test_ops.pytest_pad_reflect_symmetric.
  • tests/ops_tests.cpp — reflect/symmetric CHECK cases incl. multi-reflect.

Testing

All run locally on an M3 Max:

  • Python test_pad_reflect_symmetric — 13 shape/pad-width cases × 2 modes compared element-for-element against numpy.pad (in-bounds, multi-reflect where pad ≫ axis, asymmetric per-axis, zero-width sides, and degenerate axes n==1, n==2). Exact match.
  • C++ tests/ops_tests.cpp "test pad" — 9 assertions pass.
  • Full C++ suite251 cases / 251 passed, 3442 assertions / 0 failed.
>>> import mlx.core as mx
>>> a = mx.array([1, 2, 3])
>>> mx.pad(a, 2, mode="reflect")
array([3, 2, 1, 2, 3, 2, 1], dtype=int32)
>>> mx.pad(a, 2, mode="symmetric")
array([2, 1, 1, 2, 3, 3, 2], dtype=int32)

Implements numpy.pad-compatible "reflect" and "symmetric" modes for
mx.pad, matching numpy semantics for arbitrary pad sizes (the reflection
repeats when the pad width exceeds the axis length).

- mlx/ops.cpp: reflect_pad helper builds a per-axis triangle-wave index
  map and gathers with take; one take per padded axis. reflect uses
  period 2(n-1) and skips the edge; symmetric uses period 2n and repeats
  the edge. n==1 maps to 0.
- python/src/ops.cpp: extend the pad mode Literal and docstring.
- python/tests/test_ops.py: test_pad_reflect_symmetric covers in-bounds,
  multi-reflect, asymmetric per-axis, zero-width sides, and degenerate
  axes (n==1, n==2), checked against numpy.pad.
- tests/ops_tests.cpp: reflect/symmetric CHECK cases incl. multi-reflect.
@katlun-lgtm katlun-lgtm marked this pull request as ready for review May 30, 2026 20:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant