Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,8 @@ array tile(
}
expand_shape.push_back(shape[i]);
broad_shape.push_back(shape[i]);
final_shape.push_back(reps[i] * shape[i]);
final_shape.push_back(
safe_cast(static_cast<int64_t>(reps[i]) * shape[i], "tile"));
}

auto x = reshape(arr, std::move(expand_shape), s);
Expand Down Expand Up @@ -6350,7 +6351,10 @@ array roll(
if (size == 0) {
continue; // skip rolling this axis if it has size 0
}
auto split_index = (sh < 0) ? (-sh) % size : size - sh % size;
// Promote to 64-bit so negating a shift of INT_MIN does not overflow.
int64_t sh64 = sh;
auto split_index = static_cast<ShapeElem>(
(sh64 < 0) ? (-sh64) % size : size - sh64 % size);

auto parts = split(result, Shape{split_index}, ax, s);
std::swap(parts[0], parts[1]);
Expand All @@ -6369,11 +6373,11 @@ array roll(const array& a, int shift, StreamOrDevice s /* = {} */) {
}

array roll(const array& a, const Shape& shift, StreamOrDevice s /* = {} */) {
int total_shift = 0;
for (auto& s : shift) {
total_shift += s;
int64_t total_shift = 0;
for (auto& sh : shift) {
total_shift += sh;
}
return roll(a, total_shift, s);
return roll(a, safe_cast(total_shift, "roll"), s);
}

array roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) {
Expand All @@ -6394,11 +6398,12 @@ array roll(
const Shape& shift,
int axis,
StreamOrDevice s /* = {} */) {
int total_shift = 0;
for (auto& s : shift) {
total_shift += s;
int64_t total_shift = 0;
for (auto& sh : shift) {
total_shift += sh;
}
return roll(a, Shape{total_shift}, std::vector<int>{axis}, s);
return roll(
a, Shape{safe_cast(total_shift, "roll")}, std::vector<int>{axis}, s);
}

array real(const array& a, StreamOrDevice s /* = {} */) {
Expand Down
21 changes: 21 additions & 0 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4284,3 +4284,24 @@ TEST_CASE("test max min with nan") {
CHECK(array_equal(max_result, expected, true).item<bool>());
CHECK(array_equal(min_result, expected, true).item<bool>());
}

TEST_CASE("roll and tile shape overflow") {
// Shape arithmetic must not overflow (signed-int UB) for large but otherwise
// valid int32 inputs; out-of-range results are rejected gracefully.
// https://github.com/ml-explore/mlx/issues/3601

// tile: reps * dim exceeding int32 raises instead of overflowing.
CHECK_THROWS_AS(tile(zeros({2}), {2147483647}), std::overflow_error);

// roll: a shift sum exceeding int32 raises instead of overflowing.
CHECK_THROWS_AS(
roll(zeros({4}), Shape{2147483647, 2147483647}), std::overflow_error);
CHECK_THROWS_AS(
roll(zeros({4}), Shape{2147483647, 2147483647}, 0), std::overflow_error);

// roll: a shift of INT_MIN must not negate-overflow. INT_MIN mod 4 == 0, so
// rolling a size-4 axis by INT_MIN is the identity.
auto x = array({1, 2, 3, 4});
auto rolled = roll(x, -2147483647 - 1);
CHECK(array_equal(rolled, x).item<bool>());
}