Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
LanguageDataCollator,
ShardedDataset,
VisionLanguageDataCollator,
_get_bucket_size,
)

try:
Expand Down Expand Up @@ -89,8 +90,9 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
class EagleOfflineDataCollator:
"""Data collator that truncate or pads data for offline training."""

def __init__(self, train_len):
def __init__(self, train_len, bucket_granularity=0):
self.train_len = train_len
self.bucket_granularity = bucket_granularity

def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0):
"""Pad or truncate a tensor to length along a given dimension."""
Expand All @@ -110,13 +112,19 @@ def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0):
return out

def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
if self.bucket_granularity > 0:
batch_max = max(item["input_ids"].shape[0] for item in features)
pad_len = _get_bucket_size(batch_max, self.train_len, self.bucket_granularity)
else:
pad_len = self.train_len

base_batch = {
k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features])
k: torch.stack([self._pad_or_truncate(item[k], pad_len) for item in features])
for k in ["input_ids", "attention_mask", "loss_mask", "labels"]
}

base_model_outputs = {
k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features])
k: torch.stack([self._pad_or_truncate(item[k], pad_len) for item in features])
for k in ["base_model_hidden_states", "aux_hidden_states"]
}

Expand All @@ -131,6 +139,7 @@ def make_eagle_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
train_len=None,
bucket_granularity=0,
) -> dict:
if data_args.offline_data_path is None:
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
Expand All @@ -140,6 +149,7 @@ def make_eagle_supervised_data_module(
tokenizer=tokenizer,
train_len=train_len,
return_labels=True,
bucket_granularity=bucket_granularity,
)
else:
data_collator = VisionLanguageDataCollator(
Expand All @@ -159,7 +169,9 @@ def make_eagle_supervised_data_module(
raise ValueError(f"No .pt files found in {data_args.offline_data_path}")

train_dataset = OfflineSupervisedDataset(dumped_files)
data_collator = EagleOfflineDataCollator(train_len=train_len)
data_collator = EagleOfflineDataCollator(
train_len=train_len, bucket_granularity=bucket_granularity
)

return {
"train_dataset": train_dataset,
Expand Down
7 changes: 6 additions & 1 deletion examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
NUM_NODES="${1#*=}"
;;
--bucket_granularity*)
if [[ "$1" != *=* ]]; then shift; fi
BUCKET_GRANULARITY="${1#*=}"
;;
--head_node_ip*)
if [[ "$1" != *=* ]]; then shift; fi
HEAD_NODE_IP="${1#*=}"
Expand Down Expand Up @@ -164,7 +168,7 @@ DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}
DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"}
NUM_TTT_STEPS=${NUM_TTT_STEPS:-3}

BUCKET_GRANULARITY=${BUCKET_GRANULARITY:-512}

if [[ "$MODE" == "eagle3" ]]; then
if [[ -n "$EAGLE_CONFIG" ]]; then
Expand Down Expand Up @@ -259,6 +263,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
--cp_size $CP_SIZE \
--dp_shard_size $DP_SHARD_SIZE \
--num_ttt_steps $NUM_TTT_STEPS \
--bucket_granularity $BUCKET_GRANULARITY \
"

start_time=$(date +%s)
Expand Down
19 changes: 18 additions & 1 deletion examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ class TrainingArguments(transformers.TrainingArguments):
)
cp_size: int = field(default=1, metadata={"help": "Context parallelism size."})
dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."})
bucket_granularity: int = field(
default=512,
metadata={
"help": (
"Pad sequences to the nearest multiple of this value instead of training_seq_len. "
"Set to 0 to disable (always pad to training_seq_len)."
)
},
)


@dataclass
Expand Down Expand Up @@ -237,8 +246,16 @@ def train():

print_rank_0("Loading dataset...")
if training_args.mode == "eagle3":
bucket_gran = training_args.bucket_granularity
if bucket_gran > 0 and training_args.cp_size > 1:
from math import lcm

bucket_gran = lcm(bucket_gran, training_args.cp_size)
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, train_len=training_args.training_seq_len
tokenizer,
data_args,
train_len=training_args.training_seq_len,
bucket_granularity=bucket_gran,
)

trainer = EagleTrainerWithAccLog(
Expand Down
21 changes: 12 additions & 9 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def _activate_torch_compile(self):
import torch._dynamo

torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode
torch._dynamo.config.cache_size_limit = 2048

# Individual try-catch for each function to maximize torch.compile usage
try:
Expand All @@ -634,24 +635,23 @@ def _activate_torch_compile(self):
print("Disabling torch.compile for _prepare_eagle_inputs due to compilation error.")

try:
self._eagle_forward = torch.compile(
self._eagle_forward, dynamic=False, mode="max-autotune"
)
self._eagle_forward = torch.compile(self._eagle_forward, dynamic=False)
except Exception:
print("Disabling torch.compile for _eagle_forward due to compilation error.")

try:
self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True)
self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False)
except Exception:
print("Disabling torch.compile for _eagle_loss due to compilation error.")

