Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
MIX_HIDDEN_STATES="${1#*=}"
;;
--disable_torch_compile*)
if [[ "$1" != *=* ]]; then shift; fi
DISABLE_TORCH_COMPILE="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -158,6 +162,7 @@ DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
LOG_STEPS=${LOG_STEPS:-100}
DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}
DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"}
NUM_TTT_STEPS=${NUM_TTT_STEPS:-3}


Expand Down Expand Up @@ -245,6 +250,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
--estimate_ar $ESTIMATE_AR \
--ar_validate_steps $AR_VALIDATE_STEPS \
--mix_hidden_states $MIX_HIDDEN_STATES \
--disable_torch_compile $DISABLE_TORCH_COMPILE \
$DRAFT_VOCAB_CACHE_ARGS \
$VLM_ARGS \
$OFFLINE_TRAINING_ARGS \
Expand Down
12 changes: 9 additions & 3 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ class EagleArguments:
default=False,
metadata={"help": "Whether to mix hidden states from previous TTT step."},
)
disable_torch_compile: bool = field(
default=False,
metadata={"help": "Disable torch.compile on eagle forward/loss methods."},
)
num_ttt_steps: int = field(
default=3,
metadata={"help": "Number of train-time-test steps to use during training."},
Expand All @@ -149,9 +153,10 @@ def train():
model_args, data_args, training_args, medusa_args, eagle_args = (
parser.parse_args_into_dataclasses()
)
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, this is an unrelated bugfix related to #1045 (does not fully solve the issue, just a single-gpu workaround)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in slack, this issue id due to transformers version mismatch. Should be fixed after updating transformers.

training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
Expand Down Expand Up @@ -212,6 +217,7 @@ def train():
"eagle_decoder_type": eagle_args.eagle_decoder_type,
"eagle_offline": use_offline_training,
"eagle_mix_hidden_states": eagle_args.mix_hidden_states,
"eagle_use_torch_compile": not eagle_args.disable_torch_compile,
"eagle_ttt_steps": eagle_args.num_ttt_steps,
"eagle_architecture_config": custom_config,
}
Expand Down
5 changes: 5 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,8 @@ class EagleConfig(ModeloptBaseConfig):
"Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost."
),
)

eagle_use_torch_compile: bool = ModeloptField(
default=True,
description="Whether to use torch.compile on eagle forward/loss methods for faster training.",
)
1 change: 1 addition & 0 deletions modelopt/torch/speculative/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ def modify(
self.eagle_decoder_type = config.eagle_decoder_type
self.eagle_ttt_steps = config.eagle_ttt_steps
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states
self.eagle_use_torch_compile = config.eagle_use_torch_compile
16 changes: 16 additions & 0 deletions modelopt/torch/speculative/eagle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

"""Eagle model utils."""

from contextlib import nullcontext

import torch


Expand Down Expand Up @@ -70,3 +72,17 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No
inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


def maybe_nvtx_range(*args, **kwargs):
"""Helper function to create NVTX ranges if NVTX is available."""
try:
import torch.cuda.nvtx as nvtx

nvtx.range_push("nvtx init")
nvtx.range_pop()

return nvtx.range(*args, **kwargs)
except Exception:
# If NVTX is not available, return a no-op context manager
return nullcontext()
97 changes: 73 additions & 24 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
)
from ..eagle.conversion import EagleDMRegistry
from ..eagle.eagle_model import EagleModel
from ..eagle.utils import expand_mask, make_causal_mask
from ..eagle.utils import expand_mask, make_causal_mask, maybe_nvtx_range
from ..medusa.conversion import MedusaDMRegistry
from ..medusa.medusa_model import MedusaModel
from ..utils import (
Expand Down Expand Up @@ -292,6 +292,10 @@ def __init__(self, config, decoder_layer_cls, bias=False):
num_layers=self.config.parallel_draft_heads_num_layers,
)

def _maybe_init_rope(self):
if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"):
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)

def _expand_first_attn_in_dim(self, first_layer_attn):
"""Modify qkv projection in first layer to accept 2h hidden size."""
# Find Linear modules to expand
Expand Down Expand Up @@ -372,11 +376,6 @@ def forward(
self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)

if self.config.eagle_decoder_type == "llama":
# Lazy init rope to avoid save/load meta tensor error
if not hasattr(self, "rotary_emb"):
self.rotary_emb = LlamaRotaryEmbedding(
config=self.config, device=hidden_states.device
)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
else:
position_embeddings = None
Expand Down Expand Up @@ -618,8 +617,34 @@ def modify(
# https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566
self.is_quantized = False

if self.eagle_use_torch_compile:
self._activate_torch_compile()

self._cached_attn_blk_masks = {}

def _activate_torch_compile(self):
import torch._dynamo

torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode

# Individual try-catch for each function to maximize torch.compile usage
try:
self._prepare_eagle_inputs = torch.compile(self._prepare_eagle_inputs, dynamic=False)
except Exception:
print("Disabling torch.compile for _prepare_eagle_inputs due to compilation error.")

try:
self._eagle_forward = torch.compile(
self._eagle_forward, dynamic=False, mode="max-autotune"
)
except Exception:
print("Disabling torch.compile for _eagle_forward due to compilation error.")

try:
self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True)
except Exception:
print("Disabling torch.compile for _eagle_loss due to compilation error.")

