Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions scripts/parameter_norms/compute_layer_norms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python3

import argparse
import json
import os
import re
from pathlib import Path
from typing import cast

import torch
import torch.distributed as dist
from pydantic import BaseModel
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor

from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading
from modalities.checkpointing.stateful.app_state import AppState
from modalities.config.config import ProcessGroupBackendType
from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticDeviceMeshIFType
from modalities.main import Main
from modalities.running_env.cuda_env import CudaEnv
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method


class ComponentsInstantiationModel(BaseModel):
app_state: PydanticAppStateType
device_mesh: PydanticDeviceMeshIFType | None = None


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Load one or more Modalities DCP checkpoints into an app state.")
parser.add_argument("--config-file-path", type=Path, required=True, help="Path to the YAML config file.")
parser.add_argument(
"--experiments-root-path",
type=Path,
required=True,
help="Path passed to Main for resolver/context setup.",
)
parser.add_argument(
"--checkpoint-dir-paths",
type=Path,
nargs="+",
required=True,
help="Paths to multiple checkpoint directories containing *.distcp files.",
)
parser.add_argument(
"--json-output-path",
type=Path,
default=Path("layer_norms_across_checkpoints.json"),
help="Output path for raw per-checkpoint norms as JSON.",
)
return parser.parse_args()


def _resolve_checkpoint_dir_paths(args: argparse.Namespace) -> list[Path]:
return list(args.checkpoint_dir_paths)


def _normalize_parameter_name(parameter_name: str) -> str:
name = parameter_name
for prefix in ("module.", "_orig_mod.", "_fsdp_wrapped_module."):
if name.startswith(prefix):
name = name[len(prefix) :]
return name


def _get_dp_shard_group(device_mesh: DeviceMesh | None):
if device_mesh is None:
return None
try:
return get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.DP_SHARD).get_group()
except Exception:
# Fallback to the default process group if a dedicated DP-shard group is unavailable.
return None


def _compute_and_print_parameter_norms(app_state: AppState, dp_shard_group) -> dict[str, float]:
parameter_sq_sums: dict[str, torch.Tensor] = {}

for model_part_idx, model_part in enumerate(app_state.model_parts):
for name, parameter in model_part.named_parameters():
if not parameter.requires_grad:
continue
raw_name = f"model_part_{model_part_idx}.{name}" if len(app_state.model_parts) > 1 else name
parameter_name = _normalize_parameter_name(raw_name)

# FSDP2 parameters can be DTensors. Convert to local shard first so c10d all_reduce
# operates on plain tensors instead of DTensors.
local_param = parameter.to_local() if isinstance(parameter, DTensor) else parameter
local_sq_sum = local_param.detach().float().pow(2).sum()
parameter_sq_sums[parameter_name] = local_sq_sum

# Aggregate over the DP-shard group to reconstruct global norms for sharded parameters.
for parameter_name, sq_sum in parameter_sq_sums.items():
dist.all_reduce(sq_sum, op=dist.ReduceOp.SUM, group=dp_shard_group)
parameter_sq_sums[parameter_name] = sq_sum

parameter_norms = {name: torch.sqrt(sq_sum).item() for name, sq_sum in parameter_sq_sums.items()}

if dist.get_rank() == 0:
print("Per-parameter L2 norms (global across DP-shards):")
for parameter_name in sorted(parameter_norms):
print(f"{parameter_name}: {parameter_norms[parameter_name]:.6f}")

return parameter_norms


def _extract_checkpoint_label(checkpoint_dir_path: Path) -> str:
match = re.search(r"seen_steps_(\d+)", checkpoint_dir_path.name)
if match:
return f"steps_{match.group(1)}"
return checkpoint_dir_path.name


def _save_json_results(results: list[dict], output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)


def main() -> None:
args = _parse_args()
checkpoint_dir_paths = _resolve_checkpoint_dir_paths(args)

with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
rank = dist.get_rank()
collected_results: list[dict] = []

for checkpoint_dir_path in checkpoint_dir_paths:
# Rebuild components per checkpoint because AppState only supports one load call.
main_obj = Main(
config_path=args.config_file_path,
experiments_root_path=args.experiments_root_path,
)
components = cast(
ComponentsInstantiationModel,
main_obj.build_components(components_model_type=ComponentsInstantiationModel),
)

app_state = cast(AppState, getattr(components, "app_state"))
device_mesh = cast(DeviceMesh | None, getattr(components, "device_mesh", None))

loader = DCPCheckpointLoading(global_rank=rank)
loader.load_checkpoint_(app_state=app_state, checkpoint_dir_path=checkpoint_dir_path)

dp_shard_group = _get_dp_shard_group(device_mesh)
if rank == 0:
print(f"\n=== {checkpoint_dir_path} ===")
parameter_norms = _compute_and_print_parameter_norms(app_state, dp_shard_group)

if rank == 0:
collected_results.append(
{
"checkpoint_path": str(checkpoint_dir_path),
"checkpoint_label": _extract_checkpoint_label(checkpoint_dir_path),
"parameter_norms": parameter_norms,
}
)
print(
f"Loaded checkpoint from {checkpoint_dir_path} on world size {dist.get_world_size()} "
f"(pid={os.getpid()})."
)

