Skip to content

Comments

Allow freezing of FunctionGraph for hashing#1908

Draft
jessegrabowski wants to merge 6 commits intopymc-devs:mainfrom
jessegrabowski:hashable-inner-graphs
Draft

Allow freezing of FunctionGraph for hashing#1908
jessegrabowski wants to merge 6 commits intopymc-devs:mainfrom
jessegrabowski:hashable-inner-graphs

Conversation

@jessegrabowski
Copy link
Member

Closes #1606

LLM disclosure: this PR made heavy use of Claude in the planning and first cut stages, though I was heavily involved. Still, the code should be subject to extra scrutiny as a result.

The purpose of the PR is to refactor Ops with inner graphs to allow comparison. The linked issue has an exhaustive discussion of the factors at play. There was an attempt in the aesara days to attack this, but it was perhaps too aggressive: it cons-hashed all Apply nodes, which necessitated changes across the codebase. @ricardoV94 suggested a weakref dict approach for subgraphs. This is implemented at the Op level. The plan is for Ops that have inner graphs (Composite, ScalarLoop, Scan, OpFromGraph, etc) to have a _cache class attribute, and implement the op-specific logic for caching, pickling, unpickling, etc. It didn't look super generalizable to me at first blush, but we can argue about it maybe.

Changes to FunctionGraph:

  • FunctionGraph now has a method freeze that returns a FrozenFunctionGraph.
  • The FrozenFunctionGraph does cons-hashing of Apply nodes within its scope only
  • It generates a hash based on its inner graph
  • Two FrozenFunctionGraphs with the same inner graph with evaluate to equal, but their Apply nodes won't be references to the same objects (this is the "conservatism" of my approach)

Specific implementation details:

  • The structural_hash of a FrozenFunctionGraph is built from a list of 3-tuples: (name, type, inputs), plus the outputs. For constants, inputs is replaced with the hash of the input data.
  • Equality between FrozenFunctionGraphs is done by comparing hashes, then falling back to equal_computation if the hash misses.

A consequence of the cons-hashing in this approach is that the inner graph is de-duplicated when we call fg.freeze(). So a MergeOptimizer pass is no longer required. Usage is demonstrated on the Composite Op. If we like the approach I can move forward with refactoring other Ops, but I wanted to stop here and discuss the approach.

Code example:

import pytensor.tensor as pt
import pytensor

a, b, c, d = pt.dscalars('a', 'b', 'c', 'd')
eq1 = pt.sin(a) * b ** 2
eq2 = pt.sin(c) * d ** 2

with pytensor.config.change_flags(optimizer_verbose=True):
    f = pytensor.function([a, b, c, d], [eq1, eq2])

f.dprint()

Result:

Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id A] 1
 ├─ a [id B]
 └─ b [id C]
Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id D] 0
 ├─ c [id E]
 └─ d [id F]

Inner graphs:

Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id A]
 ← mul [id G]
    ├─ sin [id H]
    │  └─ *0-<float64> [id I]
    └─ sqr [id J]
       └─ *1-<float64> [id K]

Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id D]
 ← mul [id G]
    └─ ···

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you not go all out?

