Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
609 changes: 609 additions & 0 deletions src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py

Large diffs are not rendered by default.

581 changes: 581 additions & 0 deletions src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Large diffs are not rendered by default.

32 changes: 27 additions & 5 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import flags
import datetime
from etils import epath
from flax import nnx
from flax.training import train_state
import jax
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
Expand Down Expand Up @@ -571,7 +572,7 @@ def load_state_if_possible(
load_parameters_from_path: str,
load_full_state_from_path: str,
checkpoint_storage_concurrent_gb: int,
abstract_unboxed_pre_state: train_state.TrainState,
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
enable_single_replica_ckpt_restoring: bool | None = False,
dataset_type: str | None = "tfds",
step: int = -1, # -1 means latest
Expand Down Expand Up @@ -639,9 +640,14 @@ def map_to_pspec(data):
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()

restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
checkpoint_args = ocp.args.PyTreeRestore(
item=abstract_unboxed_pre_state,
item=restore_target,
restore_args=restore_args,
partial_restore=True,
)
Expand Down Expand Up @@ -679,9 +685,14 @@ def map_to_pspec(data):
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)

if load_parameters_from_path != "":
if isinstance(abstract_unboxed_pre_state, nnx.State):
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
else:
params = abstract_unboxed_pre_state.params

restored_params = load_params_from_path(
load_parameters_from_path,
abstract_unboxed_pre_state.params,
params,
checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
Expand Down Expand Up @@ -773,7 +784,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
# Determine the effective step for saving a checkpoint.
# If 'step' is not provided, this call is for a potential final checkpoint
# and use the last completed step from the state.
actual_step = (int(state.step) - 1) if step is None else int(step)
if step is not None:
actual_step = int(step)
else:
if config.pure_nnx:
actual_step = int(state.optimizer.step) - 1
else:
# Linen TrainState has .step attribute
actual_step = int(state.step) - 1

if config.pure_nnx:
# Convert nnx.State to dict.
state = state.to_pure_dict()

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
# This occurs if this function was called:
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MODEL_MODE_TRAIN,
Config,
DecoderBlockType,
MultimodalInput,
ShardMode,
)
from maxtext.inference import page_manager
Expand Down Expand Up @@ -1061,10 +1062,10 @@ def __call__(
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
multimodal_input: None | Any = None,
kv_caches: list[jax.Array] | None = None,
attention_metadata=None,
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
multimodal_input: None | MultimodalInput = None,
):
cfg = self.config
assert decoder_input_tokens.ndim == 2 # [batch, len]
Expand Down
6 changes: 5 additions & 1 deletion src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,11 @@ def __call__(
previous_chunk=previous_chunk,
slot=slot,
page_state=page_state,
multimodal_input=multimodal_input,
image_embeddings=multimodal_input.image_embeddings if multimodal_input is not None else None,
image_masks=multimodal_input.image_masks if multimodal_input is not None else None,
audio_embeddings=multimodal_input.audio_embeddings if multimodal_input is not None else None,
audio_masks=multimodal_input.audio_masks if multimodal_input is not None else None,
bidirectional_mask=multimodal_input.bidirectional_mask if multimodal_input is not None else None,
kv_caches=kv_caches,
attention_metadata=attention_metadata,
deepstack_visual_embeds=deepstack_visual_embeds,
Expand Down
4 changes: 3 additions & 1 deletion src/maxtext/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu):
else:
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)

step_size = -1.0 * learning_rate_fn(count)
# learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped
# by optax.inject_hyperparams, it is passed as a pre-evaluated scalar).
step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn)
# Finally, fold in step size.
updates = jax.tree_util.tree_map(lambda x: step_size * x, updates)

Expand Down
78 changes: 45 additions & 33 deletions src/maxtext/trainers/post_train/distillation/train_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,30 +273,45 @@ def wrt_filter(path, x):
# Inherits _shard_optimizer from PeftTrainer.

def _train_step(self, model, optimizer, inputs):
"""Overrides the main JIT block to natively handle ModelBundle module."""
"""Overrides the main JIT block to natively handle ModelBundle module.

Uses jax.value_and_grad with explicit split/merge to avoid nesting
nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign
conflicting outer_index values and raises:
ValueError: The graph structure of a node added to cached_partial was
mutated inside the transformation.
"""
batch = self.gen_model_input_fn(inputs)
student = model.student_model
teacher = model.teacher_model
current_step = model.training_step[...]

def loss_wrapper(student, teacher, batch):
if "teacher_output" in batch:
teacher_output = batch["teacher_output"]
else:
teacher_output = self.strategy.teacher_forward_fn(
model=teacher,
input_tokens=batch["input_tokens"],
positions=batch["positions"],
attention_mask=batch.get("attention_mask"),
decoder_segment_ids=batch.get("decoder_segment_ids"),
decoder_target_tokens=batch.get("targets", None),
decoder_target_mask=batch.get("targets_segmentation", None),
cache=None,
)
# Run teacher inference outside of value_and_grad.
# The teacher is frozen (stop_gradient), so its output is a constant
# from the perspective of the student gradient computation.
if "teacher_output" in batch:
teacher_output = batch["teacher_output"]
else:
teacher_output = self.strategy.teacher_forward_fn(
model=teacher,
input_tokens=batch["input_tokens"],
positions=batch["positions"],
attention_mask=batch.get("attention_mask"),
decoder_segment_ids=batch.get("decoder_segment_ids"),
decoder_target_tokens=batch.get("targets", None),
decoder_target_mask=batch.get("targets_segmentation", None),
cache=None,
)
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)

teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
# Split student into differentiable params and non-differentiable rest.
# Capture graphdef outside of jax.value_and_grad for stable graph tracking.
student_graphdef, diff_params, rest = nnx.split(student, self.wrt_filter, ...)

def loss_wrapper_pure(diff_params, rest):
local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True)
student_output = self.strategy.student_forward_fn(
model=student,
model=local_student,
input_tokens=batch["input_tokens"],
positions=batch["positions"],
attention_mask=batch.get("attention_mask"),
Expand All @@ -305,29 +320,26 @@ def loss_wrapper(student, teacher, batch):
decoder_target_mask=batch.get("targets_segmentation", None),
cache=None,
)
# we should apply a mask for labels to disable segment-separator tokens
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
return self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step)

# Because student is the 0th argument, argnums=0 guarantees
# we only compute gradients for the student.
grad_fn = nnx.value_and_grad(
loss_wrapper,
argnums=nnx.DiffState(0, self.wrt_filter),
has_aux=True,
)
loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step)
# Capture updated non-param state (e.g. RNG counters) from local_student.
_, _, new_rest = nnx.split(local_student, self.wrt_filter, ...)
return loss, (aux, new_rest)

out, grads = grad_fn(model.student_model, model.teacher_model, batch)
grad_fn = jax.value_and_grad(loss_wrapper_pure, argnums=0, has_aux=True)
(loss, (aux, new_rest)), grads = grad_fn(diff_params, rest)

model.training_step.set_value(current_step + 1)
# Propagate updated non-param state back to student.
nnx.update(student, new_rest)

tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
optimizer.update(student, grads)

optimizer.update(model.student_model, grads)
model.training_step.set_value(current_step + 1)

tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
if tunix_expects_grad_norm:
return out[0], out[1], optax.global_norm(grads)
return out[0], out[1]
return loss, aux, optax.global_norm(grads)
return loss, aux

def _eval_step(self, model, inputs):
"""Evaluation only needs the student."""
Expand Down
43 changes: 42 additions & 1 deletion src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,42 @@
import os
import pathwaysutils

# JAX 0.9+ changed with_sharding_constraint to assert (not reshard) when all
# mesh axes are Explicit. tpu_inference still expects resharding semantics.
# Patch: try the original (works for Auto axes); on AssertionError (Explicit
# mesh) fall back to jax.sharding.reshard.
_orig_wsc = jax.lax.with_sharding_constraint


def _compat_wsc(x, shardings):
try:
return _orig_wsc(x, shardings)
except AssertionError:
return jax.sharding.reshard(x, shardings)


jax.lax.with_sharding_constraint = _compat_wsc

# tpu_inference JaxEinsum defaults param_dtype=float32, so tpu_inference model weights
# initialize as float32. During weight sync, tunix._apply_dtype_cast then upcasts the
# incoming bfloat16 MaxText weights → float32 to match the target. This leaves v_proj
# as float32 while k_proj output appears bfloat16 (due to k_norm dtype promotion),
# causing a dtype mismatch in the ragged paged attention kernel.
# Fix: skip bfloat16→float32 upcasts during weight sync so synced weights stay bfloat16.
import jax.numpy as _jnp
import tunix.generate.utils as _tunix_utils

_orig_apply_dtype_cast = _tunix_utils._apply_dtype_cast # pylint: disable=protected-access


def _no_bf16_to_f32_cast(val, tgt_dtype, src_key):
if hasattr(val, "dtype") and val.dtype == _jnp.bfloat16 and tgt_dtype == _jnp.float32:
return val # keep bfloat16; tpu_inference model dtype is bfloat16 despite float32 init
return _orig_apply_dtype_cast(val, tgt_dtype, src_key)


_tunix_utils._apply_dtype_cast = _no_bf16_to_f32_cast # pylint: disable=protected-access

