Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_example_tests_runner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
pip uninstall -y apex || true
fi


find examples/${{ inputs.example }} -name "requirements.txt" | while read req_file; do pip install -r "$req_file" || exit 1; done
- name: Run tests
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/example_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/pytorch:25.06-py3"
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
example: ${{ matrix.example }}
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-l4-latest-1
Expand All @@ -81,7 +81,7 @@ jobs:
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/pytorch:25.06-py3"
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
example: ${{ matrix.example }}
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-h100-latest-2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/gpu_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
runs-on: linux-amd64-gpu-l4-latest-1
timeout-minutes: 120
container: &gpu_container
image: nvcr.io/nvidia/pytorch:25.06-py3
image: nvcr.io/nvidia/pytorch:25.08-py3
env:
GIT_DEPTH: 1000 # For correct version for tests/gpu/torch/quantization/plugins/test_megatron.py
PIP_CONSTRAINT: "" # Disable pip constraint for upgrading packages
Expand Down
17 changes: 10 additions & 7 deletions examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,21 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM,

### Docker

Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.06-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.12-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.

Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install dataset and example-specific dependencies.

### Local Installation

Install Modelopt with `hf` dependencies and other requirements for this example:
To set up the environment locally, first install the latest ModelOpt with Hugging Face support:

```bash
pip install -e "../..[hf]"
```

Next, install any additional dependencies required for this example:

```bash
pip install -U nvidia-modelopt[hf]
pip install -r requirements.txt
```

