Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14602,6 +14602,10 @@ defmodule Nx do
[%T{vectorized_axes: vectorized_axes} = tensor, indices] =
broadcast_vectors([tensor, indices], align_ranks: false)

if indices.shape == {} do
raise ArgumentError, "expected indices rank to be at least 1, got: 0"
end

axes = indexed_axes(tensor, indices, opts)

unless Nx.Type.integer?(indices.type) do
Expand Down Expand Up @@ -16823,32 +16827,37 @@ defmodule Nx do
raise ArgumentError, "expected n to be a non-negative integer, got: #{inspect(n)}"
end

{iota_shape, start, stop} =
case {start.shape, stop.shape} do
{shape, shape} ->
iota_shape = Tuple.insert_at(shape, tuple_size(shape), n)
{iota_shape, new_axis(start, -1, opts[:name]), new_axis(stop, -1, opts[:name])}
if n == 1 do
# Special case: single point returns start value
start |> new_axis(-1, opts[:name]) |> as_type(opts[:type])
else
{iota_shape, start, stop} =
case {start.shape, stop.shape} do
{shape, shape} ->
iota_shape = Tuple.insert_at(shape, tuple_size(shape), n)
{iota_shape, new_axis(start, -1, opts[:name]), new_axis(stop, -1, opts[:name])}

{start_shape, stop_shape} ->
raise ArgumentError,
"expected start and stop to have the same shape. Got shapes #{inspect(start_shape)} and #{inspect(stop_shape)}"
end
{start_shape, stop_shape} ->
raise ArgumentError,
"expected start and stop to have the same shape. Got shapes #{inspect(start_shape)} and #{inspect(stop_shape)}"
end

iota = iota(iota_shape, axis: -1, type: opts[:type], vectorized_axes: vectorized_axes)
iota = iota(iota_shape, axis: -1, type: opts[:type], vectorized_axes: vectorized_axes)

divisor =
if opts[:endpoint] do
n - 1
else
n
end
divisor =
if opts[:endpoint] do
n - 1
else
n
end

step = Nx.subtract(stop, start) |> Nx.divide(divisor)
step = Nx.subtract(stop, start) |> Nx.divide(divisor)

iota
|> multiply(step)
|> add(start)
|> as_type(opts[:type])
iota
|> multiply(step)
|> add(start)
|> as_type(opts[:type])
end
end

@doc """
Expand Down
4 changes: 3 additions & 1 deletion nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1675,7 +1675,7 @@ defmodule Nx.BinaryBackend do
{acc_offset, acc_binary} ->
num_vals_before = div(offset - acc_offset, output_size)
vals_before = List.duplicate(init_binary, num_vals_before)
source_val = to_binary(value)
source_val = value |> Nx.as_type(output_type) |> to_binary()
new_binary = :erlang.list_to_bitstring([vals_before, source_val])

{offset + output_size, <<acc_binary::bitstring, new_binary::bitstring>>}
Expand Down Expand Up @@ -1849,6 +1849,8 @@ defmodule Nx.BinaryBackend do
|> then(&from_binary(out, &1))
end

defp bin_slice(data, _shape, _size, [], [], [], _output_shape), do: data

defp bin_slice(data, shape, size, start_indices, lengths, strides, output_shape) do
start_indices = clamp_indices(start_indices, shape, lengths)

Expand Down
173 changes: 173 additions & 0 deletions nx/test/nx/edge_cases_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
defmodule Nx.EdgeCasesTest do
@moduledoc """
Regression tests for Nx edge cases:
- window_scatter_max/min on f64
- Nx.slice on scalar tensor
- Nx.linspace with n=1
- Nx.gather with scalar indices

