Skip to content

Cache graph functions and add If gradient test#637

Open
nfeybesse wants to merge 2 commits into
tensorflow:masterfrom
nfeybesse:custom/graph-function-cache
Open

Cache graph functions and add If gradient test#637
nfeybesse wants to merge 2 commits into
tensorflow:masterfrom
nfeybesse:custom/graph-function-cache

Conversation

@nfeybesse
Copy link
Copy Markdown
Contributor

This PR introduces a small cache of ConcreteFunction instances attached to a Graph.

When functions are attached via attachFunction, they are stored in a local cache
indexed by their defined name. This avoids repeatedly scanning the native
TensorFlow function library when resolving functions during gradient construction.

A helper method getFunctionCached(String prefix) is also added to allow quick lookup
of cached functions by name prefix.

In addition, this PR introduces IfGradientTest, a unit test validating correct
gradient propagation through a StatefulIf operation.

@nfeybesse nfeybesse force-pushed the custom/graph-function-cache branch from 846b6cf to 16b27d2 Compare March 12, 2026 13:17
@Craigacp
Copy link
Copy Markdown
Collaborator

We've removed Windows support in preparation for the 1.2 release, please don't add it back to the CI. Google don't release the libtensorflow binaries on Windows which we've used for the past few releases, and building libtensorflow from source isn't tenable with the GitHub Actions resources we have access to.

@nfeybesse nfeybesse force-pushed the custom/graph-function-cache branch from 16b27d2 to 9c3b19d Compare March 12, 2026 14:30
* @return a cached {@link ConcreteFunction} whose name starts with {@code prefix}, or {@code
* null} if none is found
*/
public ConcreteFunction getFunctionCached(String prefix) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @nfeybesse, can you please remove this method? Looks like it is not being used, and it might not be desirable neither since if you have multiple functions with the same prefix, you don't know which one it gonna return (unless you return the whole list of matching functions?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, that makes sense regarding the prefix-based lookup.

I tried removing access to the cached functions completely and relying only on Graph.getFunction(exactName), but this makes the implementation substantially more complicated. In particular, during custom gradient construction, calling Graph.getFunction(...) may end up scanning/querying the native function library while the graph is already being manipulated by the gradient builder. In my test case this can hang, so resolving the gradient functions through the native function library does not seem safe in that context.

I can still avoid the ambiguous prefix lookup by keeping an exact-name Java-side map in the test/code that creates the gradient functions. That works, but it means duplicating bookkeeping outside Graph even though Graph already has the information.

Maybe a middle-ground would be to expose a read-only view of the cached function names, for example a keySet() or functionNames() method. Then callers could resolve ambiguity themselves, choose an exact name deterministically, and still avoid exposing a method that returns an arbitrary function for a prefix.

Tell me what you prefer

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nfeybesse , sorry I thought I had replied to your last message, but looks like not, don't know what happened. But yes, the read-only view of the cached function names keyset works for us, can you please make the change? We are in the process of releasing 1.2.0, but if you can't push your changes before we'll simply release 1.2.1 after

@karllessard
Copy link
Copy Markdown
Collaborator

Hi @nfeybesse , just checking quickly if you plan to push any time soon an updated version of that PR based on our previous discussion, if so we'll wait a bit before releasing 1.2.0

@nfeybesse
Copy link
Copy Markdown
Contributor Author

Hi Karl, I updated the PR following your suggestion.

getFunctionCached(prefix) has been removed and replaced with a read-only functionNames() view of the cached function names. The If gradient test now resolves the prefix match on the caller side, checks that it is unambiguous, and then calls getFunction(exactName).

I ran the targeted If gradient test, the module spotless check, and a full mvn install locally. I also verified the downstream gs-keras If gradient test suite against the locally installed artifacts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants