Skip to content
581 changes: 581 additions & 0 deletions src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import os
import sys

from flax import nnx
import jax
from jax import random
from jax.sharding import Mesh
Expand All @@ -48,11 +49,15 @@
from maxtext.common import checkpointing
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.layers import quantizations
from maxtext.layers import train_state_nnx
from maxtext.models.models import transformer_as_linen
from maxtext.optimizers import optimizers
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import maxtext_utils_nnx
from maxtext.utils import model_creation_utils
from maxtext.utils import train_utils
import numpy as np
from psutil import Process
import tensorstore as ts
Expand Down Expand Up @@ -87,13 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

quant = quantizations.configure_quantization(cfg)
if cfg.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
_, tx = train_utils.create_training_optimizer(cfg, model)
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)

def init_state_fn():
nnx_model = _create_model_partial()
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)

else:
quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
cfg.checkpoint_dir,
Expand All @@ -102,11 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

if cfg.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")
Expand Down Expand Up @@ -191,10 +201,21 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
}

state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}
if cfg.pure_nnx:
# NNX state-tree paths after `nnx.split(TrainStateNNX)`:
# model params -> ['model']<rest>.value
# adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu']<rest>.value
# step -> ['optimizer']['step'].value
# opt count -> ['optimizer']['opt_state']['count'].value
state_map = {
".optimizer.step.value": ("step", None),
".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None),
}
else:
state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}

def get_layer_prefix(keystr_pax):
# different path format between decoder_layer variable
Expand All @@ -206,19 +227,27 @@ def get_layer_prefix(keystr_pax):
return prefix_pax_opt_state

for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
# model variable
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
# first momentum in optimizer state
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
# second momentum in optimizer state
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
if cfg.pure_nnx:
state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
else:
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)

def verify_fn(key_path, _):
keystr = jax.tree_util.keystr(key_path)
Expand Down Expand Up @@ -270,10 +299,11 @@ def map_fn(key_path, value):
max_logging.log("converted state finished")
max_utils.print_mem_stats("converted state finished")

if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
max_logging.log(f"saved a checkpoint at step {step_value}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(converted_state.step):
if checkpoint_manager.reached_preemption(step_value):
checkpoint_manager.wait_until_finished()
sys.exit()

Expand Down
7 changes: 7 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,13 @@ logical_axis_rules: [
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
# ==========================================
# Pipeline Parallelism
# ==========================================
['layers_outside_pipeline', []],
['layers_per_stage', []],
['num_activations', []],
['circular_repeats', []],
# ==========================================
# Deprecated / Scheduled for Removal
# ==========================================
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError("We currently don't support vocab tiling on NNX module.")


def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples):
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,8 +2841,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if self.num_vocab_tiling > 1 and self.enable_nnx:
raise ValueError("We currently don't support vocab tiling on NNX module.")
# Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py.
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
if "gpu" not in self.hardware:
raise ValueError(
Expand Down
Loading
Loading