Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
from modelopt.recipe import load_config
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
from modelopt.torch.utils import print_rank_0
Expand Down Expand Up @@ -167,10 +168,14 @@ def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dic
eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert()
dflash_cfg: DFlash section dict (DFlashConfig fields), passed directly to mtsp.convert()
"""
merged = OmegaConf.load(config_path)
# Resolve $import / imports: via modelopt's loader, then layer OmegaConf
# dotlist overrides on top.
cfg = load_config(config_path)
assert isinstance(cfg, dict), f"Top-level recipe must be a YAML mapping: {config_path}"
if overrides:
merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides)))
cfg = OmegaConf.to_container(merged, resolve=True)
merged = OmegaConf.merge(OmegaConf.create(cfg), OmegaConf.from_dotlist(list(overrides)))
cfg = OmegaConf.to_container(merged, resolve=True)
assert isinstance(cfg, dict)

# Eagle/DFlash sections map directly to config fields — no field enumeration needed.
eagle_cfg = cfg.get("eagle", {})
Expand Down Expand Up @@ -318,8 +323,15 @@ def train():
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
elif training_args.mode == "dflash":
# Mask-token resolution: recipe value wins; otherwise fall back to the
# tokenizer's built-in mask_token_id. DFlashConfig still raises if neither
# source provides one.
if dflash_cfg.get("dflash_mask_token_id") is None:
tok_mask_id = getattr(tokenizer, "mask_token_id", None)
if tok_mask_id is not None:
dflash_cfg["dflash_mask_token_id"] = tok_mask_id
dflash_cfg = DFlashConfig.model_validate(
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
dflash_cfg, context={"data_args": data_args}
).model_dump()
mtsp.convert(model, [("dflash", dflash_cfg)])
else:
Expand Down
20 changes: 8 additions & 12 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@

from .eagle.default_config import default_eagle_config, default_kimik2_eagle_config

# Permissive schema for `model:` / `data:` / `training:` recipe snippets used
# via $import in modelopt_recipes/configs/speculative_decoding/. Real field
# validation happens downstream in transformers.HfArgumentParser.parse_dict()
# (examples/speculative_decoding/main.py); this alias exists so snippets can
# satisfy load_config()'s requirement that modelopt-schema paths resolve under
# the modelopt.* namespace.
SpeculativeArgsSnippet = dict

kimik2_eagle_default_config = deepcopy(default_kimik2_eagle_config)

eagle3_default_config = deepcopy(default_eagle_config)
Expand Down Expand Up @@ -132,18 +140,6 @@ def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any:
data["dflash_offline"] = getattr(data_args, "offline_data_path", None) is not None
return data

@model_validator(mode="before")
@classmethod
def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any:
"""Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None:
return data
ctx = info.context if info.context else {}
tokenizer = ctx.get("tokenizer")
if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None:
data["dflash_mask_token_id"] = tokenizer.mask_token_id
return data

