Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 130 additions & 83 deletions src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from maxtext.common.gcloud_stub import mldiagnostics_modules
from maxtext.common.gcloud_stub import workload_monitor
from maxtext.common.managed_mldiagnostics import ManagedMLDiagnostics
from maxtext.utils import exceptions
from maxtext.utils import gcs_utils
from maxtext.utils import max_logging
from maxtext.utils import max_utils
Expand Down Expand Up @@ -99,38 +100,47 @@ def __init__(self, config, learning_rate_schedule):
self.performance_metric_queue = self.get_performance_metric_queue(config)
self.learning_rate_schedule = learning_rate_schedule
self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
self.buffered_train_metrics = None
# self.buffered_metrics is a polymorphic deferred-write queue. Entries are one of:
# ("train", train_step, metrics, step_time_delta)
# ("eval", eval_step, metrics, step_time_delta)
self.buffered_metrics = []
# Number of eval steps accumulated since the last reset_eval_metrics(). Used by
# buffer_and_write_metrics to detect the eval→train transition and trigger finalization.
self._pending_eval_step_count = 0
if self.config.managed_mldiagnostics:
ManagedMLDiagnostics(config) # Initialize the MLRun instance.

def reset_eval_metrics(self):
"""Resets the cumulative metrics dictionary for a new evaluation run."""
self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
self._pending_eval_step_count = 0

def write_metrics(self, metrics, step, is_training=True):
"""Entry point for all metrics writing in Train's Main."""
def write_metrics(self, metrics, step, metric_type="train"):
"""Entry point for all metrics writing. metric_type is one of 'train', 'eval', 'running_eval'."""
if metrics:
self.log_metrics(metrics, step, is_training)
self.log_metrics(metrics, step, metric_type)

if self.config.enable_tensorboard:
self.write_metrics_to_tensorboard(metrics, step, is_training)
if self.config.enable_tensorboard and metric_type != "running_eval":
self.write_metrics_to_tensorboard(metrics, step, metric_type)

if self.config.metrics_file:
self.write_metrics_locally(metrics, step)

if self.config.gcs_metrics and jax.process_index() == 0:
self.write_metrics_for_gcs(metrics, step, is_training)
self.write_metrics_for_gcs(metrics, step, metric_type)

if self.config.managed_mldiagnostics:
self.write_metrics_to_managed_mldiagnostics(metrics, step)

if is_training:
if metric_type == "train":
self._maybe_abort_after_write_metrics(metrics)

def log_metrics(self, metrics, step, is_training):
def log_metrics(self, metrics, step, metric_type):
"""Logs metrics via max_logging."""
if is_training:
if metric_type == "train":
self._log_training_metrics(metrics, step)
elif metric_type == "running_eval":
self._log_running_eval_metrics(metrics, step)
else:
self._log_eval_metrics(metrics, step)

Expand Down Expand Up @@ -196,23 +206,45 @@ def _log_training_metrics(self, metrics, step):
max_logging.log(", ".join(log_parts))

def _log_eval_metrics(self, metrics, step):
"""Handles evaluation-specific metric logging."""
"""Logs the final accumulated eval summary at the end of an eval run."""
scalars = metrics["scalar"]
log_parts = [
f"eval metrics after step: {step}",
f"Completed eval after train step {step}",
f"loss={scalars['eval/avg_loss']:.3f}",
f"perplexity={scalars['eval/avg_perplexity']:.3f}",
f"total_weights={scalars['eval/total_weights']}",
f"avg_z_loss={scalars.get('eval/avg_z_loss', 0.0):.3f}",
]

if self.config.num_experts > 1:
log_parts.append(f"avg_moe_lb_loss={scalars['eval/avg_moe_lb_loss']:.3f}")
if self.config.mtp_num_layers > 0:
log_parts.extend(
[
f"avg_mtp_loss={scalars['eval/avg_mtp_loss']:.3f}",
f"avg_mtp_acceptance_rate={scalars['eval/avg_mtp_acceptance_rate_percent']:.2f}%",
]
)
if self.config.use_dpo:
log_parts.append(f"dpo_reward_accuracy={scalars['eval/dpo_reward_accuracy']:.3f}")
max_logging.log(", ".join(log_parts))

def _log_running_eval_metrics(self, metrics, step):
"""Logs a per-eval-step running average (deferred by one eval step)."""
scalars = metrics["scalar"]
log_parts = [
f"Completed eval step: {step}",
f"seconds: {scalars['eval/step_time_seconds']:.3f}",
f"running loss={scalars['eval/avg_loss']:.3f}",
f"running perplexity={scalars['eval/avg_perplexity']:.3f}",
f"running total_weights={scalars['eval/total_weights']}",
]
if self.config.mtp_num_layers > 0:
log_parts.extend(
[
f"running mtp_loss={scalars['eval/avg_mtp_loss']:.3f}",
f"running mtp_acceptance_rate={scalars['eval/avg_mtp_acceptance_rate_percent']:.2f}%",
]
)
max_logging.log(", ".join(log_parts))

def _is_profiler_boundary_step(self, step):
Expand Down Expand Up @@ -249,11 +281,11 @@ def write_metrics_locally(self, metrics, step):
metrics_dict = _prepare_metrics_for_json(metrics, step, self.config.run_name)
local_metrics_file.write(str(json.dumps(metrics_dict)) + "\n")

def write_metrics_for_gcs(self, metrics, step, is_training):
def write_metrics_for_gcs(self, metrics, step, metric_type):
"""Writes metrics to GCS."""
metrics_dict_step = _prepare_metrics_for_json(metrics, step, self.config.run_name)
self.running_gcs_metrics.append(metrics_dict_step)
if is_training and (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1:
if metric_type == "train" and (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1:
start_step = (step // self.config.log_period) * self.config.log_period
metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt"
with open(metrics_filename, "wt", encoding="utf8") as metrics_for_gcs:
Expand All @@ -266,15 +298,15 @@ def write_metrics_for_gcs(self, metrics, step, is_training):
max_logging.log(f"File {metrics_filename} moved successfully!")
self.running_gcs_metrics = [] # reset running_metrics to empty list

def write_metrics_to_tensorboard(self, metrics, step, is_training):
def write_metrics_to_tensorboard(self, metrics, step, metric_type):
"""Writes metrics to TensorBoard."""
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
self.writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
self.writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

if is_training:
if metric_type == "train":
full_log = step % self.config.log_period == 0

if full_log and jax.process_index() == 0:
Expand Down Expand Up @@ -324,19 +356,68 @@ def get_performance_metric_queue(self, config):
gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue)
return performance_metric_queue

def buffer_and_write_train_metrics(self, metrics, step, step_time_delta):
def buffer_and_write_metrics(self, metrics, step, step_time_delta=None, is_training=True):
"""
Buffers metrics for the current training step and simultaneously writes the training metrics
for the previous step to GCS and/or TensorBoard. This buffering strategy allows for back-to-back
execution of training steps, by overlapping data loading for step n with the execution of step n−1.
This significantly boosts training efficiency.
"""
if self.buffered_train_metrics is not None:
(step_to_write, metrics_to_write) = self.buffered_train_metrics
self.write_metrics(metrics_to_write, step_to_write)
Per-step entry point for both train and eval metrics. Flushes the single deferred entry from
the previous call, then queues this step's metrics. The buffer is a queue of length 1: a new
call means the previous dispatch is already submitted, so its result is safe to materialize.

self.record_train_metrics(metrics, step, step_time_delta.total_seconds())
self.buffered_train_metrics = (step, metrics)
Both train and eval entries carry raw metrics and defer all processing
(record_train_metrics / _accumulate_eval_metrics) to _flush_one_buffered_entry so float()
never blocks the dispatch path.
"""
if self.buffered_metrics:
self._flush_one_buffered_entry(self.buffered_metrics.pop(0))
if is_training:
self.buffered_metrics.append(("train", step, metrics, step_time_delta))
if self._pending_eval_step_count > 0:
self._finalize_eval_metrics(step)
else:
self._pending_eval_step_count += 1
self.buffered_metrics.append(("eval", step, metrics, step_time_delta))

def _flush_one_buffered_entry(self, entry):
"""Dispatches a single buffered entry to the writer. All float() calls happen here."""
kind = entry[0]
if kind == "train":
_, step, metrics, step_time_delta = entry
self.record_train_metrics(metrics, step, step_time_delta.total_seconds())
self.write_metrics(metrics, step)
elif kind == "eval":
_, eval_step, raw_metrics, step_time_delta = entry
# _accumulate_eval_metrics calls float() that materialize the metrics, deferred to here
self._accumulate_eval_metrics(raw_metrics)
running_count = eval_step + 1 # eval_step is 0-indexed
cumulative = self.cumulative_eval_metrics["scalar"]
running_avg_loss = cumulative["eval/total_loss"] / (cumulative["eval/total_weights"] + EPS)
snapshot = {
"scalar": {
Comment thread
aireenmei marked this conversation as resolved.
"eval/avg_loss": running_avg_loss,
"eval/avg_perplexity": np.exp(running_avg_loss),
"eval/total_weights": cumulative["eval/total_weights"],
"eval/avg_mtp_loss": cumulative["eval/mtp_loss"] / running_count,
"eval/avg_mtp_acceptance_rate_percent": (cumulative["eval/mtp_acceptance_rate_percent"] / running_count),
"eval/step_time_seconds": step_time_delta.total_seconds(),
}
}
self.write_metrics(snapshot, eval_step, metric_type="running_eval")

def _accumulate_eval_metrics(self, metrics):
"""Accumulates one eval step's raw metrics into cumulative_eval_metrics (eager float())."""
scalar = metrics.get("scalar", {})
self.cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(scalar.get("evaluation/total_loss", 0.0))
self.cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(scalar.get("evaluation/total_weights", 0.0))
self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(scalar.get("evaluation/moe_lb_loss", 0.0))
self.cumulative_eval_metrics["scalar"]["eval/indexer_loss"] += float(scalar.get("evaluation/indexer_loss", 0.0))
self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(scalar.get("evaluation/mtp_loss", 0.0))
self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] += float(
scalar.get("evaluation/mtp_acceptance_rate_percent", 0.0)
)
self.cumulative_eval_metrics["scalar"]["eval/z_loss"] += float(scalar.get("evaluation/z_loss", 0.0))
if self.config.use_dpo:
self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] += float(
scalar.get("evaluation/dpo_reward_accuracy", 0.0)
)

