Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 184 additions & 131 deletions pytensor/compile/builders.py

Large diffs are not rendered by default.

130 changes: 127 additions & 3 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import warnings
import weakref
from collections.abc import (
Hashable,
Iterable,
Expand Down Expand Up @@ -824,6 +825,16 @@ def __repr__(self):
def clone(self, **kwargs):
return self

def equals(self, other):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why in the base class. For instance the np.array_equal looks very tensor oriented, but we have types like Slice, RNG, ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this commit was needed in an intermediate form, but now can simply be dropped.

if not isinstance(other, type(self)):
return False
if self.type != other.type:
return False
try:
return np.array_equal(self.data, other.data, equal_nan=True)
except (TypeError, ValueError):
return self.data == other.data

@property
def owner(self) -> None:
return None
Expand All @@ -838,6 +849,119 @@ def value(self):
return self.data


class FrozenConstant(Constant):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again sounds like it's specializing on numerical array types. How does MergeOptimizer find that two constants are equal for merging? Can we reuse that logic?

I wouldn't expect we need a FrozenConstant class in the end, since Constants are frozen by our standards already. The challenge here is more finding whether a new constant was already seen before?

"""A globally-interned Constant for use in frozen graphs.

Two ``FrozenConstant`` instances with the same type and data are the same object.
"""

_cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
_filtered: Any

# FrozenConstant doesn't inherit the scalar mixin that provides .dtype,
# but scalar C code generation expects it on all variables.
@property
def dtype(self):
return self.type.dtype

def __new__(cls, type: _TypeType, data: Any, name: str | None = None):
filtered = type.filter(data)
cache_key = cls._make_key(type, filtered)
if cache_key is not None:
cached = cls._cache.get(cache_key)
if cached is not None:
return cached
instance = object.__new__(cls)
# Store filtered data now so __init__ can skip re-filtering
instance._filtered = filtered
if cache_key is not None:
cls._cache[cache_key] = instance
return instance

def __init__(self, type: _TypeType, data: Any, name: str | None = None):
if hasattr(self, "data"):
return
# Use pre-filtered data from __new__ to avoid a second type.filter() call
AtomicVariable.__init__(self, type, name=name)
self.data = self._filtered
del self._filtered
add_tag_trace(self)

@staticmethod
def _make_key(type, filtered):
if isinstance(filtered, np.ndarray):
from pytensor.tensor.utils import hash_from_ndarray

return type, hash_from_ndarray(filtered)
if isinstance(filtered, np.generic):
from pytensor.tensor.utils import hash_from_ndarray

return type, hash_from_ndarray(np.asarray(filtered))
try:
return type, hash(filtered)
except TypeError:
return None

def __reduce__(self):
return (type(self), (self.type, self.data, self.name))


class FrozenApply(Apply):
"""An immutable, globally-interned Apply node for frozen graphs.

Uses tuples for ``inputs`` and ``outputs`` so mutation raises ``TypeError``
at the language level. Interned by ``(op, inputs)`` — constructing a
``FrozenApply`` with an ``op`` and ``inputs`` that match an existing live
instance returns that instance.
"""

_cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary()

def __new__(
cls, op: "Op", inputs: tuple[Variable, ...], output_types: tuple["Type", ...]
):
# Canonicalize inputs through their owner's outputs to ensure cache hits after unpickling.
inputs = tuple(
inp.owner.outputs[inp.index]
if inp.owner is not None and isinstance(inp.owner, FrozenApply)
else inp
for inp in inputs
)
key = (op, inputs)
cached = cls._cache.get(key)
if cached is not None:
return cached

instance = object.__new__(cls)
instance.op = op
instance.inputs = inputs # type: ignore[assignment]
instance.outputs = tuple( # type: ignore[assignment]
t.variable_type(type=t, owner=instance, index=i)
for i, t in enumerate(output_types)
)
instance.tag = Scratchpad()
cls._cache[key] = instance
return instance

def __init__(self, op, inputs, output_types):
# All initialization is done in __new__
pass

def clone(self, clone_inner_graph: bool = False) -> "Apply":
"""Clone into a mutable Apply node."""
from pytensor.graph.op import HasInnerGraph

new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the boolean clone_inner_graph first which is cheaper than the isinstance

new_op = new_op.clone()

return Apply(new_op, list(self.inputs), [o.clone() for o in self.outputs])

def __reduce__(self):
output_types = tuple(o.type for o in self.outputs)
return (type(self), (self.op, self.inputs, output_types))


def clone(
inputs: Sequence[Variable],
outputs: Sequence[Variable],
Expand Down Expand Up @@ -1154,14 +1278,14 @@ def equal_computations(

for x, y in zip(xs, ys, strict=True):
if not isinstance(x, Variable) and not isinstance(y, Variable):
return np.array_equal(x, y)
return np.array_equal(x, y, equal_nan=True)
if not isinstance(x, Variable):
if isinstance(y, Constant):
return np.array_equal(y.data, x)
return np.array_equal(y.data, x, equal_nan=True)
return False
if not isinstance(y, Variable):
if isinstance(x, Constant):
return np.array_equal(x.data, y)
return np.array_equal(x.data, y, equal_nan=True)
return False
x_is_owned, y_is_owned = (x.owner is not None, y.owner is not None)
if x_is_owned != y_is_owned:
Expand Down
153 changes: 153 additions & 0 deletions pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pytensor.graph.basic import (
Apply,
AtomicVariable,
Constant,
NominalVariable,
Variable,
clone_get_equiv,
)
Expand Down Expand Up @@ -928,3 +930,154 @@ def dprint(self, **kwargs):
from pytensor.printing import debugprint

return debugprint(self, **kwargs)

def freeze(self) -> "FrozenFunctionGraph":
"""Return a frozen, hashable version of this FunctionGraph."""
return FrozenFunctionGraph(self.inputs, self.outputs)


class FrozenFunctionGraph:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They could inherit from a shared base-class. Then for instance x_funcify_FunctionGraph can dispatch on the base class, since they don't care whether it is a Frozen or Regular FunctionGraph?

"""Immutable, hashable function graph for inner graphs of Ops.

All internal nodes are globally interned via ``FrozenApply`` and ``FrozenConstant``. Two ``FrozenFunctionGraph``
instances built from structurally identical source graphs share the same internal objects, so equality reduces to
an identity check on the output tuples.

.. code-block:: python

from pytensor.scalar.basic import float64, add
from pytensor.graph.fg import FunctionGraph

x, y = float64("x"), float64("y")
fg = FunctionGraph([x, y], [add(x, y)])
frozen = fg.freeze()
frozen2 = FunctionGraph([x, y], [add(x, y)]).freeze()

assert frozen == frozen2
assert {frozen: "value"}[frozen2] == "value"
"""

def __init__(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
):
from pytensor.graph.basic import (
FrozenApply,
FrozenConstant,
)

nominal_inputs = tuple(
NominalVariable(i, inp.type, name=inp.name) for i, inp in enumerate(inputs)
)

memo: dict[Variable, Variable] = dict(zip(inputs, nominal_inputs, strict=True))

var_hash: dict[Variable, int] = {}
for i, nm in enumerate(nominal_inputs):
var_hash[nm] = hash(("input", i, nm.type))

for node in toposort(outputs, blockers=inputs):
for inp in node.inputs:
if inp not in memo:
if isinstance(inp, Constant):
fc = FrozenConstant(inp.type, inp.data)
memo[inp] = fc
if fc not in var_hash:
var_hash[fc] = hash(fc)
elif isinstance(inp, AtomicVariable):
# AtomicVariables (e.g. NominalVariables from outer
# scopes) are already interned and hashable.
memo[inp] = inp
if inp not in var_hash:
var_hash[inp] = hash(inp)
else:
raise ValueError(
f"Non-Constant, non-AtomicVariable orphan {inp} found "
"in the graph. All variables must be graph inputs, "
"Constants, AtomicVariables, or produced by Apply "
"nodes reachable from the inputs."
)

new_inputs = tuple(memo[i] for i in node.inputs)
output_types = tuple(out.type for out in node.outputs)
new_node = FrozenApply(node.op, new_inputs, output_types)

input_hashes = tuple(var_hash[i] for i in new_inputs)
node_hash = hash((node.op, input_hashes))
for old_out, new_out in zip(node.outputs, new_node.outputs, strict=True):
memo[old_out] = new_out
var_hash[new_out] = hash((node_hash, new_out.index))

self.inputs: tuple[Variable, ...] = nominal_inputs

resolved_outputs = []
for o in outputs:
mapped = memo.get(o)
# After unpickling, o may be a fresh object whose owner is the (correctly interned) FrozenApply.
# We thus resolve it through its owner to get back the original variable.
if mapped is None and o.owner is not None:
mapped = memo.get(o.owner.outputs[o.index])
if mapped is None:
raise ValueError(
f"Output variable {o} could not be mapped to a frozen graph variable. "
"All outputs must be graph inputs, constants, or produced by Apply nodes "
"reachable from the inputs."
)
resolved_outputs.append(mapped)
self.outputs: tuple[Variable, ...] = tuple(resolved_outputs)

self._structural_hash: int = hash(tuple(var_hash[o] for o in self.outputs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be right, but why do we need to hash intermediate variables? Can't we just hash the outputs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just erring on the side of caution. Can you think of a case where two graphs with different inputs would lead to different outputs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean with equal inputs? But no, if we implement it correctly it shouldn't happen


def __hash__(self):
return self._structural_hash

def __eq__(self, other):
if self is other:
return True
if not isinstance(other, FrozenFunctionGraph):
return False
if self._structural_hash != other._structural_hash:
return False
if self.outputs == other.outputs:
return True
# Hash match but output identity mismatch — likely a hash collision
# or interning bug. Fall back to structural comparison.
import warnings

Comment on lines +1044 to +1047
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? If it's a bug raise. You already showed the outputs are different

from pytensor.graph.basic import equal_computations

if (
len(self.outputs) == len(other.outputs)
and len(self.inputs) == len(other.inputs)
and equal_computations(
list(self.outputs),
list(other.outputs),
in_xs=list(self.inputs),
in_ys=list(other.inputs),
)
):
warnings.warn(
"FrozenFunctionGraph: structurally equal graphs did not share "
"interned objects. This may indicate an interning bug.",
stacklevel=2,
)
return True
return False

def __repr__(self):
return f"FrozenFunctionGraph(inputs={list(self.inputs)}, outputs={list(self.outputs)})"

def __reduce__(self):
return (type(self), (list(self.inputs), list(self.outputs)))

@property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These properties are expensive. Any reason not to define them at __init__ as regular attributes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess just to avoid paying the up-front cost in case we don't end up using them. I can save the results the first time and re-use thereafter.

def apply_nodes(self) -> set[Apply]:
return set(applys_between(self.inputs, self.outputs))

def toposort(self) -> list[Apply]:
return list(toposort(self.outputs, blockers=self.inputs))

@property
def variables(self) -> set[Variable]:
return set(vars_between(self.inputs, self.outputs))
Loading
Loading