@model_validator(mode="after")
def _check_mask_token_id(self) -> "DFlashConfig":
"""Validate that mask_token_id is set after all resolution attempts."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Default DFlashConfig values for DFlash training. Imported into the `dflash:`
# section of recipes. ``dflash_mask_token_id`` is intentionally omitted so the
# snippet schema is the permissive ``SpeculativeArgsSnippet`` (DFlashConfig's
# after-validator would otherwise raise during snippet load); per-model recipes
# should provide ``dflash_mask_token_id`` explicitly, and main.py falls back to
# ``tokenizer.mask_token_id`` when neither does.

# modelopt-schema: modelopt.torch.speculative.config.SpeculativeArgsSnippet
dflash_block_size: 8
dflash_num_anchors: 512
dflash_use_torch_compile: false
dflash_self_logit_distillation: true
dflash_loss_decay_factor: 4.0
dflash_architecture_config:
num_hidden_layers: 5
# mask_token_id: auto-detected from model vocab (override for specific models)
# sliding_window and layer_types are inherited from base model config automatically
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Default `training:` section values for DFlash training. Imported into the
# `training:` section of recipes. Real field validation is performed by
# transformers.HfArgumentParser.parse_dict() in main.py.

# modelopt-schema: modelopt.torch.speculative.config.SpeculativeArgsSnippet

# --- commonly modified ---
mode: dflash
output_dir:
num_train_epochs: 10
per_device_train_batch_size: 1
learning_rate: 6.0e-4
warmup_steps: 100
training_seq_len: 4096
logging_steps: 100
save_steps: 5000
cp_size: 1
dp_shard_size: 1
disable_tqdm: true
estimate_ar: false
ar_validate_steps: 0
answer_only_loss: true

# --- rarely modified ---
do_eval: false
lr_scheduler_type: linear
save_strategy: steps
weight_decay: 0.0
dataloader_drop_last: true
bf16: true
tf32: true
remove_unused_columns: false
ddp_find_unused_parameters: true
ddp_timeout: 1800
report_to: tensorboard
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Default EagleConfig values for EAGLE3 training. Imported into the `eagle:` section of recipes.
# eagle_offline is derived from data.offline_data_path; do not set here.

# modelopt-schema: modelopt.torch.speculative.config.EagleConfig
eagle_decoder_type: llama
eagle_ttt_steps: 3
eagle_mix_hidden_states: false
eagle_use_torch_compile: true
eagle_self_logit_distillation: true
eagle_freeze_base_model: true
eagle_loss_decay_factor: 0.9
eagle_hidden_state_distillation: false
eagle_reuse_base_decoder: false
eagle_report_acc: true
eagle_enable_nvtx: false
# Rope scaling: disable during training (default_config.py uses rope_type=default),
# inject YaRN during export for long-context inference.
eagle_export_rope_scaling:
rope_type: yarn
factor: 32.0
original_max_position_embeddings: 2048
# overwrite to modelopt/torch/speculative/eagle/default_config.py
eagle_architecture_config: {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Default `training:` section values for EAGLE3 training. Imported into the
# `training:` section of recipes. Real field validation is performed by
# transformers.HfArgumentParser.parse_dict() in main.py.

# modelopt-schema: modelopt.torch.speculative.config.SpeculativeArgsSnippet

# --- commonly modified ---
mode: eagle3
output_dir:
num_train_epochs: 1
per_device_train_batch_size: 1
learning_rate: 1.0e-4
warmup_steps: 1000
training_seq_len: 2048
logging_steps: 100
save_steps: 8192
cp_size: 1
disable_tqdm: false
estimate_ar: false
ar_validate_steps: -1
answer_only_loss: false

# --- rarely modified ---
do_eval: false
lr_scheduler_type: linear
save_strategy: steps
weight_decay: 0.0
dataloader_drop_last: true
bf16: true
tf32: true
remove_unused_columns: false
50 changes: 11 additions & 39 deletions modelopt_recipes/general/speculative_decoding/dflash.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Base config for DFlash training. Override fields via OmegaConf dotlist on the CLI.
# Base config for DFlash training. Override fields via OmegaConf dotlist on the CLI
# or by importing this file from a per-model recipe in modelopt_recipes/models/.

imports:
dflash_default: configs/speculative_decoding/dflash/default
dflash_training_default: configs/speculative_decoding/dflash/training_default

# maps to ModelArguments (main.py)
model:
Expand All @@ -17,44 +22,11 @@ data:

# maps to TrainingArguments (main.py)
training:
# --- commonly modified ---
mode: dflash
output_dir:
num_train_epochs: 10
per_device_train_batch_size: 1
learning_rate: 6.0e-4
warmup_steps: 100
training_seq_len: 4096
logging_steps: 100
save_steps: 5000
cp_size: 1
dp_shard_size: 1
disable_tqdm: true
estimate_ar: false
ar_validate_steps: 0
answer_only_loss: true

# --- rarely modified ---
do_eval: false
lr_scheduler_type: linear
save_strategy: steps
weight_decay: 0.0
dataloader_drop_last: true
bf16: true
tf32: true
remove_unused_columns: false
ddp_find_unused_parameters: true
ddp_timeout: 1800
report_to: tensorboard
$import: dflash_training_default

# maps to DFlashConfig (modelopt/torch/speculative/config.py).
# Per-model recipes should also set ``dflash_mask_token_id``; otherwise main.py
# falls back to ``tokenizer.mask_token_id``, and DFlashConfig raises if neither
# source provides one.
dflash:
dflash_block_size: 8
dflash_num_anchors: 512
dflash_use_torch_compile: false
dflash_self_logit_distillation: true
dflash_loss_decay_factor: 4.0
dflash_architecture_config:
num_hidden_layers: 5
# mask_token_id: auto-detected from model vocab (override for specific models)
# sliding_window and layer_types are inherited from base model config automatically
$import: dflash_default
54 changes: 8 additions & 46 deletions modelopt_recipes/general/speculative_decoding/eagle3.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI.
# Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI
# or by importing this file from a per-model recipe in modelopt_recipes/models/.

imports:
eagle_default: configs/speculative_decoding/eagle/default
eagle_training_default: configs/speculative_decoding/eagle/training_default

# maps to ModelArguments (main.py)
model:
Expand All @@ -16,51 +21,8 @@ data:

# maps to TrainingArguments (main.py)
training:
# --- commonly modified ---
mode: eagle3
output_dir:
num_train_epochs: 1
per_device_train_batch_size: 1
learning_rate: 1.0e-4
warmup_steps: 1000
training_seq_len: 2048
logging_steps: 100
save_steps: 8192
cp_size: 1
disable_tqdm: false
estimate_ar: false
ar_validate_steps: -1
answer_only_loss: false

# --- rarely modified ---
do_eval: false
lr_scheduler_type: linear
save_strategy: steps
weight_decay: 0.0
dataloader_drop_last: true
bf16: true
tf32: true
remove_unused_columns: false
$import: eagle_training_default

# maps to EagleConfig (modelopt/torch/speculative/config.py).
eagle:
# eagle_offline is derived from data.offline_data_path; do not set here.
eagle_decoder_type: llama
eagle_ttt_steps: 3
eagle_mix_hidden_states: false
eagle_use_torch_compile: true
eagle_self_logit_distillation: true
eagle_freeze_base_model: true
eagle_loss_decay_factor: 0.9
eagle_hidden_state_distillation: false
eagle_reuse_base_decoder: false
eagle_report_acc: true
eagle_enable_nvtx: false
# Rope scaling: disable during training (default_config.py uses rope_type=default),
# inject YaRN during export for long-context inference.
eagle_export_rope_scaling:
rope_type: yarn
factor: 32.0
original_max_position_embeddings: 2048
# overwrite to modelopt/torch/speculative/eagle/default_config.py
eagle_architecture_config: {}
$import: eagle_default
26 changes: 26 additions & 0 deletions modelopt_recipes/models/Kimi-K2.5/dflash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Per-model DFlash offline training recipe for Kimi-K2.5.

imports:
dflash_default: configs/speculative_decoding/dflash/default
dflash_training_default: configs/speculative_decoding/dflash/training_default

model:
model_name_or_path: moonshotai/Kimi-K2.5
trust_remote_code: true
use_fake_base_for_offline: true

data:
offline_data_path: <path to offline data>

training:
$import: dflash_training_default
output_dir: ckpts/kimi-k25-dflash

dflash:
$import: dflash_default
# If unset, main.py falls back to tokenizer.mask_token_id; DFlashConfig
# raises if neither this field nor the tokenizer provides one.
# dflash_mask_token_id:
26 changes: 26 additions & 0 deletions modelopt_recipes/models/Kimi-K2.5/eagle3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Per-model EAGLE3 offline training recipe for Kimi-K2.5.
# Mirrors examples/speculative_decoding/scripts/train_kimi_k25_offline.sh.

imports:
eagle_default: configs/speculative_decoding/eagle/default
eagle_training_default: configs/speculative_decoding/eagle/training_default

model:
model_name_or_path: moonshotai/Kimi-K2.5
trust_remote_code: true
use_fake_base_for_offline: true

data:
offline_data_path: <path to offline data>

training:
$import: eagle_training_default
output_dir: ckpts/kimi-k25-eagle3
training_seq_len: 4096

eagle:
$import: eagle_default
eagle_decoder_type: kimik2
Loading