fix(inject_hyperparams): treat plain Python ints as static to fix adafactor jit (#412)#1686
Open
nileshpatil6 wants to merge 1 commit into
Open
Conversation
Plain Python ints (e.g. min_dim_size_to_factor in adafactor) are used in Python-level control flow and shape decisions inside optimizers. Previously inject_hyperparams traced them as JAX arrays, causing a TracerBoolConversionError when the wrapped optimizer was jit-compiled. This change moves plain Python ints to other_hps (not traced) while keeping jax.Array / np.ndarray integers in numeric_hps so they remain schedulable. Fixes adafactor wrapped with inject_hyperparams under jit without requiring the caller to specify static_args manually. Fixes google-deepmind#412
17582f4 to
5cd6893
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #412
What was wrong
inject_hyperparamsputs plain Pythonintvalues intonumeric_hps,which converts them to traced JAX arrays.
adafactorpassesmin_dim_size_to_factor(anint) toscale_by_factored_rms, whichthen uses it in a Python comparison inside
_factored_dims:Under
jit, that comparison returns a traced bool, and Pythonifonit raises
TracerBoolConversionError.Fix
Move plain Python
intvalues (excludingbool, which was alreadyhandled) from
numeric_hpstoother_hpsininject_hyperparams.These values are structural parameters used for shape decisions and
cannot realistically be scheduled as JAX arrays anyway. Users who want
a schedulable integer can pass a
jax.Array(e.g.jnp.int32(10)),which still goes through the numeric path.
boolsubclassesintin Python, so the existingisinstance(value, bool)check continues to catch booleans before this new branch.Result
The test in
alias_test.pythat previously worked around this withstatic_args=('min_dim_size_to_factor',)is updated to use plaininject_hyperparams(adafactor). A dedicated regression test with large(200x200) params is added to cover the factored second-moment path.