Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/assumptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytensor.assumptions.permutation
import pytensor.assumptions.positive_definite
import pytensor.assumptions.reshape
import pytensor.assumptions.scan
import pytensor.assumptions.selection
import pytensor.assumptions.shape
import pytensor.assumptions.subtensor
Expand Down
181 changes: 181 additions & 0 deletions pytensor/assumptions/scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from pytensor.assumptions.core import (
ALL_KEYS,
AssumptionFeature,
FactState,
check_assumption,
register_assumption,
)
from pytensor.assumptions.specify import SpecifyAssumptions
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor.subtensor import IncSubtensor


def _recurrent_init_fact(buffer_var, key, feature, fallback):
"""Return the *key* fact of a recurrence's initial value.

Scan stores a ``sit-sot`` recurrence's initial value as
``SetSubtensor{:n_taps}(AllocEmpty(...), init)``. The buffer's own fact is
UNKNOWN -- the rows scan has yet to fill are uninitialised -- so read the
fact of the written ``init`` instead. Return *fallback* when the buffer is
not that shape.
"""
owner = buffer_var.owner
if (
owner is not None
and isinstance(owner.op, IncSubtensor)
and owner.op.set_instead_of_inc
):
return feature.get(owner.inputs[1], key)
return fallback


def scan_delegate(key, op, feature, fgraph, node, input_states):
"""Infer *key* for a :class:`Scan`'s outer outputs by delegating into its inner graph.

The outer-input facts seed the matching inner inputs; the inner graph is then
inferred and the inner-output facts are mapped back onto the outer outputs.

For a non-recurrent (``nit-sot``) output the per-step inner output is stacked
along a new leading axis, so the fact carries straight through. For a
recurrent (``sit-sot``) output the carried state is seeded from the
recurrence's initial value, and the fact is kept only when the loop body
reproduces it -- a one-step fixpoint, exact because the per-key lattice
(``UNKNOWN`` < ``TRUE``) leaves no room to iterate. Multi-output recurrences
(``mit-mot``) are left UNKNOWN.
"""
mappings = op.get_oinp_iinp_iout_oout_mappings()
inner_inputs = op.inner_inputs
inner_outputs = op.inner_outputs

# Inner inputs that carry recurrent state. Their outer input is the sit-sot
# buffer; the buffer's own fact is UNKNOWN, so seed from the initial value.
recurrent_iinps = {
iidx
for iidxs in mappings["inner_inp_from_outer_out"].values()
for iidx in iidxs
}

inner_feature = AssumptionFeature()
op.fgraph.attach_feature(inner_feature)
try:
# Seed each inner input with the fact of the outer input feeding it.
# The seed is written straight into the cache: an inner input is a
# graph leaf, so this is the only way to inject a fact onto it.
for iinp_idx, iinp in enumerate(inner_inputs):
outer_iidx = mappings["outer_inp_from_inner_inp"][iinp_idx]
seed = input_states[outer_iidx]
if iinp_idx in recurrent_iinps:
seed = _recurrent_init_fact(node.inputs[outer_iidx], key, feature, seed)
if seed is not FactState.UNKNOWN:
inner_feature.cache[(iinp, key)] = seed
inner_feature._var_to_keys.setdefault(iinp, set()).add(key)

out_states = [FactState.UNKNOWN] * len(node.outputs)
for outer_oidx in range(len(node.outputs)):
inner_oidxs = mappings["inner_out_from_outer_out"].get(outer_oidx, [])
if len(inner_oidxs) != 1:
# Multi-output (mit-mot) recurrences are left UNKNOWN for now.
continue
fact = inner_feature.get(inner_outputs[inner_oidxs[0]], key)

if mappings["inner_inp_from_outer_out"].get(outer_oidx):
# Recurrent: the fact survives only if the loop body reproduces
# the initial value's fact.
outer_iidx = mappings["outer_inp_from_outer_out"][outer_oidx]
init_fact = _recurrent_init_fact(
node.inputs[outer_iidx], key, feature, input_states[outer_iidx]
)
if fact is not init_fact:
fact = FactState.UNKNOWN
out_states[outer_oidx] = fact
return out_states
finally:
op.fgraph.remove_feature(inner_feature)


for _key in ALL_KEYS:
register_assumption(_key, Scan)(scan_delegate)


@node_rewriter([Scan])
def push_assumptions_into_scan(fgraph, node):
"""Push structural assumptions from a Scan's sequence and non-sequence inputs
onto the matching inner inputs.

An inner input is a bare leaf, so an ``assume`` on the outer variable is
invisible to rewrites of the inner graph. This re-asserts it with a
:class:`SpecifyAssumptions` node inside, so those rewrites can fire -- e.g.
``inv(X) @ y`` of a positive-definite :math:`X` specializes to a Cholesky
solve within the loop body. Matrix properties are invariant to batch axes,
so the assertion is valid for every per-step slice.

Recurrent inner inputs are excluded: the loop body need not preserve the
initial value's properties, so the carried state cannot be assumed to keep
them past the first step.
"""
scan_op = node.op
inner_inputs = scan_op.inner_inputs
non_recurrent = set(scan_op.inner_seqs(inner_inputs))
non_recurrent.update(scan_op.inner_non_seqs(inner_inputs))
outer_from_inner = scan_op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
]

new_facts = {}
for inner_idx, inner_inp in enumerate(inner_inputs):
if inner_inp not in non_recurrent:
continue
clients = scan_op.fgraph.clients.get(inner_inp, ())
if any(
not isinstance(client, str) and isinstance(client.op, SpecifyAssumptions)
for client, _ in clients
):
# Already carries an inner assertion -- skip to avoid re-firing.
continue
outer_inp = node.inputs[outer_from_inner[inner_idx]]
facts = {
key.name: FactState.TRUE
for key in ALL_KEYS
if check_assumption(fgraph, outer_inp, key)
}
if facts:
new_facts[inner_inp] = facts

