Skip to content

Commit 1e9faff

Browse files
committed
Expose cached graph function names
1 parent 9c3b19d commit 1e9faff

3 files changed

Lines changed: 22 additions & 20 deletions

File tree

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ public GraphOperationBuilder opBuilder(String type, String name, Scope scope) {
404404
* immediately without re-registering it.
405405
*
406406
* <p>The function is also stored in an internal cache to speed up subsequent lookups performed by
407-
* {@link #getFunction(String)} and {@link #getFunctionCached(String)}.
407+
* {@link #getFunction(String)}.
408408
*/
409409
@Override
410410
public void attachFunction(ConcreteFunction function) {
@@ -915,26 +915,15 @@ Set<Operation> initializers() {
915915
new ConcurrentHashMap<>();
916916

917917
/**
918-
* Returns a cached {@link ConcreteFunction} whose name starts with the provided prefix.
918+
* Returns a read-only view of the function names cached by this graph.
919919
*
920-
* <p>This is a lightweight lookup helper used when the exact function name is not known but
921-
* follows a deterministic prefix (for example functions generated for control-flow constructs or
922-
* custom gradient expansions).
920+
* <p>This exposes only the function names so callers can resolve ambiguous matches themselves
921+
* before calling {@link #getFunction(String)} with an exact name.
923922
*
924-
* <p>The search is performed only in the local cache and does not query the native TensorFlow
925-
* function library.
926-
*
927-
* @param prefix function name prefix
928-
* @return a cached {@link ConcreteFunction} whose name starts with {@code prefix}, or {@code
929-
* null} if none is found
923+
* @return a read-only view of cached function names
930924
*/
931-
public ConcreteFunction getFunctionCached(String prefix) {
932-
for (Map.Entry<String, ConcreteFunction> e : functionCache.entrySet()) {
933-
if (e.getKey().startsWith(prefix)) {
934-
return e.getValue();
935-
}
936-
}
937-
return null;
925+
public Set<String> functionNames() {
926+
return Collections.unmodifiableSet(functionCache.keySet());
938927
}
939928

940929
/**

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/AttributeMetadata.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public class AttributeMetadata {
3030

3131
/** The size of the list if this attribute is a list, undefined otherwise. */
3232
public final long listSize;
33+
3334
/**
3435
* The type of this attribute, or the type of the list values if it is a list.
3536
*

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IfGradientTest.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,18 @@ private static Operand<?> sumAll(Ops tf, Operand<?> v) {
144144
return tf.reduceSum((Operand) v, axes);
145145
}
146146

147+
private static ConcreteFunction getSingleFunctionByPrefix(Graph graph, String prefix) {
148+
List<String> matches =
149+
graph.functionNames().stream()
150+
.filter(name -> name.startsWith(prefix))
151+
.collect(Collectors.toList());
152+
if (matches.size() != 1) {
153+
throw new IllegalStateException(
154+
"Expected one cached function for prefix=" + prefix + ", found=" + matches);
155+
}
156+
return graph.getFunction(matches.get(0));
157+
}
158+
147159
@Test
148160
public void testStatefullIfGradient() {
149161
TensorFlow.registerCustomGradient(
@@ -199,8 +211,8 @@ public void testStatefullIfGradient() {
199211
final String thenPrefix = op.name() + "/then_grad"; // op has unique name
200212
final String elsePrefix = op.name() + "/else_grad";
201213

202-
ConcreteFunction thenGrad = op.env().getFunctionCached(thenPrefix);
203-
ConcreteFunction elseGrad = op.env().getFunctionCached(elsePrefix);
214+
ConcreteFunction thenGrad = getSingleFunctionByPrefix(op.env(), thenPrefix);
215+
ConcreteFunction elseGrad = getSingleFunctionByPrefix(op.env(), elsePrefix);
204216

205217
if (thenGrad == null || elseGrad == null) {
206218
throw new IllegalStateException("If grad functions not primed for op=" + op.name());

0 commit comments

Comments
 (0)