def record_train_metrics(self, metrics, step, step_time):
"""Records training metrics for the current step."""
Expand All @@ -354,58 +435,24 @@ def record_train_metrics(self, metrics, step, step_time):
if self.performance_metric_queue:
self.performance_metric_queue.put(step_time)

def record_eval_metrics(self, step, metrics=None, eval_step_count=None):
"""Records eval metrics and writes the metrics to GCS and/or to TensorBoard."""
if metrics:
self.cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(
metrics["scalar"].get("evaluation/total_loss", 0.0)
)
self.cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(
metrics["scalar"].get("evaluation/total_weights", 0.0)
)
self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(
metrics["scalar"].get("evaluation/moe_lb_loss", 0.0)
)
self.cumulative_eval_metrics["scalar"]["eval/indexer_loss"] += float(
metrics["scalar"].get("evaluation/indexer_loss", 0.0)
)
self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(metrics["scalar"].get("evaluation/mtp_loss", 0.0))
self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] += float(
metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0)
)
self.cumulative_eval_metrics["scalar"]["eval/z_loss"] += float(metrics["scalar"].get("evaluation/z_loss", 0.0))
if self.config.use_dpo:
self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] += float(
metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)
)

if eval_step_count:
eval_loss = self.cumulative_eval_metrics["scalar"]["eval/total_loss"] / (
self.cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS
)
self.cumulative_eval_metrics["scalar"]["eval/avg_loss"] = eval_loss
self.cumulative_eval_metrics["scalar"]["eval/avg_perplexity"] = np.exp(eval_loss)
self.cumulative_eval_metrics["scalar"]["eval/avg_moe_lb_loss"] = (
self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count
)
self.cumulative_eval_metrics["scalar"]["eval/avg_indexer_loss"] = (
self.cumulative_eval_metrics["scalar"]["eval/indexer_loss"] / eval_step_count
)
self.cumulative_eval_metrics["scalar"]["eval/avg_mtp_loss"] = (
self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] / eval_step_count
)
self.cumulative_eval_metrics["scalar"]["eval/avg_mtp_acceptance_rate_percent"] = (
self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] / eval_step_count
)
self.cumulative_eval_metrics["scalar"]["eval/avg_z_loss"] = (
self.cumulative_eval_metrics["scalar"]["eval/z_loss"] / eval_step_count
)
if self.config.use_dpo:
self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = (
self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] / eval_step_count
)

