-
Notifications
You must be signed in to change notification settings - Fork 175
Allow freezing of FunctionGraph for hashing #1908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
06a0f17
c3b95d4
5652a86
59cb109
d60bdab
08609fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import abc | ||
| import warnings | ||
| import weakref | ||
| from collections.abc import ( | ||
| Hashable, | ||
| Iterable, | ||
|
|
@@ -824,6 +825,16 @@ def __repr__(self): | |
| def clone(self, **kwargs): | ||
| return self | ||
|
|
||
| def equals(self, other): | ||
| if not isinstance(other, type(self)): | ||
| return False | ||
| if self.type != other.type: | ||
| return False | ||
| try: | ||
| return np.array_equal(self.data, other.data, equal_nan=True) | ||
| except (TypeError, ValueError): | ||
| return self.data == other.data | ||
|
|
||
| @property | ||
| def owner(self) -> None: | ||
| return None | ||
|
|
@@ -838,6 +849,119 @@ def value(self): | |
| return self.data | ||
|
|
||
|
|
||
| class FrozenConstant(Constant): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again sounds like it's specializing on numerical array types. How does MergeOptimizer find that two constants are equal for merging? Can we reuse that logic? I wouldn't expect we need a FrozenConstant class in the end, since Constants are frozen by our standards already. The challenge here is more finding whether a new constant was already seen before? |
||
| """A globally-interned Constant for use in frozen graphs. | ||
|
|
||
| Two ``FrozenConstant`` instances with the same type and data are the same object. | ||
| """ | ||
|
|
||
| _cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary() | ||
| _filtered: Any | ||
|
|
||
| # FrozenConstant doesn't inherit the scalar mixin that provides .dtype, | ||
| # but scalar C code generation expects it on all variables. | ||
| @property | ||
| def dtype(self): | ||
| return self.type.dtype | ||
|
|
||
| def __new__(cls, type: _TypeType, data: Any, name: str | None = None): | ||
| filtered = type.filter(data) | ||
| cache_key = cls._make_key(type, filtered) | ||
| if cache_key is not None: | ||
| cached = cls._cache.get(cache_key) | ||
| if cached is not None: | ||
| return cached | ||
| instance = object.__new__(cls) | ||
| # Store filtered data now so __init__ can skip re-filtering | ||
| instance._filtered = filtered | ||
| if cache_key is not None: | ||
| cls._cache[cache_key] = instance | ||
| return instance | ||
|
|
||
| def __init__(self, type: _TypeType, data: Any, name: str | None = None): | ||
| if hasattr(self, "data"): | ||
| return | ||
| # Use pre-filtered data from __new__ to avoid a second type.filter() call | ||
| AtomicVariable.__init__(self, type, name=name) | ||
| self.data = self._filtered | ||
| del self._filtered | ||
| add_tag_trace(self) | ||
|
|
||
| @staticmethod | ||
| def _make_key(type, filtered): | ||
| if isinstance(filtered, np.ndarray): | ||
| from pytensor.tensor.utils import hash_from_ndarray | ||
|
|
||
| return type, hash_from_ndarray(filtered) | ||
| if isinstance(filtered, np.generic): | ||
| from pytensor.tensor.utils import hash_from_ndarray | ||
|
|
||
| return type, hash_from_ndarray(np.asarray(filtered)) | ||
| try: | ||
| return type, hash(filtered) | ||
| except TypeError: | ||
| return None | ||
|
|
||
| def __reduce__(self): | ||
| return (type(self), (self.type, self.data, self.name)) | ||
|
|
||
|
|
||
| class FrozenApply(Apply): | ||
| """An immutable, globally-interned Apply node for frozen graphs. | ||
|
|
||
| Uses tuples for ``inputs`` and ``outputs`` so mutation raises ``TypeError`` | ||
| at the language level. Interned by ``(op, inputs)`` — constructing a | ||
| ``FrozenApply`` with an ``op`` and ``inputs`` that match an existing live | ||
| instance returns that instance. | ||
| """ | ||
|
|
||
| _cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary() | ||
|
|
||
| def __new__( | ||
| cls, op: "Op", inputs: tuple[Variable, ...], output_types: tuple["Type", ...] | ||
| ): | ||
| # Canonicalize inputs through their owner's outputs to ensure cache hits after unpickling. | ||
| inputs = tuple( | ||
| inp.owner.outputs[inp.index] | ||
| if inp.owner is not None and isinstance(inp.owner, FrozenApply) | ||
| else inp | ||
| for inp in inputs | ||
| ) | ||
| key = (op, inputs) | ||
| cached = cls._cache.get(key) | ||
| if cached is not None: | ||
| return cached | ||
|
|
||
| instance = object.__new__(cls) | ||
| instance.op = op | ||
| instance.inputs = inputs # type: ignore[assignment] | ||
| instance.outputs = tuple( # type: ignore[assignment] | ||
| t.variable_type(type=t, owner=instance, index=i) | ||
| for i, t in enumerate(output_types) | ||
| ) | ||
| instance.tag = Scratchpad() | ||
| cls._cache[key] = instance | ||
| return instance | ||
|
|
||
| def __init__(self, op, inputs, output_types): | ||
| # All initialization is done in __new__ | ||
| pass | ||
|
|
||
| def clone(self, clone_inner_graph: bool = False) -> "Apply": | ||
| """Clone into a mutable Apply node.""" | ||
| from pytensor.graph.op import HasInnerGraph | ||
|
|
||
| new_op = self.op | ||
| if isinstance(new_op, HasInnerGraph) and clone_inner_graph: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check the boolean |
||
| new_op = new_op.clone() | ||
|
|
||
| return Apply(new_op, list(self.inputs), [o.clone() for o in self.outputs]) | ||
|
|
||
| def __reduce__(self): | ||
| output_types = tuple(o.type for o in self.outputs) | ||
| return (type(self), (self.op, self.inputs, output_types)) | ||
|
|
||
|
|
||
| def clone( | ||
| inputs: Sequence[Variable], | ||
| outputs: Sequence[Variable], | ||
|
|
@@ -1154,14 +1278,14 @@ def equal_computations( | |
|
|
||
| for x, y in zip(xs, ys, strict=True): | ||
| if not isinstance(x, Variable) and not isinstance(y, Variable): | ||
| return np.array_equal(x, y) | ||
| return np.array_equal(x, y, equal_nan=True) | ||
| if not isinstance(x, Variable): | ||
| if isinstance(y, Constant): | ||
| return np.array_equal(y.data, x) | ||
| return np.array_equal(y.data, x, equal_nan=True) | ||
| return False | ||
| if not isinstance(y, Variable): | ||
| if isinstance(x, Constant): | ||
| return np.array_equal(x.data, y) | ||
| return np.array_equal(x.data, y, equal_nan=True) | ||
| return False | ||
| x_is_owned, y_is_owned = (x.owner is not None, y.owner is not None) | ||
| if x_is_owned != y_is_owned: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,8 @@ | |
| from pytensor.graph.basic import ( | ||
| Apply, | ||
| AtomicVariable, | ||
| Constant, | ||
| NominalVariable, | ||
| Variable, | ||
| clone_get_equiv, | ||
| ) | ||
|
|
@@ -928,3 +930,154 @@ def dprint(self, **kwargs): | |
| from pytensor.printing import debugprint | ||
|
|
||
| return debugprint(self, **kwargs) | ||
|
|
||
| def freeze(self) -> "FrozenFunctionGraph": | ||
| """Return a frozen, hashable version of this FunctionGraph.""" | ||
| return FrozenFunctionGraph(self.inputs, self.outputs) | ||
|
|
||
|
|
||
| class FrozenFunctionGraph: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They could inherit from a shared base-class. Then for instance |
||
| """Immutable, hashable function graph for inner graphs of Ops. | ||
|
|
||
| All internal nodes are globally interned via ``FrozenApply`` and ``FrozenConstant``. Two ``FrozenFunctionGraph`` | ||
| instances built from structurally identical source graphs share the same internal objects, so equality reduces to | ||
| an identity check on the output tuples. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| from pytensor.scalar.basic import float64, add | ||
| from pytensor.graph.fg import FunctionGraph | ||
|
|
||
| x, y = float64("x"), float64("y") | ||
| fg = FunctionGraph([x, y], [add(x, y)]) | ||
| frozen = fg.freeze() | ||
| frozen2 = FunctionGraph([x, y], [add(x, y)]).freeze() | ||
|
|
||
| assert frozen == frozen2 | ||
| assert {frozen: "value"}[frozen2] == "value" | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| inputs: Sequence[Variable], | ||
| outputs: Sequence[Variable], | ||
| ): | ||
| from pytensor.graph.basic import ( | ||
| FrozenApply, | ||
| FrozenConstant, | ||
| ) | ||
|
|
||
| nominal_inputs = tuple( | ||
| NominalVariable(i, inp.type, name=inp.name) for i, inp in enumerate(inputs) | ||
| ) | ||
|
|
||
| memo: dict[Variable, Variable] = dict(zip(inputs, nominal_inputs, strict=True)) | ||
|
|
||
| var_hash: dict[Variable, int] = {} | ||
| for i, nm in enumerate(nominal_inputs): | ||
| var_hash[nm] = hash(("input", i, nm.type)) | ||
|
|
||
| for node in toposort(outputs, blockers=inputs): | ||
| for inp in node.inputs: | ||
| if inp not in memo: | ||
| if isinstance(inp, Constant): | ||
| fc = FrozenConstant(inp.type, inp.data) | ||
| memo[inp] = fc | ||
| if fc not in var_hash: | ||
| var_hash[fc] = hash(fc) | ||
| elif isinstance(inp, AtomicVariable): | ||
| # AtomicVariables (e.g. NominalVariables from outer | ||
| # scopes) are already interned and hashable. | ||
| memo[inp] = inp | ||
| if inp not in var_hash: | ||
| var_hash[inp] = hash(inp) | ||
| else: | ||
| raise ValueError( | ||
| f"Non-Constant, non-AtomicVariable orphan {inp} found " | ||
| "in the graph. All variables must be graph inputs, " | ||
| "Constants, AtomicVariables, or produced by Apply " | ||
| "nodes reachable from the inputs." | ||
| ) | ||
|
|
||
| new_inputs = tuple(memo[i] for i in node.inputs) | ||
| output_types = tuple(out.type for out in node.outputs) | ||
| new_node = FrozenApply(node.op, new_inputs, output_types) | ||
|
|
||
| input_hashes = tuple(var_hash[i] for i in new_inputs) | ||
| node_hash = hash((node.op, input_hashes)) | ||
| for old_out, new_out in zip(node.outputs, new_node.outputs, strict=True): | ||
| memo[old_out] = new_out | ||
| var_hash[new_out] = hash((node_hash, new_out.index)) | ||
|
|
||
| self.inputs: tuple[Variable, ...] = nominal_inputs | ||
|
|
||
| resolved_outputs = [] | ||
| for o in outputs: | ||
| mapped = memo.get(o) | ||
| # After unpickling, o may be a fresh object whose owner is the (correctly interned) FrozenApply. | ||
| # We thus resolve it through its owner to get back the original variable. | ||
| if mapped is None and o.owner is not None: | ||
| mapped = memo.get(o.owner.outputs[o.index]) | ||
| if mapped is None: | ||
| raise ValueError( | ||
| f"Output variable {o} could not be mapped to a frozen graph variable. " | ||
| "All outputs must be graph inputs, constants, or produced by Apply nodes " | ||
| "reachable from the inputs." | ||
| ) | ||
| resolved_outputs.append(mapped) | ||
| self.outputs: tuple[Variable, ...] = tuple(resolved_outputs) | ||
|
|
||
| self._structural_hash: int = hash(tuple(var_hash[o] for o in self.outputs)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be right, but why do we need to hash intermediate variables? Can't we just hash the outputs?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was just erring on the side of caution. Can you think of a case where two graphs with different inputs would lead to different outputs?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean with equal inputs? But no, if we implement it correctly it shouldn't happen |
||
|
|
||
| def __hash__(self): | ||
| return self._structural_hash | ||
|
|
||
| def __eq__(self, other): | ||
| if self is other: | ||
| return True | ||
| if not isinstance(other, FrozenFunctionGraph): | ||
| return False | ||
| if self._structural_hash != other._structural_hash: | ||
| return False | ||
| if self.outputs == other.outputs: | ||
| return True | ||
| # Hash match but output identity mismatch — likely a hash collision | ||
| # or interning bug. Fall back to structural comparison. | ||
| import warnings | ||
|
|
||
|
Comment on lines
+1044
to
+1047
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? If it's a bug raise. You already showed the outputs are different |
||
| from pytensor.graph.basic import equal_computations | ||
|
|
||
| if ( | ||
| len(self.outputs) == len(other.outputs) | ||
| and len(self.inputs) == len(other.inputs) | ||
| and equal_computations( | ||
| list(self.outputs), | ||
| list(other.outputs), | ||
| in_xs=list(self.inputs), | ||
| in_ys=list(other.inputs), | ||
| ) | ||
| ): | ||
| warnings.warn( | ||
| "FrozenFunctionGraph: structurally equal graphs did not share " | ||
| "interned objects. This may indicate an interning bug.", | ||
| stacklevel=2, | ||
| ) | ||
| return True | ||
| return False | ||
|
|
||
| def __repr__(self): | ||
| return f"FrozenFunctionGraph(inputs={list(self.inputs)}, outputs={list(self.outputs)})" | ||
|
|
||
| def __reduce__(self): | ||
| return (type(self), (list(self.inputs), list(self.outputs))) | ||
|
|
||
| @property | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These properties are expensive. Any reason not to define them at
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess just to avoid paying the up-front cost in case we don't end up using them. I can save the results the first time and re-use thereafter. |
||
| def apply_nodes(self) -> set[Apply]: | ||
| return set(applys_between(self.inputs, self.outputs)) | ||
|
|
||
| def toposort(self) -> list[Apply]: | ||
| return list(toposort(self.outputs, blockers=self.inputs)) | ||
|
|
||
| @property | ||
| def variables(self) -> set[Variable]: | ||
| return set(vars_between(self.inputs, self.outputs)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why in the base class. For instance the
np.array_equallooks very tensor oriented, but we have types like Slice, RNG, ...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this commit was needed in an intermediate form, but now can simply be dropped.