Expand All @@ -56,7 +61,7 @@ See [other-datasets](#other-datasets) section for other dataset options and inst
## Getting Started: Simplified Workflow

```bash
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct
```

This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it
Expand All @@ -74,12 +79,11 @@ For small base models that fit in GPU memory, we can collocate them with draft m
./launch_train.sh --model $BASE_MODEL \
--output_dir $OUTPUT_DIR \
--data input_conversations/daring-anteater.jsonl \
--num_gpu $NUM_GPU \
--num_epochs $NUM_EPOCH \
--eagle_config eagle_config.json
```

This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details.
FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`.
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.

## Training Draft Model with Offline Base Model
Expand Down Expand Up @@ -118,7 +122,6 @@ Once we finish dumping hidden states, launch offline training with an extra `--o
./launch_train.sh --model $BASE_MODEL \
--output_dir $OUTPUT_DIR \
--data $DATA \
--num_gpu $NUM_GPU \
--num_epochs $NUM_EPOCH \
--eagle_config eagle_config.json \
--offline-data $HIDDEN_STATES_DIR
Expand Down
148 changes: 148 additions & 0 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
import os
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from types import FrameType
from typing import Any

import numpy as np
import torch
import transformers
from datasets import load_dataset
from packaging.version import Version
from PIL import Image
from scripts.ar_validate import validate_ar
from torch.distributed.tensor.experimental._attention import _SDPAMerger
from torch.utils.data import Dataset
from transformers import AutoProcessor, Trainer, TrainerCallback
from transformers.trainer_pt_utils import LabelSmoother

import modelopt
from modelopt.torch.speculative.utils import get_ttt_msk_func
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import is_master

Expand Down Expand Up @@ -566,3 +576,141 @@ def on_step_end(self, args, state, control, **kwargs):
except Exception:
print_rank_0("AR validation not available.")
return control


def get_patched_templated_ring_attn(orig_templated_attn: Callable):
"""
Return patched version of
torch.distributed.tensor.experimental._attention._templated_ring_attention
to support TTT.
"""

def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
"""Get chunk-interleaved TTT mask for current rank.
e.g.:
2 ranks, ttt_step=1;
full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0],
[x, 0, 0, 0, 0, x, 0, 0],
[x, x, 0, 0, 0, 0, x, 0],
[x, x, x, 0, 0, 0, 0, x],

rank 0, step0: [[0, 0, x, 0],
[x, 0, 0, x]]

rank 1, step0: [[0, 0, x, 0],
[x, 0, 0, x]]

rank 0, step1: [[0, 0, 0, 0],
[0, 0, 0, 0]]

rank 1, step1: [[x, x, 0, 0],
[x, x, 0, 0]]

"""
device = torch.cuda.current_device()
q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device)
kv_indices = (
torch.arange(q_len * size * (ttt_step + 1), device=device)
.view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :]
.reshape(-1)
)
msk_func = get_ttt_msk_func(q_len * size, ttt_step)
attn_mask = msk_func(
None,
None,
q_indices.view(1, 1, -1, 1),
kv_indices.view(1, 1, 1, -1),
)
attn_bias = torch.where(
attn_mask,
torch.zeros((), dtype=dtype, device=attn_mask.device),
torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device),
)

return attn_bias

def patched_templated_attn(*args, **kwargs):
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
# Get original attention op
# Sensitive to impl of _templated_ring_attention
original_op = args[2]

# This patch is only enabled for eagle model by context manager, not base model.
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH

if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")

# Unset is_causal to use custom attn mask
if patch_enbabled:
kwargs["is_causal"] = False

def patched_op(*args, **kwargs):
# Inspect the parent frame to get current shard info
# This is sensitive to torch _templated_ring_attention impl
try:
frame: FrameType = inspect.currentframe()
f_back: FrameType = frame.f_back
rank = f_back.f_locals["rank"]
size = f_back.f_locals["size"]
query = f_back.f_locals["query"]
key = f_back.f_locals["key"]
i = f_back.f_locals["i"]
ttt_step = (key.shape[2] // query.shape[2]) - 1
except Exception as e:
raise RuntimeError(
f"Failed to capture loop variables in patched _templated_ring_attention: {e}"
) from e
# Set attn mask to permuted TTT mask
if "attn_bias" in kwargs:
kwargs["attn_bias"] = _get_sharded_ttt_msk(
i, rank, size, query.shape[2], ttt_step, query.dtype
)
# Perform shard attention
return original_op(*args, **kwargs)

return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs)

return patched_templated_attn


def patch_ring_attention_for_ttt():
"""Patch torch ring attention to support context parallelism for TTT."""
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.

if not (
Version(torch.__version__) >= Version("2.8.0")
and Version(torch.__version__) < Version("2.9.0")
):
raise RuntimeError(
f"Context parallel TTT only supported for PyTorch 2.8.0 now. "
f"Got {torch.__version__}. Please use cp_size=1 or change to torch 2.8.0."
)

# 1. Disable load balance, which is designed for causal mask.
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False

# 2. Patch templated ring attention for TTT mask.
original_templated_ring_attention = (
torch.distributed.tensor.experimental._attention._templated_ring_attention
)
original_templated_ring_attention_backward = (
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
)
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
get_patched_templated_ring_attn(original_templated_ring_attention)
)
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
)

# 3. Patch merger to skip the blank shard to avoid difference in output.
original_sdpa_merger_step = _SDPAMerger.step

def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
if lse.sum() <= 0:
return
return original_sdpa_merger_step(self, out, lse, partial)

_SDPAMerger.step = patched_sdpa_merger_step
1 change: 1 addition & 0 deletions examples/speculative_decoding/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"fsdp_version":2}
39 changes: 23 additions & 16 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_CONFIG="${1#*=}"
;;
--fsdp_transformer_layer_cls_to_wrap*)
if [[ "$1" != *=* ]]; then shift; fi
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
;;
--num_gpu*)
if [[ "$1" != *=* ]]; then shift; fi
NUM_GPU="${1#*=}"
;;
--disable_tqdm*)
if [[ "$1" != *=* ]]; then shift; fi
DISABLE_TQDM="${1#*=}"
Expand All @@ -102,6 +94,14 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
AR_VALIDATE_STEPS="${1#*=}"
;;
--cp_size*)
if [[ "$1" != *=* ]]; then shift; fi
CP_SIZE="${1#*=}"
;;
--dp_size*)
if [[ "$1" != *=* ]]; then shift; fi
DP_SHARD_SIZE="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -129,15 +129,15 @@ LR=${LR:-"1e-4"}
TRAIN_BS=${TRAIN_BS:-4}
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
NUM_GPU=${NUM_GPU:-1}
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
DISABLE_TQDM=${DISABLE_TQDM:-False}
VLM_PROCESSOR=${VLM_PROCESSOR:-}
VLM_IMG_DIR=${VLM_IMG_DIR:-}
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
ESTIMATE_AR=${ESTIMATE_AR:-False}
CP_SIZE=${CP_SIZE:-1}
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}

if [[ "$MODE" == "medusa" ]]; then
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
Expand All @@ -163,21 +163,25 @@ else
OFFLINE_TRAINING_ARGS=""
fi

if [[ "$NUM_GPU" == 1 ]]; then
MULTI_GPU=""
else
MULTI_GPU="--multi_gpu"
fi

if [[ "$VLM_PROCESSOR" != "" ]]; then
VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR"
else
VLM_ARGS=""
fi

if [[ "$GPU_COUNT" -gt 1 ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
fi


# Disable tokenizers parallelism to avoid warning
export TOKENIZERS_PARALLELISM=False
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
CMD="accelerate launch --mixed_precision bf16 main.py \
--mode $MODE \
--eagle_decoder_type $EAGLE_DECODER_TYPE \
--model_name_or_path $MODEL \
Expand Down Expand Up @@ -206,6 +210,9 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
$VLM_ARGS \
$OFFLINE_TRAINING_ARGS \
$SPECULATIVE_ARGS \
$FSDP_ARGS \
--cp_size $CP_SIZE \
--dp_shard_size $DP_SHARD_SIZE \
"

start_time=$(date +%s)
Expand Down
17 changes: 16 additions & 1 deletion examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@

import torch
import transformers
from eagle_utils import EagleTrainerWithAccLog, EagleTrainingPlot, make_eagle_supervised_data_module
from accelerate import ParallelismConfig
from eagle_utils import (
EagleTrainerWithAccLog,
EagleTrainingPlot,
make_eagle_supervised_data_module,
patch_ring_attention_for_ttt,
)
from medusa_utils import make_medusa_supervised_data_module
from transformers.trainer_utils import get_last_checkpoint

Expand Down Expand Up @@ -100,6 +106,8 @@ class TrainingArguments(transformers.TrainingArguments):
remove_unused_columns: bool = field(
default=False, metadata={"help": "Set to False to keep extra args for VLM."}
)
cp_size: int = field(default=1, metadata={"help": "Context parallelism size."})
dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."})


@dataclass
Expand Down Expand Up @@ -130,6 +138,13 @@ def train():
model_args, data_args, training_args, medusa_args, eagle_args = (
parser.parse_args_into_dataclasses()
)
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
training_args.parallelism_config.sp_backend = None
print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}")

# Detecting last checkpoint.
Expand Down
Loading
Loading