self.write_metrics(self.cumulative_eval_metrics, step, is_training=False)
def _finalize_eval_metrics(self, train_step):
"""Computes final averaged eval metrics and writes them at train_step."""
eval_step_count = self._pending_eval_step_count
cumulative = self.cumulative_eval_metrics["scalar"]
eval_loss = cumulative["eval/total_loss"] / (cumulative["eval/total_weights"] + EPS)
cumulative["eval/avg_loss"] = eval_loss
cumulative["eval/avg_perplexity"] = np.exp(eval_loss)
cumulative["eval/avg_moe_lb_loss"] = cumulative["eval/moe_lb_loss"] / eval_step_count
cumulative["eval/avg_indexer_loss"] = cumulative["eval/indexer_loss"] / eval_step_count
cumulative["eval/avg_mtp_loss"] = cumulative["eval/mtp_loss"] / eval_step_count
cumulative["eval/avg_mtp_acceptance_rate_percent"] = cumulative["eval/mtp_acceptance_rate_percent"] / eval_step_count
cumulative["eval/avg_z_loss"] = cumulative["eval/z_loss"] / eval_step_count
if self.config.use_dpo:
cumulative["eval/dpo_reward_accuracy"] = cumulative["eval/dpo_reward_accuracy"] / eval_step_count
self.write_metrics(self.cumulative_eval_metrics, train_step, metric_type="eval")
self._pending_eval_step_count = 0
if self.config.target_eval_loss and eval_loss <= self.config.target_eval_loss:
raise exceptions.StopTraining(f"Target loss {self.config.target_eval_loss=} is achieved.")

def flush_metrics_and_cleanup(self):
"""
Expand All @@ -414,8 +461,8 @@ def flush_metrics_and_cleanup(self):
logger instance should not be used to add or write more metrics as the
underlying writer objects (e.g., TensorBoard SummaryWriter) will be closed.
"""
if self.buffered_train_metrics is not None:
(step_to_write, metrics_to_write) = self.buffered_train_metrics
self.write_metrics(metrics_to_write, step_to_write)
for entry in self.buffered_metrics:
self._flush_one_buffered_entry(entry)
self.buffered_metrics = []

max_utils.close_summary_writer(self.writer)
20 changes: 13 additions & 7 deletions src/maxtext/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,6 @@ def generation_worker_fn(
data_buffer.clear()

step_time_delta = datetime.datetime.now() - last_step_completion
last_step_completion = datetime.datetime.now()

state_to_save = _split_grpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step)
Expand All @@ -877,27 +876,33 @@ def generation_worker_fn(

if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
assert eval_data_iterator
# Explicitly reset the eval iterator and counters before starting the eval loop
eval_data_iterator.reset()
metric_logger.reset_eval_metrics()
max_logging.log(f"Starting eval after train step {step}")
eval_step_count = 0
last_eval_step_completion = datetime.datetime.now()
# pylint: disable=not-callable
for eval_batch in eval_data_iterator:
if 0 < config.eval_steps <= eval_step_count:
break
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, rng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
eval_step_time_delta = datetime.datetime.now() - last_eval_step_completion
last_eval_step_completion = datetime.datetime.now()
metric_logger.buffer_and_write_metrics(
eval_metrics, eval_step_count, step_time_delta=eval_step_time_delta, is_training=False
)
max_logging.log(f"Completed eval step {eval_step_count}")
eval_step_count += 1
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
prof.deactivate()
raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.")

prof.maybe_deactivate_profiler(step, state)

if step == start_step:
max_utils.print_mem_stats("After params initialized")
Comment thread
aireenmei marked this conversation as resolved.

metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)
last_step_completion = datetime.datetime.now()
metric_logger.buffer_and_write_metrics(metrics, step, step_time_delta)

if config.save_checkpoint_on_completion:
state_to_save = _split_grpo_state(state)[0]
Expand All @@ -907,6 +912,7 @@ def generation_worker_fn(
checkpoint_manager.wait_until_finished()
_job_completed_gracefully = True
except exceptions.StopTraining as e:
prof.deactivate()
max_logging.log(f"Training stopped: {str(e)}")
_job_completed_gracefully = True
finally:
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/trainers/post_train/sft/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def on_eval_step_end(self, train_ctx: peft_trainer.PeftTrainer, eval_loss: float
"eval/total_weights": self.eval_metadata["total_weights"],
}
}
self.metric_logger.write_metrics(metrics, train_ctx.train_steps, is_training=False)
self.metric_logger.write_metrics(metrics, train_ctx.train_steps, metric_type="eval")
self.eval_metadata.clear()

if avg_loss <= self.config.target_eval_loss:
Expand Down
Loading
Loading