IEEE 754 overflow/domain/divzero tests are skipped pending
upstream fix in the Complex library (elixir-nx/complex#29).
"""
use ExUnit.Case, async: true

# ── IEEE 754 tests (pending Complex library fix) ───────────────────
# These tests require elixir-nx/complex#29 to be released.
# Once Complex handles :math overflow/domain errors, these
# will pass without any changes to BinaryBackend.

describe "unary overflow returns Inf instead of crashing" do
@tag :skip
test "exp(large) returns Inf" do
assert Nx.to_number(Nx.exp(Nx.tensor(1000.0))) == :infinity
end

@tag :skip
test "expm1(large) returns Inf" do
assert Nx.to_number(Nx.expm1(Nx.tensor(1000.0))) == :infinity
end

@tag :skip
test "sinh(large positive) returns Inf" do
assert Nx.to_number(Nx.sinh(Nx.tensor(1000.0))) == :infinity
end

@tag :skip
test "sinh(large negative) returns -Inf" do
assert Nx.to_number(Nx.sinh(Nx.tensor(-1000.0))) == :neg_infinity
end

@tag :skip
test "cosh(large) returns Inf" do
assert Nx.to_number(Nx.cosh(Nx.tensor(1000.0))) == :infinity
end

@tag :skip
test "sigmoid(large positive) returns 1.0" do
assert Nx.to_number(Nx.sigmoid(Nx.tensor(1.0e6))) == 1.0
end

@tag :skip
test "sigmoid(large negative) returns 0.0" do
assert Nx.to_number(Nx.sigmoid(Nx.tensor(-1.0e6))) == 0.0
end
end

describe "domain errors return NaN instead of crashing" do
@tag :skip
test "asin outside [-1, 1]" do
assert Nx.to_number(Nx.asin(Nx.tensor(2.0))) == :nan
end

@tag :skip
test "acos outside [-1, 1]" do
assert Nx.to_number(Nx.acos(Nx.tensor(2.0))) == :nan
end

@tag :skip
test "acosh below 1" do
assert Nx.to_number(Nx.acosh(Nx.tensor(0.5))) == :nan
end

@tag :skip
test "atanh outside (-1, 1)" do
assert Nx.to_number(Nx.atanh(Nx.tensor(2.0))) == :nan
end

@tag :skip
test "atanh at boundaries returns Inf/-Inf" do
assert Nx.to_number(Nx.atanh(Nx.tensor(1.0))) == :infinity
assert Nx.to_number(Nx.atanh(Nx.tensor(-1.0))) == :neg_infinity
end
end

describe "division by zero returns Inf/NaN instead of crashing" do
@tag :skip
test "positive / 0.0 = Inf" do
assert Nx.to_number(Nx.divide(Nx.tensor(1.0), Nx.tensor(0.0))) == :infinity
end

@tag :skip
test "negative / 0.0 = -Inf" do
assert Nx.to_number(Nx.divide(Nx.tensor(-1.0), Nx.tensor(0.0))) == :neg_infinity
end

@tag :skip
test "0.0 / 0.0 = NaN" do
assert Nx.to_number(Nx.divide(Nx.tensor(0.0), Nx.tensor(0.0))) == :nan
end

@tag :skip
test "normal division still works" do
assert Nx.to_number(Nx.divide(Nx.tensor(10.0), Nx.tensor(2.0))) == 5.0
end
end

# ── Active tests (fixes in this PR) ────────────────────────────────

describe "window_scatter_max/min on f64" do
test "window_scatter_max works with f64" do
t = Nx.iota({6}, type: :f64)
s = Nx.iota({3}, type: :f64)
init = Nx.tensor(0.0, type: :f64)
result = Nx.window_scatter_max(t, s, init, {2}, strides: [2], padding: :valid)
assert Nx.type(result) == {:f, 64}
assert Nx.shape(result) == {6}
end

test "window_scatter_min works with f64" do
t = Nx.iota({6}, type: :f64)
s = Nx.iota({3}, type: :f64)
init = Nx.tensor(0.0, type: :f64)
result = Nx.window_scatter_min(t, s, init, {2}, strides: [2], padding: :valid)
assert Nx.type(result) == {:f, 64}
assert Nx.shape(result) == {6}
end
end

describe "scalar slice" do
test "slice of scalar tensor returns scalar" do
t = Nx.tensor(42)
result = Nx.slice(t, [], [])
assert Nx.to_number(result) == 42
end

test "scalar slice with f64" do
t = Nx.tensor(3.14, type: :f64)
result = Nx.slice(t, [], [])
assert_in_delta Nx.to_number(result), 3.14, 1.0e-10
end
end

describe "linspace n=1" do
test "linspace n=1 returns start value" do
result = Nx.linspace(0, 10, n: 1)
assert Nx.shape(result) == {1}
assert Nx.to_flat_list(result) == [0.0]
end

test "linspace n=1 with same start/stop" do
result = Nx.linspace(5, 5, n: 1)
assert Nx.to_flat_list(result) == [5.0]
end

test "linspace n=2 still works" do
result = Nx.linspace(0, 10, n: 2)
assert Nx.to_flat_list(result) == [0.0, 10.0]
end
end

describe "gather scalar indices error" do
test "gather raises correct error on scalar indices" do
assert_raise ArgumentError, ~r/expected indices rank to be at least 1/, fn ->
Nx.gather(Nx.iota({3}), Nx.tensor(0))
end
end

test "gather with valid indices still works" do
t = Nx.iota({3, 4})
result = Nx.gather(t, Nx.tensor([[0, 0], [2, 3]]))
assert Nx.to_flat_list(result) == [0, 11]
end
end
end
Loading