if not new_facts:
return None

# Rebuild the inner graph over fresh leaves, splicing the assertions on.
replace = {}
new_inner_inputs = []
for inner_inp in inner_inputs:
dummy = inner_inp.type()
new_inner_inputs.append(dummy)
facts = new_facts.get(inner_inp)
replace[inner_inp] = SpecifyAssumptions(facts)(dummy) if facts else dummy
new_inner_outputs = clone_replace(scan_op.inner_outputs, replace=replace)

new_scan_op = Scan(
new_inner_inputs,
new_inner_outputs,
scan_op.info,
mode=scan_op.mode,
profile=scan_op.profile,
truncate_gradient=scan_op.truncate_gradient,
name=scan_op.name,
allow_gc=scan_op.allow_gc,
)
new_outs = new_scan_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, new_outs)
return new_outs


scan_seqopt1.register(
push_assumptions_into_scan.__name__,
dfs_rewriter(push_assumptions_into_scan, ignore_newtrees=True),
"fast_run",
"scan",
position=1,
)
94 changes: 67 additions & 27 deletions pytensor/scan/rewriting/push_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.replace import clone_replace
Expand Down Expand Up @@ -80,22 +80,71 @@ def add_to_replace(y):
assert len(inner_non_seqs) == len(outer_non_seqs)
assert len(inner_seqs) == len(outer_seqs)

inner_seqs_set = set(inner_seqs)
inner_clients = op.fgraph.clients

def _is_base_pushable(x):
# Pullable without seeing through any marker: a non-sequence placeholder, an
# already-hoisted node's output, or a constant.
return (
x in inner_non_seqs_set
or x.owner in to_remove_set
or isinstance(x, Constant)
)

def _is_pushable_input(x):
# ``_is_base_pushable``, or -- seen through -- a marker op (``TypeCastingOp``,
# e.g. an ``assume()``) wrapping pullable inputs. Seeing through the marker
# lets loop-invariant work that flows through a declared fact (``R @ Q @ R.T``
# with R a selection) hoist out instead of being recomputed every step because
# it is anchored to the never-hoisted marker.
return _is_base_pushable(x) or (
x.owner is not None
and isinstance(x.owner.op, TypeCastingOp)
and all(_is_pushable_input(i) for i in x.owner.inputs)
)

def _feeds_inner_seq(nd):
# Does an in-loop consumer of ``nd`` read a sequence? If so that consumer may
# still specialize against the sequence using a fact carried by a marker among
# ``nd``'s inputs (e.g. ``inv(pd) @ y_t`` -> a Cholesky solve). Hoisting ``nd``
# out -- which only seeing through the marker enables -- would split it from
# the sequence and preempt that specialization, so keep it inside.
for out in nd.outputs:
for client, _ in inner_clients.get(out, ()):
if isinstance(client, Apply) and any(
i in inner_seqs_set for i in client.inputs
):
return True
return False

def _outer_input_for(x):
# The outer-graph stand-in for a pushable inner input.
if x in inner_non_seqs_set:
return outer_non_seqs[inner_non_seqs_map[x]]
if x in to_replace_set:
return replace_with_out[to_replace_map[x]]
if isinstance(x, Constant):
return x
# Marker seen through: rebuild it over its argument's outer stand-in so the
# declared fact is reproduced outside the loop.
return x.owner.op(*[_outer_input_for(i) for i in x.owner.inputs])

for nd in local_fgraph_topo:
if ( # we haven't already looked at this node
nd not in to_remove_set
and all(
(
(x in inner_non_seqs_set)
or (x.owner in to_remove_set)
or isinstance(x, Constant)
)
for x in nd.inputs
)
# We can (supposedly) do this because the assumption is that a
# `ViewOp` or `DeepCopyOp` will be just at the end of the
# function and not somewhere in the middle
and not isinstance(nd.op, ViewOp)
and all(_is_pushable_input(x) for x in nd.inputs)
# Marker ops carry no computation; hoisting them may strip an inner-graph
# hint a later rewrite needs -- and is zero compute saving anyway.
and not isinstance(nd.op, TypeCastingOp)
and not isinstance(nd.op, DeepCopyOp)
# A node pullable only by seeing through a marker, whose result still feeds
# a sequence-dependent op, stays in: that op may yet specialize against the
# sequence using the marked fact.
and not (
any(not _is_base_pushable(x) for x in nd.inputs)
and _feeds_inner_seq(nd)
)
):
# We have a candidate node to remove from the inner-graph

Expand All @@ -108,20 +157,11 @@ def add_to_replace(y):
to_remove_set.add(nd)
new_inputs = []
for old_input in nd.inputs:
if old_input in inner_non_seqs_set:
# This is case a), so we want to use the corresponding
# outer-graph input as the input to our new pushed-out node
_idx = inner_non_seqs_map[old_input]
new_input = outer_non_seqs[_idx]
elif old_input in to_replace_set:
# This is case b), so we want to use the new pushed-out node
# as the input to this new pushed-out node
new_input = replace_with_out[to_replace_map[old_input]]
else:
assert isinstance(old_input, Constant)
new_input = old_input

new_input = old_input.type.filter_variable(new_input)
# Map each inner input to its outer stand-in: a non-sequence's outer
# value (case a), an already-pushed-out node's output (case b), a
# constant, or a marker rebuilt over the outer stand-in of its
# argument (an ``assume()`` seen through).
new_input = old_input.type.filter_variable(_outer_input_for(old_input))
new_inputs.append(new_input)

pushed_out_node = nd.op.make_node(*new_inputs)
Expand Down
Loading
Loading