Skip to content

[NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter#3843

Draft
ecnal-cienet wants to merge 3 commits into
mainfrom
feat/nnx-linen-nnx-ckpt-utils
Draft

[NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter#3843
ecnal-cienet wants to merge 3 commits into
mainfrom
feat/nnx-linen-nnx-ckpt-utils

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented May 7, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX, plus post-training bugfixes that surfaced once the NNX path got exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. 🔄 [This PR] Linen↔NNX checkpoint converter. Originally bundled with the comparator under PR4.5; further split into PR4.5 (converter) + PR4.6 (comparator) on 2026-05-07 to keep each reviewable.
    4.6. ❌ Linen↔NNX checkpoint comparator (stacked follow-up on this branch).
  5. ❌ NNX correctness fixes, feature enablements, and vocab tiling on NNX.
  6. ❌ NNX-native DPO.
  7. ❌ NNX-native MaxEngine inference.
  8. ❌ NNX-native LoRA + GRPO.
  9. ❌ NNX-aware QK-Clip + remaining checkpoint utilities.
    9.5. ❌ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix.
  10. ❌ Vocab tiling custom_vjp for NNX.
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

This PR adds linen_nnx_converter.py — a bidirectional Linen↔NNX checkpoint converter used during the NNX migration to translate checkpoints between the two formats. Originally bundled with a comparison utility; the comparator was split into a stacked follow-up (PR4.6) on 2026-05-07 so each PR stays narrowly reviewable. PR4.5 and PR4.6 are file-disjoint.

This is a pure addition — no existing files are modified, no production-code paths reference the utility, and no Linen or NNX runtime behavior changes. PR5+ do not depend on this branch.

Diff: +1450 / −0 across 2 new files.

What it does

src/maxtext/checkpoint_conversion/linen_nnx_converter.py:

  • Bidirectional conversion — Linen → NNX and NNX → Linen. Round-trips are designed to preserve byte values; the unit tests assert this on synthetic checkpoints.
  • Top-level key mapping (handled symmetrically in both directions):
    • Linen params/params/<model> ↔ NNX model/<model> (remove/add double-nesting; add/strip {value: ...} wrappers).
    • Linen opt_state ↔ NNX optimizer/opt_state (remove/add params level on mu / nu).
    • Linen step ↔ NNX optimizer/step (move in/out of optimizer).
  • Layer structure--scan_layers=True (default) stacks per-layer arrays into a single layers tensor (layer axis at position 1); --scan_layers=False keeps integer-keyed layers/{N}. NNX→Linen direction auto-detects the source layout.
  • Format detection--direction=auto picks Linen vs NNX from top-level keys (model → NNX, params → Linen).
  • CPU-only — sets JAX_PLATFORMS=cpu before importing JAX so it runs on a workstation without TPU/GPU access.

Tests

tests/unit/linen_nnx_converter_test.py — pure-CPU, 84 cases. Covers:

  • Format detection
  • Value-wrapper add/strip
  • Layer-axis transpose
  • Layer stacking / unstacking
  • Full Linen→NNX and NNX→Linen convert_* paths
  • Opt-state conversion in both directions
  • Checkpoint load/save (orbax mocked)
  • The CLI main entry point

Existing tests untouched.

Stats

  • Diff: +1450 / −0 across 2 files (2 new, 0 modified).
  • Production-code impact: none. No existing source file imports the utility.
  • Linen preservation: trivially preserved — no Linen file is touched.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
Part 1 — sharding diagnostics:
- maxtext_utils.py: extend print_shardings_params to support NNX (nnx.State input)
- run_sharding_dump.py: add --pure_nnx flag

Part 2 — post-training bugfixes (NNX-side):
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing
  the whole object as multimodal_input= kwarg; NNXDecoder only accepts the
  individual image/audio/mask fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
  (callable() check before invoking learning_rate_fn)
- train_distill.py / train_sft.py / train_rl.py: avoid nesting nnx.value_and_grad
  inside nnx.jit (Tunix's default trainer), which raises "graph structure of a
  node added to cached_partial was mutated" — refactor to jax.value_and_grad
  with explicit nnx.split / nnx.merge; train_rl.py also adds with_sharding_constraint
  + dtype-cast compat shims for jax 0.9 / tpu_inference

Linen<->NNX checkpoint conversion utility and validation tool moved to a
follow-up PR (PR4.5) to keep this change reviewable.
Bidirectional Linen <-> NNX checkpoint conversion. Same on-disk shape
both directions; round-trips preserve byte values.

Top-level key mapping:
- Linen params/params/<model> <-> NNX model/<model> (double-nesting,
  {value:} wrappers).
- Linen opt_state <-> NNX optimizer/opt_state (params level on mu/nu).
- Linen step <-> NNX optimizer/step.

Layer structure:
- scan_layers=True (default): stack layers_N -> layers tensor.
- scan_layers=False: rename layers_N -> integer-keyed layers/{N}.

NNX->Linen direction auto-detects which layer layout the source uses.
--direction=auto picks Linen vs NNX from top-level keys.

Pure utility addition. No production-code dependencies; PR5+ do not
depend on this branch. Comparison utility split into PR4.6.
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-nnx-ckpt-utils branch from 4b9231f to 4417cde Compare May 8, 2026 16:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant