Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ trainer.save_state()
trainer.save_model("<path to the output directory>")
```

### Observability

If W&B is installed, it will be used automatically for logging. If `--tensorboard` is provided,
it will be used instead, outputting data into the provided log directory.

## Support Matrix

| Model | Medusa | EAGLE1/2 | EAGLE3 |
Expand Down
71 changes: 61 additions & 10 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from packaging.version import Version
from scripts.ar_validate import validate_ar
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from transformers import Trainer, TrainerCallback
from transformers.trainer_pt_utils import LabelSmoother

Expand Down Expand Up @@ -170,34 +171,47 @@ def make_eagle_supervised_data_module(
class EagleTrainerWithAccLog(Trainer):
"""Wrapper around Trainer that logs training accuracy."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = False

def compute_loss(self, *args, **kwargs):
"""Override compute_loss to save train accs in trainer state."""
if not hasattr(self.state, "training_accs"):
self.state.training_accs = []
kwargs.pop("num_items_in_batch", None)
loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs)
return_outputs = kwargs.pop("return_outputs", False)
loss, outputs = super().compute_loss(*args, return_outputs=True, **kwargs)
if hasattr(outputs, "train_acc"):
self.state.training_accs.append(outputs.train_acc)
return loss
return (loss, outputs) if return_outputs else loss


class EagleTrainingPlot(TrainerCallback):
"""Callback that plot training acc and AR during training."""

def __init__(self, ar_validate_steps: int = 1000, estimate_ar: bool = False):
def __init__(
self,
ar_validate_steps: int = 1000,
estimate_ar: bool = False,
tb_writer: SummaryWriter | None = None,
):
self.ar_validate_steps = ar_validate_steps
if wandb and is_master():
wandb.init()
self.estimate_ar = estimate_ar
self.tb_writer = tb_writer
self.last_seen_step = -1

def on_log(self, args, state, control, **kwargs):
"""Log training acc and estimate AR during log step."""
def _report_stats(self, state, eval_mode: bool, **kwargs):
if not hasattr(state, "training_accs") or len(state.training_accs) == 0:
return control
return
average_acc = np.mean(state.training_accs, axis=0)
mode_name = "Eval" if eval_mode else "Training"
mode_id = mode_name.lower()
if self.estimate_ar:
# Calculate mean training AR since last log
# NOTE: This is only an estimate of the real AR.
# NOTE: This is only a estimate of the real AR.
est_ar = 1
acc_cumprod = 1
for step_acc in average_acc[0]:
Expand All @@ -207,7 +221,7 @@ def on_log(self, args, state, control, **kwargs):
for draft_acc in average_acc[1:]:
acc_cumprod *= draft_acc[-1]
est_ar += acc_cumprod
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")
print_rank_0(f"Step {state.global_step} Estimated {mode_name} AR: {est_ar:.4f}")

# log to wandb
if wandb and is_master():
Expand All @@ -217,11 +231,44 @@ def on_log(self, args, state, control, **kwargs):
for i, draft_acc in enumerate(average_acc):
for j, step_acc in enumerate(draft_acc):
wandb.log(
{f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step
{f"parallel_{i}_step_{j}_{mode_id}_acc": step_acc}, step=state.global_step
)
if self.estimate_ar:
wandb.log({f"estimated_{mode_id}_ar": est_ar}, step=state.global_step)

if self.tb_writer:
# TODO: What are in "kwargs.logs"?
for i, draft_acc in enumerate(average_acc):
for j, step_acc in enumerate(draft_acc):
self.tb_writer.add_scalar(
f"{mode_id}/parallel_{i}_step_{j}_{mode_id}_acc",
step_acc,
state.global_step,
)
if self.estimate_ar:
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)
self.tb_writer.add_scalar(f"{mode_id}/estimated_ar", est_ar, state.global_step)

def on_log(self, args, state, control, **kwargs):
"""Log training acc and estimate AR during log step."""
if not hasattr(state, "training_accs") or len(state.training_accs) == 0:
self.last_seen_step = state.global_step
return control

if state.global_step != self.last_seen_step:
# Eval mode doesn't increment the global step, so we can use that to detect eval vs training
self._report_stats(state, eval_mode=False, **kwargs)
# reset training_accs
state.training_accs = []

self.last_seen_step = state.global_step
return control

def on_evaluate(self, args, state, control, **kwargs):
"""Log eval acc and estimate AR during eval step."""
if not hasattr(state, "training_accs") or len(state.training_accs) == 0:
return control

self._report_stats(state, eval_mode=True, **kwargs)
# reset training_accs
state.training_accs = []
return control
Expand All @@ -242,6 +289,10 @@ def on_step_end(self, args, state, control, **kwargs):
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb and is_master():
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
if self.tb_writer:
self.tb_writer.add_scalar(
"custom/validate_ar", sum(ars) / len(ars), state.global_step
)
except Exception:
print_rank_0("AR validation not available.")
return control
Expand Down
15 changes: 13 additions & 2 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
MIX_HIDDEN_STATES="${1#*=}"
;;
--tensorboard*)
if [[ "$1" != *=* ]]; then shift; fi
ENABLE_TENSORBOARD="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -152,14 +156,14 @@ DISABLE_TQDM=${DISABLE_TQDM:-False}
VLM_PROCESSOR=${VLM_PROCESSOR:-}
VLM_IMG_DIR=${VLM_IMG_DIR:-}
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
ESTIMATE_AR=${ESTIMATE_AR:-False}
ESTIMATE_AR=${ESTIMATE_AR:-True}
CP_SIZE=${CP_SIZE:-1}
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
LOG_STEPS=${LOG_STEPS:-100}
DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}
NUM_TTT_STEPS=${NUM_TTT_STEPS:-3}

ENABLE_TENSORBOARD=${ENABLE_TENSORBOARD:-"False"}

if [[ "$MODE" == "eagle3" ]]; then
if [[ -n "$EAGLE_CONFIG" ]]; then
Expand Down Expand Up @@ -216,6 +220,12 @@ else
MULTI_NODE_ARGS=""
fi

if [[ "$ENABLE_TENSORBOARD" != "False" ]]; then
OBSERVABILITY_ARGS="--report_to tensorboard"
else
OBSERVABILITY_ARGS=""
fi

# Disable tokenizers parallelism to avoid warning
export TOKENIZERS_PARALLELISM=False
CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/main.py \
Expand Down Expand Up @@ -253,6 +263,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
--cp_size $CP_SIZE \
--dp_shard_size $DP_SHARD_SIZE \
--num_ttt_steps $NUM_TTT_STEPS \
$OBSERVABILITY_ARGS \
"

start_time=$(date +%s)
Expand Down
24 changes: 22 additions & 2 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
make_eagle_supervised_data_module,
patch_ring_attention_for_ttt,
)
from torch.utils.tensorboard import SummaryWriter
from transformers.integrations import TensorBoardCallback
from transformers.trainer_utils import get_last_checkpoint

import modelopt.torch.opt as mto
Expand Down Expand Up @@ -102,7 +104,7 @@ class TrainingArguments(transformers.TrainingArguments):
bf16: bool = field(default=True)
mode: Literal["eagle3", "medusa"] = "eagle3"
estimate_ar: bool = field(
default=False, metadata={"help": "Whether to estimate AR during training for logging."}
default=True, metadata={"help": "Whether to estimate AR during training for logging."}
)
ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."})
disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."})
Expand Down Expand Up @@ -235,11 +237,29 @@ def train():
tokenizer, data_args, train_len=training_args.training_seq_len
)

callbacks = []
tb_writer = None
if "tensorboard" in training_args.report_to:
log_dir = training_args.output_dir
tb_writer = SummaryWriter(log_dir=log_dir)
if isinstance(training_args.report_to, list):
training_args.report_to.remove("tensorboard")
else:
training_args.report_to = "none"
callbacks.append(TensorBoardCallback(tb_writer=tb_writer))
callbacks.append(
EagleTrainingPlot(
training_args.ar_validate_steps,
tb_writer=tb_writer,
estimate_ar=training_args.estimate_ar,
)
)

trainer = EagleTrainerWithAccLog(
model=model,
processing_class=tokenizer,
args=training_args,
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)],
callbacks=callbacks,
**data_module,
)

Expand Down
Loading