def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step):
# compile and cached flex attention masks in first call
if ttt_step not in self._cached_attn_blk_masks:
Expand Down Expand Up @@ -657,6 +682,7 @@ def _prepare_decoder_attention_mask(

return combined_attention_mask

@maybe_nvtx_range("prepare_eagle_inputs")
def _prepare_eagle_inputs(
self,
input_ids,
Expand Down Expand Up @@ -716,7 +742,20 @@ def _prepare_eagle_inputs(
else:
eagle_position_ids = position_ids.view(-1, seq_length).long()

return eagle_input_embeds, eagle_input_hiddens, eagle_attention_mask, eagle_position_ids
base_model_logits = base_outputs.logits
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
base_output_predict_tok = base_model_logits.argmax(dim=-1).detach()
base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach()

return (
eagle_input_embeds,
eagle_input_hiddens,
eagle_attention_mask,
eagle_position_ids,
base_output_predict_tok,
base_output_softmax_logits,
)

def _compute_ttt_attention_mask(
self, batch_size, seq_length, ttt_step
Expand Down Expand Up @@ -746,6 +785,7 @@ def _compute_ttt_attention_mask(
tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1)
return tensor_mask

@maybe_nvtx_range("base_model_forward")
def _base_model_forward(
self,
input_ids,
Expand Down Expand Up @@ -794,6 +834,7 @@ def _map_logits_to_draft_vocab(self, full_logits):
)
return full_logits[:, :, reverse_mapping]

@maybe_nvtx_range("eagle_forward")
def _eagle_forward(
self,
eagle_input_hidden_states,
Expand Down Expand Up @@ -890,13 +931,17 @@ def forward(

# ====Prepare inputs for the first eagle forward pass====
eagle_loss = None
train_accs = [[] for _ in range(self.eagle_config.parallel_draft_step)]
num_parallel = self.eagle_config.parallel_draft_step
num_ttt = self.eagle_ttt_steps
train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device)
b, seq_length, _ = base_outputs.out_hiddens.shape
(
eagle_input_embeds,
eagle_input_hiddens,
eagle_attn_mask_0,
eagle_position_ids,
base_output_predict_tok,
base_output_softmax_logits,
) = self._prepare_eagle_inputs(
input_ids,
attention_mask,
Expand All @@ -905,6 +950,8 @@ def forward(
base_outputs,
)

self.eagle_module._maybe_init_rope()

# ====Run eagle forward with extra training-time-test steps====
for ttt_step in range(self.eagle_ttt_steps):
# TODO: (hg) during cp training, this mask is not used. Maybe turn it off then.
Expand Down Expand Up @@ -949,7 +996,8 @@ def forward(
# base model predict +1 tok, while eagle predict +2
# so we shift base model outputs compared to eagle outputs
# additionally, we mask the first n tok of eagle outputs at nth TTT step
base_outputs.logits[:, 1 + i + ttt_step :],
base_output_softmax_logits[:, 1 + i + ttt_step :],
base_output_predict_tok[:, 1 + i + ttt_step :],
eagle_logit[:, ttt_step : -(1 + i)],
loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i],
)
Expand All @@ -958,10 +1006,13 @@ def forward(
eagle_loss = (
classification_loss if eagle_loss is None else eagle_loss + classification_loss
)
train_accs[i].append(acc)
train_accs[i, ttt_step] = acc
if not self.training:
break

# Slice by actual number of steps taken, in case of early return
train_accs = train_accs[:, : ttt_step + 1].tolist()

# Merge base model loss and eagle loss
if base_outputs.loss is None and eagle_loss is None:
loss = None
Expand All @@ -977,29 +1028,26 @@ def forward(
train_acc=train_accs,
)

@maybe_nvtx_range("eagle_loss")
def _eagle_loss(
self,
base_model_logits,
base_output_softmax_logits,
base_output_predict_tok,
eagle_logits,
loss_mask,
):
"""Function for EAGLE loss computing."""
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
loss_mask = loss_mask[:, : eagle_logits.shape[1], None]
classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)(
eagle_logits
)
classification_loss = -torch.sum(torch.sum(loss_mask * classification_loss, 2)) / (
loss_mask.sum() + 1e-5
)
# Compute accuracy
base_predict_tok = base_model_logits.clone().detach().argmax(dim=-1)
eagle_predict_tok = eagle_logits.clone().detach().argmax(dim=-1)
eagle_logsoft = torch.log_softmax(eagle_logits, dim=2)
classification_loss = -torch.sum(
torch.sum(loss_mask * base_output_softmax_logits * eagle_logsoft, 2)
) / (loss_mask.sum() + 1e-5)
# Compute accuracy (returned as tensor to avoid sync; .item() called after TTT loop)
eagle_predict_tok = eagle_logits.detach().argmax(dim=-1)
valid = loss_mask[:, :, 0].bool()
correct = (base_predict_tok == eagle_predict_tok) & valid
correct = (base_output_predict_tok == eagle_predict_tok) & valid
denom = valid.sum().clamp_min(1).float()
accuracy = round(correct.sum().float().div(denom).item(), 3)
accuracy = correct.sum().float() / denom

return classification_loss, accuracy

Expand Down Expand Up @@ -1039,6 +1087,7 @@ def pseudo_speculative_generate(
else:
eagle_input_hidden_states = base_model_hidden_states

self.eagle_module._maybe_init_rope()
draft_tokens = []
for step in range(steps):
b, seq_length = eagle_ids.shape
Expand Down
Loading