if rank == 0:
_save_json_results(collected_results, args.json_output_path)
print(f"Saved raw parameter norms JSON to {args.json_output_path}")


if __name__ == "__main__":
main()
145 changes: 145 additions & 0 deletions scripts/parameter_norms/plot_layer_norms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python3

import argparse
import json
import re
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot parameter norms across checkpoints from a JSON log file.")
parser.add_argument(
"--layer-norms-json-path",
type=Path,
required=True,
help="Path to JSON produced by scripts/compute_layer_norms.py.",
)
parser.add_argument(
"--plot-output-path",
type=Path,
default=Path("parameter_norms_grouped_by_layer.pdf"),
help="Output PDF path containing one plot page per layer.",
)
parser.add_argument(
"--layer-filter-regex",
type=str,
default=r".*",
help="Regex to select layer keys in the visualization.",
)
return parser.parse_args()


def _load_results(path: Path) -> list[dict]:
with open(path, "r", encoding="utf-8") as f:
results = json.load(f)
if not isinstance(results, list) or not results:
raise ValueError("Expected a non-empty JSON list of checkpoint results.")
return results


def _extract_layer_key(parameter_name: str) -> str:
tokens = parameter_name.split(".")
for i in range(len(tokens) - 1):
if tokens[i] in {"h", "layers", "blocks"} and tokens[i + 1].isdigit():
if i > 0:
return ".".join(tokens[i - 1 : i + 2])
return ".".join(tokens[i : i + 2])
return ".".join(tokens[:-1]) if len(tokens) > 1 else parameter_name


def _layer_sort_key(layer_key: str) -> tuple:
# Prefer numeric ordering for transformer block keys like h.0, layers.12, blocks.3.
match = re.search(r"(?:^|\.)(h|layers|blocks)\.(\d+)(?:\.|$)", layer_key)
if match:
return (0, match.group(1), int(match.group(2)), layer_key)
return (1, layer_key)


def _plot_checkpoint_comparison(
results: list[dict],
plot_output_path: Path,
layer_filter_regex: str,
) -> None:
metric_key = "parameter_norms" if "parameter_norms" in results[0] else "layer_norms"
layer_pattern = re.compile(layer_filter_regex)
filtered_parameters = sorted(
{
parameter_name
for checkpoint_result in results
for parameter_name in checkpoint_result[metric_key].keys()
if layer_pattern.search(parameter_name)
}
)
if not filtered_parameters:
raise ValueError(f"No layer names matched --layer-filter-regex={layer_filter_regex!r}.")

checkpoint_labels = [checkpoint_result["checkpoint_label"] for checkpoint_result in results]

grouped_parameters: dict[str, list[str]] = {}
for parameter_name in filtered_parameters:
layer_key = _extract_layer_key(parameter_name)
grouped_parameters.setdefault(layer_key, []).append(parameter_name)
ordered_layer_keys = sorted(grouped_parameters, key=_layer_sort_key)

plot_output_path.parent.mkdir(parents=True, exist_ok=True)
with PdfPages(plot_output_path) as pdf:
# First page: quick summary of layers and parameter counts.
summary_lines = [
f"checkpoints: {len(checkpoint_labels)}",
f"layers: {len(grouped_parameters)}",
f"parameters plotted: {len(filtered_parameters)}",
"",
"Layer -> #parameters",
]
for layer_key in ordered_layer_keys:
summary_lines.append(f"{layer_key}: {len(grouped_parameters[layer_key])}")

fig, ax = plt.subplots(figsize=(10, 12))
ax.axis("off")
ax.text(0.01, 0.99, "\n".join(summary_lines), va="top", ha="left", fontsize=10)
fig.tight_layout()
pdf.savefig(fig)
plt.close(fig)

# One page per layer with all parameter curves for that layer.
x = list(range(len(checkpoint_labels)))
for layer_key in ordered_layer_keys:
parameter_names = sorted(grouped_parameters[layer_key])
fig, ax = plt.subplots(figsize=(12, 6))
for parameter_name in parameter_names:
y = [checkpoint_result[metric_key].get(parameter_name, float("nan")) for checkpoint_result in results]
short_name = (
parameter_name[len(layer_key) + 1 :]
if parameter_name.startswith(layer_key + ".")
else parameter_name
)
ax.plot(x, y, marker="o", linewidth=1.5, label=short_name)

ax.set_title(f"{layer_key} parameter norms across checkpoints")
ax.set_xlabel("Checkpoint")
ax.set_ylabel("L2 norm")
ax.set_xticks(x)
ax.set_xticklabels(checkpoint_labels, rotation=45, ha="right")
ax.grid(True, alpha=0.25)
ax.legend(loc="best", fontsize=8)
fig.tight_layout()
pdf.savefig(fig)
plt.close(fig)


def main() -> None:
args = _parse_args()
results = _load_results(args.layer_norms_json_path)
_plot_checkpoint_comparison(
results=results,
plot_output_path=args.plot_output_path,
layer_filter_regex=args.layer_filter_regex,
)
print(f"Saved grouped parameter-norm plots to {args.plot_output_path}")


if __name__ == "__main__":
main()
Loading