from absl import app
from absl import logging as absl_logging
from etils import epath
Expand Down Expand Up @@ -418,6 +454,8 @@ def create_rl_components(
"hf_overrides": trainer_config.vllm_hf_overrides,
"enable_expert_parallel": sampler_config.enable_expert_parallel,
"enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts
# Ensures vLLM model initializes with correct dtype (not float32 default)
"dtype": trainer_config.weight_dtype,
},
rollout_vllm_sampling_kwargs={
"stop": trainer_config.stop_strings,
Expand Down Expand Up @@ -563,7 +601,10 @@ def rl_train(argv: Sequence[str], kwargs: dict):
max_train_steps = get_max_train_steps(trainer_config)

# Create model tokenizer
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
model_tokenizer = AutoTokenizer.from_pretrained(
trainer_config.tokenizer_path,
token=trainer_config.hf_access_token or None,
)

train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer)

Expand Down
70 changes: 68 additions & 2 deletions src/maxtext/trainers/post_train/sft/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@
eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16
"""

from typing import Sequence
from typing import Any, Sequence

from absl import app
import os
import jax
import optax
import pathwaysutils

from flax import nnx
from flax.linen import partitioning as nn_partitioning

from orbax import checkpoint as ocp
Expand All @@ -68,6 +69,70 @@
from maxtext.utils import model_creation_utils


class MaxTextPeftTrainer(peft_trainer.PeftTrainer):
"""MaxText-specific PeftTrainer that avoids nested NNX transformations.

Tunix's default PeftTrainer._train_step creates nnx.value_and_grad inside
nnx.jit. This nesting causes Flax NNX to assign conflicting outer_index
values to graph nodes, resulting in:
ValueError: The graph structure of a node added to cached_partial was
mutated inside the transformation.

This subclass overrides create_train_step_fn to use jax.value_and_grad
with an explicit split/merge pattern (matching MaxText's pre-training NNX
train_step), which avoids the nested NNX transformation issue entirely.
"""

def create_train_step_fn(self):
"""Creates a train step using jax.value_and_grad with explicit NNX split/merge."""
loss_fn_ref = self.loss_fn
has_aux = self._has_aux
gen_fn = self.gen_model_input_fn
is_lora_enabled = self._lora_enabled
wrt = nnx.LoRAParam if is_lora_enabled else nnx.Param
tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)

# Capture the graphdef once outside of JIT so that split/merge inside
# jax.value_and_grad can use a stable (non-traced) structural descriptor.
graphdef, _, _ = nnx.split(self.model, wrt, ...)

def train_step(model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any):
inputs = gen_fn(inputs)

# Split model into differentiable params and non-differentiable rest.
# Using jax.value_and_grad (not nnx.value_and_grad) avoids nesting NNX
# transforms inside nnx.jit, which would corrupt outer_index tracking.
_, diff_params, rest = nnx.split(model, wrt, ...)

def loss_wrapper(diff_params, rest, **inputs_kw):
local_model = nnx.merge(graphdef, diff_params, rest, copy=True)
out = loss_fn_ref(local_model, **inputs_kw)
# Capture updated non-param state (e.g. RNG counters) from local_model.
_, _, new_rest = nnx.split(local_model, wrt, ...)
if has_aux:
loss, aux = out
return loss, (aux, new_rest)
else:
return out, (None, new_rest)

grad_fn = jax.value_and_grad(loss_wrapper, argnums=0, has_aux=True)
(out_val, (aux, new_rest)), grads = grad_fn(diff_params, rest, **inputs)

# Propagate updated non-param state (RNG counters, etc.) back to model.
nnx.update(model, new_rest)

# Apply optimizer update. grads has the same nnx.State(wrt) structure
# as diff_params, which is compatible with optimizer.update.
optimizer.update(model, grads)

aux_out = aux if has_aux else None
if tunix_expects_grad_norm:
return out_val, aux_out, optax.global_norm(grads)
return out_val, aux_out

return train_step


def get_tunix_config(mt_config):
"""Gets the Tunix training configurations from the MaxText config.

Expand Down Expand Up @@ -109,6 +174,7 @@ def get_tunix_config(mt_config):
checkpointing_options=checkpointing_options,
metrics_logging_options=metrics_logging_options,
profiler_options=profiler_options,
data_sharding_axis=tuple(mt_config.data_sharding),
)


Expand Down Expand Up @@ -162,7 +228,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)
# Provide rules context so 'norm' is translated to mesh axes during maybe_restore
with nn_partitioning.axis_rules(mt_config.logical_axis_rules):
trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
trainer = MaxTextPeftTrainer(model, optimizer, tunix_config)
trainer.with_training_hooks(training_hooks)
trainer.with_data_hooks(data_hooks)
trainer = use_maxtext_loss_function(trainer, mt_config)
Expand Down
Loading
Loading