Bug report
A ValueError occurs during the initialization of the Transformer class in maxtext/models/models.py. The error is triggered because the decoder attribute is being treated as static metadata by Flax NNX, preventing its assignment to a dynamic ToNNX wrapper.
Logs/Output
ValueError: Cannot assign data value of type '<class 'maxtext.layers.nnx_wrappers.ToNNX'>' to static attribute 'decoder' of Pytree type '<class 'maxtext.models.models.Transformer'>'. To override the status explicitly wrap the value with nnx.data on assignment:
_.decoder = nnx.data(...)
Environment Information
MaxText Branch: main (as of March 2026).
Hardware: Single-host TPU v5e 2*4
JAX Version: 0.4.25.
flax>=0.10.0
Additional Context
No response