Skip to content

fix(inject_hyperparams): treat plain Python ints as static to fix adafactor jit (#412)#1686

Open
nileshpatil6 wants to merge 1 commit into
google-deepmind:mainfrom
nileshpatil6:fix/adafactor-inject-hyperparams-jit
Open

fix(inject_hyperparams): treat plain Python ints as static to fix adafactor jit (#412)#1686
nileshpatil6 wants to merge 1 commit into
google-deepmind:mainfrom
nileshpatil6:fix/adafactor-inject-hyperparams-jit

Conversation

@nileshpatil6
Copy link
Copy Markdown

Fixes #412

What was wrong

inject_hyperparams puts plain Python int values into numeric_hps,
which converts them to traced JAX arrays. adafactor passes
min_dim_size_to_factor (an int) to scale_by_factored_rms, which
then uses it in a Python comparison inside _factored_dims:

if shape[sorted_dims[-2]] < min_dim_size_to_factor:

Under jit, that comparison returns a traced bool, and Python if on
it raises TracerBoolConversionError.

Fix

Move plain Python int values (excluding bool, which was already
handled) from numeric_hps to other_hps in inject_hyperparams.
These values are structural parameters used for shape decisions and
cannot realistically be scheduled as JAX arrays anyway. Users who want
a schedulable integer can pass a jax.Array (e.g. jnp.int32(10)),
which still goes through the numeric path.

bool subclasses int in Python, so the existing isinstance(value, bool) check continues to catch booleans before this new branch.

Result

optimizer = optax.inject_hyperparams(optax.adafactor)(learning_rate=0.01)
params = {'w': jnp.ones((200, 200))}
grads = {'w': jnp.ones((200, 200))}
opt_state = jax.jit(optimizer.init)(params)
updates, _ = jax.jit(optimizer.update)(grads, opt_state, params)
# works without specifying static_args

The test in alias_test.py that previously worked around this with
static_args=('min_dim_size_to_factor',) is updated to use plain
inject_hyperparams(adafactor). A dedicated regression test with large
(200x200) params is added to cover the factored second-moment path.

Plain Python ints (e.g. min_dim_size_to_factor in adafactor) are used
in Python-level control flow and shape decisions inside optimizers.
Previously inject_hyperparams traced them as JAX arrays, causing a
TracerBoolConversionError when the wrapped optimizer was jit-compiled.

This change moves plain Python ints to other_hps (not traced) while
keeping jax.Array / np.ndarray integers in numeric_hps so they remain
schedulable. Fixes adafactor wrapped with inject_hyperparams under jit
without requiring the caller to specify static_args manually.

Fixes google-deepmind#412
@nileshpatil6 nileshpatil6 force-pushed the fix/adafactor-inject-hyperparams-jit branch from 17582f4 to 5cd6893 Compare May 30, 2026 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Problems when jitting Adafactor with inject_hyperparams.

1 participant