Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
581 changes: 581 additions & 0 deletions src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

# This conversion script reads paxml-format weights and emits a Linen-format
# MaxText checkpoint (downstream uses `.params['params']`, `.opt_state.mu['params']`,
# `.opt_state.nu['params']` keystr paths; the keystr_map below targets the Linen
# tree shape). Use the Linen path regardless of pure_nnx.
quant = quantizations.configure_quantization(cfg)
if cfg.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

Expand All @@ -102,11 +103,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

if cfg.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")
Expand Down
7 changes: 7 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,13 @@ logical_axis_rules: [
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
# ==========================================
# Pipeline Parallelism
# ==========================================
['layers_outside_pipeline', []],
['layers_per_stage', []],
['num_activations', []],
['circular_repeats', []],
# ==========================================
# Deprecated / Scheduled for Removal
# ==========================================
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError("We currently don't support vocab tiling on NNX module.")


def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples):
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,8 +2841,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if self.num_vocab_tiling > 1 and self.enable_nnx:
raise ValueError("We currently don't support vocab tiling on NNX module.")
# Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py.
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
if "gpu" not in self.hardware:
raise ValueError(
Expand Down
Loading
Loading