Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
609 changes: 609 additions & 0 deletions src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py

Large diffs are not rendered by default.

581 changes: 581 additions & 0 deletions src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"""

import argparse
import functools
import gc
import os
import sys
Expand Down Expand Up @@ -87,7 +88,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
mesh = Mesh(devices_array, cfg.mesh_axes)

quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
if cfg.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

Expand All @@ -98,7 +102,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
if cfg.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")

Expand Down
45 changes: 38 additions & 7 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import flags
import datetime
from etils import epath
from flax import nnx
from flax.training import train_state
import jax
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
Expand Down Expand Up @@ -521,7 +522,7 @@ def load_state_if_possible(
load_parameters_from_path: str,
load_full_state_from_path: str,
checkpoint_storage_concurrent_gb: int,
abstract_unboxed_pre_state: train_state.TrainState,
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
enable_single_replica_ckpt_restoring: bool | None = False,
dataset_type: str | None = "tfds",
step: int = -1, # -1 means latest
Expand Down Expand Up @@ -589,8 +590,13 @@ def map_to_pspec(data):
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()

restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args)

match (checkpoint_manager, dataset_type, data_iterator):
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
Expand Down Expand Up @@ -625,9 +631,14 @@ def map_to_pspec(data):
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)

if load_parameters_from_path != "":
if isinstance(abstract_unboxed_pre_state, nnx.State):
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
else:
params = abstract_unboxed_pre_state.params

restored_params = load_params_from_path(
load_parameters_from_path,
abstract_unboxed_pre_state.params,
params,
checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
Expand Down Expand Up @@ -711,15 +722,35 @@ def save_params_to_path(checkpoint_dir, params, use_ocdbt=True, use_zarr3=True):
print(f"Quantized params checkpoint saved at: {checkpoint_dir}")


def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step=None):
"""Save checkpoint if checkpointing is enabled."""
def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step=None, force=False):
"""Save checkpoint if checkpointing is enabled.

Args:
checkpoint_manager: The checkpoint manager.
state: The training state to save.
config: The config object.
data_iterator: The data iterator.
step: The step number. If None, extracts from state (for Linen TrainState).
force: If True, force save the checkpoint regardless of checkpoint_period.
"""
if checkpoint_manager is None:
return

# Determine the effective step for saving a checkpoint.
# If 'step' is not provided, this call is for a potential final checkpoint
# and use the last completed step from the state.
actual_step = (int(state.step) - 1) if step is None else int(step)
if step is not None:
actual_step = int(step)
else:
if config.pure_nnx:
actual_step = int(state.optimizer.step) - 1
else:
# Linen TrainState has .step attribute
actual_step = int(state.step) - 1

if config.pure_nnx:
# Convert nnx.State to dict.
state = state.to_pure_dict()

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
# This occurs if this function was called:
Expand Down
9 changes: 9 additions & 0 deletions src/maxtext/common/gcloud_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def is_decoupled() -> bool: # dynamic check so setting env after initial import
return os.environ.get("DECOUPLE_GCLOUD", "").upper() == "TRUE"


def is_pure_nnx() -> bool: # dynamic check so setting env after initial import still works
"""Return True when running in pure NNX mode (PURE_NNX=TRUE env var).

Defaults to FALSE — Linen is the default test mode.
Set PURE_NNX=TRUE to opt in to NNX mode (skips linen_only tests, runs nnx_only tests).
"""
return os.environ.get("PURE_NNX", "FALSE").upper() == "TRUE"


T = TypeVar("T")


Expand Down
6 changes: 4 additions & 2 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ logical_axis_rules: [
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
['num_activations', []],
['engram_dim', ['tensor']],
['mhc', []],
['diloco', 'diloco'],
Expand Down Expand Up @@ -1106,8 +1107,9 @@ position_id_per_seconds: 25
subslice_shape: ""

# NNX
enable_nnx: false
pure_nnx_decoder: false
enable_nnx: True
pure_nnx_decoder: True
pure_nnx: True

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
Expand Down
46 changes: 37 additions & 9 deletions src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# This logical rule is designed to optimize pipeline parallelism for large-scale jobs.
# Key changes include removing expert weight sharding on the `q_lora` dimension, which
# is relatively small (e.g., 512 for DeepSeek), and limiting sharding strategies when
# EP x FSDP > 512.
# This logical rule is designed to optimize pipeline parallelism for large-scale jobs.
# Key changes include removing expert weight sharding on the `q_lora` dimension, which
# is relatively small (e.g., 512 for DeepSeek), and limiting sharding strategies when
# EP x FSDP > 512.
#
# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a
# data parallel (DP) domain externally, making the `data` axis a necessary reference;
# second, it may be required for DCN communication.
# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a
# data parallel (DP) domain externally, making the `data` axis a necessary reference;
# second, it may be required for DCN communication.
#
# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or
# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to
# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or
# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to
# store prefetched weights.
mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert']
data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']]
Expand Down Expand Up @@ -71,4 +71,32 @@ logical_axis_rules: [
['exp_with_fsdp', 'fsdp'],
['paged_kv_heads', ['tensor']],
['engram_dim', ['tensor']],
# Axes unsharded: sequence/context/tensor_transpose/autoregressive do not exist in this mesh
['activation_attn_length_no_exp', []],
['activation_length_no_exp', []],
['activation_norm_length', []],
['activation_q_length_no_exp', []],
['prefill_activation_length', []],
['prefill_activation_norm_length', []],
['activation_kv_length', []],
['decode_length', []],
['embed_tensor_transpose', []],
['q_lora_up_proj', []],
['kv_lora_up_proj', []],
['kv', []],
['qkv', []],
['kv_head_dim', []],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_heads_none', []],
['cache_kv', []],
['cache_sequence', []],
['num_pages', []],
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
['num_activations', []],
['mhc', []],
['diloco', []],
]
57 changes: 56 additions & 1 deletion src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# This rule only uses FSDP. Pure FSDP is the go-to sharding strategy
# This rule only uses FSDP. Pure FSDP is the go-to sharding strategy
# for small-scale training and this rule simplifies the overall configuration.
mesh_axes: ['fsdp']
data_sharding: [['fsdp']]
logical_axis_rules: [
# Batch/data dimensions sharded on fsdp
['activation_batch', ['fsdp']],
['activation_batch_no_exp', ['fsdp']],
['activation_batch_moe', ['fsdp']],
Expand All @@ -27,11 +28,65 @@ logical_axis_rules: [
['activation_kv_batch', ['fsdp']],
['activation_kv_batch_no_exp', ['fsdp']],
['decode_batch', ['fsdp']],
# Weight dimensions sharded on fsdp
['embed', ['fsdp']],
['embed_no_exp', ['fsdp']],
['embed_moe', ['fsdp']],
['embed_no_exp_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['exp_with_fsdp', 'fsdp'],
# All other axes are unsharded (tensor/sequence/expert axes do not exist in pure-fsdp)
['activation_heads', []],
['activation_kv_heads', []],
['activation_length', []],
['activation_attn_length', []],
['activation_attn_length_no_exp', []],
['activation_length_no_exp', []],
['activation_norm_length', []],
['activation_q_length', []],
['activation_q_length_no_exp', []],
['prefill_activation_length', []],
['prefill_activation_norm_length', []],
['activation_kv_length', []],
['activation_attn_embed', []],
['activation_embed', []],
['activation_mlp', []],
['activation_kv', []],
['activation_kv_head_dim', []],
['activation_vocab', []],
['activation_stage', []],
['activation_exp', []],
['decode_length', []],
['mlp', []],
['mlp_no_fsdp', []],
['vocab', []],
['heads', []],
['q_heads', []],
['kv_heads', []],
['embed_tensor_transpose', []],
['q_lora_up_proj', []],
['kv_lora_up_proj', []],
['norm', []],
['layers', []],
['qkv', []],
['kv', []],
['kv_head_dim', []],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_heads_none', []],
['cache_heads', []],
['cache_kv', []],
['cache_sequence', []],
['exp', []],
['paged_kv_heads', []],
['num_pages', []],
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
['num_activations', []],
['engram_dim', []],
['mhc', []],
['diloco', []],
]
4 changes: 4 additions & 0 deletions src/maxtext/configs/decoupled_base_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ eval_dataset_name: 'c4/en:3.1.0'
# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs
attention: "dot_product"

# Default to Linen mode for tests; NNX is opt-in via PURE_NNX=TRUE.
pure_nnx: False
pure_nnx_decoder: False

# Avoid HLO dump overhead.
dump_hlo: false
jax_cache_dir: ""
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ class HardwareAndMesh(BaseModel):
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")


class LayoutAndSharding(BaseModel):
Expand Down
32 changes: 26 additions & 6 deletions src/maxtext/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,23 +546,43 @@ def setup_train_loop(
max_logging.log("Training mesh used for the workload")
num_inference_devices = config.inference_devices_per_replica * config.inference_replicas
training_devices = jax.devices()[num_inference_devices:]
model = mt.from_config(config, devices=training_devices)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = mt.from_config(config, devices=training_devices)
mesh = model.mesh
max_logging.log("Inference mesh used for the workload")
inference_devices = jax.devices()[:num_inference_devices]
inference_model = mt.from_config(config_inference, devices=inference_devices)
if config_inference.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
inference_model = mt.from_config(config_inference, devices=inference_devices)
inference_mesh = inference_model.mesh
init_rng, checkpoint_manager, learning_rate_schedule, tx = train_utils.create_training_tools(config, model, mesh)
init_rng = jax.random.PRNGKey(config.init_weights_seed)
learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model)
if config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn)

with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
data_iterator = grpo_input_pipeline.create_data_iterator(config_inference, inference_mesh)
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
data_iterator, config, mesh, checkpoint_manager, init_state_fn
)

# create inference_state_mesh_shardings from inference_mesh
if config_inference.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_inference_state_fn = functools.partial(
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
)
inference_state_mesh_shardings = maxtext_utils.get_abstract_state(
inference_model, tx, config_inference, init_rng, inference_mesh, is_training=False
config_inference, inference_mesh, init_inference_state_fn, is_training=False
)[2]
if not config.using_pipeline_parallelism:
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
Expand Down Expand Up @@ -697,7 +717,7 @@ def train_loop(config, config_inference, recorder, state=None):
data_buffer = []
data_buffer_lock = threading.Lock()

start_step = get_first_step(state) # this is the start_step for training
start_step = get_first_step(model, state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
inference_prof = profiler.Profiler(config_inference, offset_step=start_step)
data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder)
Expand Down
Loading
Loading