Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,36 +732,46 @@ def setup_method(self):
super().setup_method()

def test_vec_vec_kron_raises(self):
"""Ensure kron raises an error for 1D inputs."""
x = vector()
y = vector()
with pytest.raises(
TypeError, match="kron: inputs dimensions must sum to 3 or more"
):
kron(x, y)

@pytest.mark.parametrize("static_shape", [True, False])
@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
def test_perform(self, shp0, shp1):
def test_perform(self, static_shape, shp0, shp1):
"""Test kron execution and symbolic shape inference."""
if len(shp0) + len(shp1) == 2:
pytest.skip("Sum of shp0 and shp1 must be more than 2")

x = tensor(dtype="floatX", shape=shp0)
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)

y = tensor(dtype="floatX", shape=shp1)
b = self.rng.random(shp1).astype(config.floatX)

# Using np.kron to evaluate expected numerical output and dimensionality
np_val = np.kron(a, b)

# Determine tensor shapes
shape_x = shp0 if static_shape else (None,) * len(shp0)
shape_y = shp1 if static_shape else (None,) * len(shp1)
shape_out = np_val.shape if static_shape else (None,) * np_val.ndim

x = tensor(dtype="floatX", shape=shape_x)
y = tensor(dtype="floatX", shape=shape_y)

kron_xy = kron(x, y)

# Assert symbolic shape inference immediately after node creation
assert kron_xy.type.shape == shape_out

f = function([x, y], kron_xy)
out = f(a, b)

# Using np.kron to compare outputs
np_val = np.kron(a, b)
np.testing.assert_allclose(out, np_val)

# Regression test for issue #1867
assert kron_xy.type.shape == np_val.shape

@pytest.mark.parametrize(
"i, shp0, shp1",
[(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))],
Expand Down
Loading