If you already deduplicate and do internal hash-cons you are one step away from getting hashing for free across different FunctionGraphs. Just do the hash-cons globally. Then FrozenFunctionGrahp([x, y], [foo(x, y)] is equal to another functiongraph if and only if fgraph.outputs == other_fgraph.outputs. No need for recursive hashing or expensive equal_computations.

As it stands you are not doing much better sneaking a default MergeOptimizer at __init__ and adding a FunctionGraph class that has no replace mode.

And cheap hashing/ equality is not just a nice to have, it's really valuable to not slow down compilation. In some of my benchmarks on previous work, some graphs could spend inordinate time on equality checks.

Comments regardless of whether we go:

  • Don't create FrozenFunctionGraph as a subclass of FrozenGraph, let's push the general principle, shared abstract classes, no-subclass of actually realized objects. Then you don't need check_frozen , the methods just don't exist for the frozen subclass.
  • You could create a frozenApply that uses tuple for input/outputs instead of list. That will help ensuring the immutability because all our current rewrite machinery works on the idea of overriding entries in those lists. Accidentally trying to mutate a graph would 99% fail there.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is starting to look good, how are you feeling about it?

Notes:

  • Add a FrozenFunctionGraph.unfreeze(), that yields a FunctionGraph?
  • Really try to avoid the FrozenConstant stuff
  • Ops with inner graph (at least the ones you touched now) should only have a FrozenFunctionGraph internally (not a mutable one as well). Maybe that's already the case.

We need some follow-up issues open:

  • Optimizing OpFromGraph: There should be an explicit rewrite that creates a new OpFromGraph with its updated frozen graph, (so it is also reflected immediately in dprint). We should never do any further rewrites of the internal fgraph during compilation.
  • Scan/Minimize/Root: Use the new FrozenFunctionGraph as well. This should immediately address #1601
  • When compiling OpFromGraph in jitted contexts we should try to avoid recreating inner numba/jax functions when the same OFG is compiled multiple times in a function, this will likely speedup compilation. In the C-backend that already happens due to the caching of _fn. That's how we can deliver on the promised compilations speedups and it's specially relevant for a library like pytensor-ml that may want to chains hundreds of the same "LayerOp"s in sequence

def clone(self, **kwargs):
return self

def equals(self, other):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why in the base class. For instance the np.array_equal looks very tensor oriented, but we have types like Slice, RNG, ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this commit was needed in an intermediate form, but now can simply be dropped.

return self.data


class FrozenConstant(Constant):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again sounds like it's specializing on numerical array types. How does MergeOptimizer find that two constants are equal for merging? Can we reuse that logic?

I wouldn't expect we need a FrozenConstant class in the end, since Constants are frozen by our standards already. The challenge here is more finding whether a new constant was already seen before?

from pytensor.graph.op import HasInnerGraph

new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the boolean clone_inner_graph first which is cheaper than the isinstance

resolved_outputs.append(mapped)
self.outputs: tuple[Variable, ...] = tuple(resolved_outputs)

self._structural_hash: int = hash(tuple(var_hash[o] for o in self.outputs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be right, but why do we need to hash intermediate variables? Can't we just hash the outputs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just erring on the side of caution. Can you think of a case where two graphs with different inputs would lead to different outputs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean with equal inputs? But no, if we implement it correctly it shouldn't happen

Comment on lines +1044 to +1047
# Hash match but output identity mismatch — likely a hash collision
# or interning bug. Fall back to structural comparison.
import warnings

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? If it's a bug raise. You already showed the outputs are different


if isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
elif isinstance(x, Constant) and isinstance(x.type, ScalarType):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug due to FrozenConstants? I really don't think we should have those, because they'll lack all the attributes that ScalarVariables have, and they will fail isinstance(x, ScalarVariable) checks

@@ -4140,38 +4116,17 @@ def prepare_node(self, node, storage_map, compute_map, impl):
def __eq__(self, other):
if self is other:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we have regular __props__ based equality/hashing now?

self.outputs_type = tuple(output.type for output in self.outputs)
self.nin = len(inputs)
self.nout = len(outputs)
fgraph = FunctionGraph(inputs, outputs, clone=clone_graph)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we build a frozen fgraph immediately? This is doing double effort. Also we can get ride of the clone_graph now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of the make_node methods do graph mutations. For example, Composite has this logic to flatten a sequence of nested Composites using rebuild_collect_shared, which wants to do a clone with new inputs eventually.

I could try to rip out/update this machinery too? I am already working on an unfreeze method as a work-around (you brought this up in another comment), but idk if we judge this as too expensive.

Copy link
Member

@ricardoV94 ricardoV94 Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested thing is also an eager rewrite mistake, it should be its own rewrite unless somehow compilation fails if Composite has another Composite inside (which I don't think so, because ScalarLoop can definitely handle other Composite/ScalarLoop inside)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed i was thinking the same thing, but not sure it's for this PR.

Anyway we will still have to unfreeze the fgraph at some point, whether during the rewrite or during init. We already have local_inline_composite_constants that mutates Composites and will need to unfreeze -> modify -> refreeze.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just mean no point in having two FunctionGraph inside Composite

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep I'm on the same page

@@ -4273,12 +4209,6 @@ def __str__(self):

@property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the property, if it already exists just store it as self.fgraph

e = x + y * x

op1 = OpFromGraph([x, y], [e])
op2 = OpFromGraph([x, y], [e])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also test with distinct x, y variables? They should still be identical since there are just nominal/dummies

return FrozenFunctionGraph(self.inputs, self.outputs)


class FrozenFunctionGraph:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They could inherit from a shared base-class. Then for instance x_funcify_FunctionGraph can dispatch on the base class, since they don't care whether it is a Frozen or Regular FunctionGraph?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Equality of Ops with InnerGraph

2 participants