Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,63 @@ def local_0_dot_x(fgraph, node):
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([Dot])
def local_1_dot_x(fgraph, node):
"""Simplify dot(x, y) when x or y is a tensor of all ones.

dot(x, ones(n, p)) -> alloc(sum(x, axis=1, keepdims=True), x.shape[0], y.shape[1])
dot(ones(m, n), y) -> alloc(sum(y, axis=0, keepdims=True), x.shape[0], y.shape[1])
dot(ones(m, n), ones(n, p)) -> alloc(cast(n, out_dtype), x.shape[0], y.shape[1])
"""
x, y = node.inputs
out_dtype = node.outputs[0].type.dtype

x_is_ones = (
get_underlying_scalar_constant_value(
x, only_process_constants=False, raise_not_constant=False
)
== 1
)
y_is_ones = (
get_underlying_scalar_constant_value(
y, only_process_constants=False, raise_not_constant=False
)
== 1
)

if not (x_is_ones or y_is_ones):
return None

if x_is_ones and y_is_ones:
# ones(m, n) @ ones(n, p) = n * ones(m, p)
n = cast(x.shape[1], out_dtype)
result = alloc(n, x.shape[0], y.shape[1])
copy_stack_trace(node.outputs[0], result)
return [result]

if y_is_ones:
# x @ ones(n, p) -> sum(x, axis=1, keepdims=True) broadcast to (m, p)
result = alloc(
cast(pt_sum(x, axis=1, keepdims=True), out_dtype),
x.shape[0],
y.shape[1],
)
copy_stack_trace(node.outputs[0], result)
return [result]

if x_is_ones:
# ones(m, n) @ y -> sum(y, axis=0, keepdims=True) broadcast to (m, p)
result = alloc(
cast(pt_sum(y, axis=0, keepdims=True), out_dtype),
x.shape[0],
y.shape[1],
)
copy_stack_trace(node.outputs[0], result)
return [result]


@register_stabilize
@node_rewriter([blockwise_of(BlockDiagonal)])
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
Expand Down
66 changes: 66 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4884,6 +4884,72 @@ def test_local_dot_to_mul_unspecified_length_1():
)


@pytest.mark.parametrize(
"x_shape, y_shape, x_is_ones, y_is_ones",
[
# Right-side ones: x @ ones(n, p)
((3, 4), (4, 2), False, True),
# Left-side ones: ones(m, n) @ y
((3, 4), (4, 2), True, False),
# Both ones
((3, 4), (4, 2), True, True),
# 1x1 identity: x @ ones(1, 1)
((5, 1), (1, 1), False, True),
# 1x1 identity: ones(1, 1) @ y
((1, 1), (1, 5), True, False),
],
ids=[
"right_ones",
"left_ones",
"both_ones",
"right_1x1",
"left_1x1",
],
)
def test_local_1_dot_x(x_shape, y_shape, x_is_ones, y_is_ones):
rng = np.random.default_rng(42)
mode = get_default_mode()

x = pt.matrix("x", dtype="float64")
y = pt.matrix("y", dtype="float64")

if x_is_ones:
x_val = np.ones(x_shape, dtype="float64")
else:
x_val = rng.normal(size=x_shape).astype("float64")

if y_is_ones:
y_val = np.ones(y_shape, dtype="float64")
else:
y_val = rng.normal(size=y_shape).astype("float64")

if x_is_ones and y_is_ones:
out = dot(pt.constant(x_val), pt.constant(y_val))
fn = pytensor.function([], out, mode=mode)
elif x_is_ones:
out = dot(pt.constant(x_val), y)
fn = pytensor.function([y], out, mode=mode)
else:
out = dot(x, pt.constant(y_val))
fn = pytensor.function([x], out, mode=mode)

# Verify no Dot node remains in the graph
assert not any(isinstance(node.op, Dot) for node in fn.maker.fgraph.apply_nodes), (
"Dot op should have been rewritten away"
)

# Verify numerical correctness
expected = np.dot(x_val, y_val)
if x_is_ones and y_is_ones:
result = fn()
elif x_is_ones:
result = fn(y_val)
else:
result = fn(x_val)

np.testing.assert_allclose(result, expected)


class TestBlockDiagDotToDotBlockDiag:
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
Expand Down