def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step):
# compile and cached flex attention masks in first call
if ttt_step not in self._cached_attn_blk_masks:
self._cached_attn_blk_masks.update(
{ttt_step: self._compute_ttt_attention_mask(batch_size, seq_length, ttt_step)}
cache_key = (ttt_step, seq_length)
if cache_key not in self._cached_attn_blk_masks:
self._cached_attn_blk_masks[cache_key] = self._compute_ttt_attention_mask(
batch_size, seq_length, ttt_step
)
return self._cached_attn_blk_masks[ttt_step]
return self._cached_attn_blk_masks[cache_key]

def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, past_key_values_length, device, dtype
Expand Down Expand Up @@ -1100,7 +1100,10 @@ def pseudo_speculative_generate(
)

# Use SDPA attention during generation for both stability and performance
with temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"):
with (
temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"),
torch.compiler.set_stance("force_eager"),
):
_, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward(
eagle_input_hidden_states,
self._base_model_embeddings(eagle_ids),
Expand Down
33 changes: 31 additions & 2 deletions modelopt/torch/utils/plugins/transformers_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os

import torch
import torch.nn.functional as F
import transformers
from datasets import load_dataset
from transformers.trainer_pt_utils import LabelSmoother
Expand Down Expand Up @@ -112,6 +113,13 @@ def _load_dataset(self):
self._raw_samples = shard


def _get_bucket_size(seq_len: int, max_len: int, granularity: int) -> int:
"""Round seq_len up to the nearest multiple of granularity, capped at max_len."""
if granularity <= 0:
return max_len
return min(((seq_len + granularity - 1) // granularity) * granularity, max_len)


class LanguageDataCollator:
"""Data collator for language modeling tasks.

Expand All @@ -129,6 +137,7 @@ def __init__(
answer_only_loss: bool = False,
json_key: str = "text",
return_labels: bool = False,
bucket_granularity: int = 0,
):
"""Initialize the LanguageDataset."""
if not isinstance(tokenizer, transformers.PreTrainedTokenizerBase):
Expand All @@ -143,6 +152,7 @@ def __init__(
self.answer_only_loss = answer_only_loss
self.json_key = json_key
self.return_labels = return_labels
self.bucket_granularity = bucket_granularity

if chat_template is not None:
self.tokenizer.chat_template = chat_template
Expand Down Expand Up @@ -172,31 +182,50 @@ def _post_process_chat_template(self):
)

def _process_chat_sample(self, examples: list):
padding = "longest" if self.bucket_granularity > 0 else "max_length"
tokenized_examples = self.tokenizer.apply_chat_template(
examples,
return_tensors="pt",
return_dict=True,
padding="max_length",
padding=padding,
truncation=True,
max_length=self.train_len,
add_generation_prompt=self.add_generation_prompt,
return_assistant_tokens_mask=self.answer_only_loss,
)
if self.bucket_granularity > 0:
tokenized_examples = self._pad_to_bucket(tokenized_examples)
if self.return_labels:
input_ids = tokenized_examples["input_ids"]
labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID)
labels[..., :-1] = input_ids[..., 1:]
tokenized_examples["labels"] = labels
return tokenized_examples

def _pad_to_bucket(self, tokenized_examples):
cur_len = tokenized_examples["input_ids"].shape[1]
bucket_len = _get_bucket_size(cur_len, self.train_len, self.bucket_granularity)
pad_size = bucket_len - cur_len
if pad_size > 0:
tokenized_examples["input_ids"] = F.pad(
tokenized_examples["input_ids"], (0, pad_size), value=self.tokenizer.pad_token_id
)
tokenized_examples["attention_mask"] = F.pad(
tokenized_examples["attention_mask"], (0, pad_size), value=0
)
return tokenized_examples

def _process_text_sample(self, examples: list):
padding = "longest" if self.bucket_granularity > 0 else "max_length"
tokenized_examples = self.tokenizer(
examples,
return_tensors="pt",
padding="max_length",
padding=padding,
truncation=True,
max_length=self.train_len,
)
if self.bucket_granularity > 0:
tokenized_examples = self._pad_to_bucket(tokenized_examples)
return tokenized_examples

def __call__(self, examples):
Expand Down
Loading