Skip to content

Commit 30a2736

Browse files
authored
Merge branch 'main' into shengliangx/docs-preview
2 parents 8ddaec5 + b61fb4e commit 30a2736

5 files changed

Lines changed: 38 additions & 30 deletions

File tree

examples/megatron_bridge/prune_minitron.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def get_args() -> argparse.Namespace:
111111
type=str,
112112
default=None,
113113
help=(
114-
"Path to save/restore intermediate pruning scores for resuming / faster re-run. "
115-
"If not provided, it will default to `<output_path>/modelopt_pruning_scores.pth`"
114+
"Directory to save/restore per-rank intermediate pruning scores for resuming / faster re-run. "
115+
"If not provided, it will default to `<output_path>/modelopt_pruning_scores`"
116116
),
117117
)
118118

@@ -187,13 +187,11 @@ def get_args() -> argparse.Namespace:
187187
# Post-process arguments
188188
if args.prune_intermediate_ckpt is None:
189189
if args.output_megatron_path:
190-
args.prune_intermediate_ckpt = (
191-
f"{args.output_megatron_path}/modelopt_pruning_scores.pth"
192-
)
190+
args.prune_intermediate_ckpt = f"{args.output_megatron_path}/modelopt_pruning_scores"
193191
elif args.output_hf_path:
194-
args.prune_intermediate_ckpt = f"{args.output_hf_path}/modelopt_pruning_scores.pth"
192+
args.prune_intermediate_ckpt = f"{args.output_hf_path}/modelopt_pruning_scores"
195193
print_rank_0(
196-
"No checkpoint provided to cache intermediate pruning scores. "
194+
"No directory provided to cache per-rank intermediate pruning scores. "
197195
f"Setting to: {args.prune_intermediate_ckpt}"
198196
)
199197

examples/pruning/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ This mode can be useful when you know the exact dimensions you want to prune to
9898
# Specify the pruning constraints (Check Support Matrix for available pruning dimensions)
9999
# Save minitron scores at checkpoint so we can re-run pruning with different constraints without running the forward loop again
100100
constraints = {"export_config": {"num_layers": 32, "hidden_size": 3584, "ffn_hidden_size": 10240}}
101-
config = {"forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores.pth"}
101+
config = {"forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores/"}
102102

103103
mtp.prune(...)
104104
```
@@ -130,7 +130,7 @@ def score_func(m):
130130
constraints = {"params": 6e9} # Prune to 6B parameters
131131
config = {
132132
"forward_loop": forward_loop,
133-
"checkpoint": "/path/to/cache/pruning/scores.pth",
133+
"checkpoint": "/path/to/cache/pruning/scores/",
134134
"score_func": score_func,
135135
# Optional: Configure search space constraints (showing defaults)
136136
"max_width_pruning": 0.4, # Maximum 40% per width pruning hparams (hidden_size, ffn_hidden_size, etc.)

modelopt/torch/opt/searcher.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from abc import ABC, abstractmethod
2727
from collections.abc import Callable
2828
from contextlib import nullcontext
29-
from typing import Any, final
29+
from typing import TYPE_CHECKING, Any, final
3030

3131
import numpy as np
3232
import pulp
@@ -36,6 +36,9 @@
3636
from modelopt.torch.utils import distributed as dist
3737
from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop, warn_rank_0
3838

39+
if TYPE_CHECKING:
40+
from pathlib import Path
41+
3942
LimitsTuple = tuple[float, float]
4043
ConstraintsDict = dict[str, str | float | dict | None]
4144
Deployment = dict[str, str]
@@ -238,9 +241,18 @@ def state_dict(self) -> SearchStateDict:
238241

239242
def _get_checkpoint_path(self) -> str | None:
240243
"""Get per-rank checkpoint path when distributed, otherwise the original path."""
241-
checkpoint = self.config["checkpoint"]
244+
checkpoint: str | Path | None = self.config["checkpoint"]
242245
if checkpoint is None:
243246
return None
247+
checkpoint = str(checkpoint)
248+
# Detect directory: exists as dir, ends with separator, or has no file extension
249+
is_dir_path = (
250+
os.path.isdir(checkpoint)
251+
or checkpoint.endswith(os.sep)
252+
or not os.path.splitext(checkpoint)[1]
253+
)
254+
if is_dir_path:
255+
return os.path.join(checkpoint, f"rank{dist.rank()}.pth")
244256
if dist.is_initialized():
245257
dirname, basename = os.path.split(checkpoint)
246258
name, ext = os.path.splitext(basename)

tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _test_mcore_gpt_pruning(
124124
uneven_pp,
125125
position_embedding_type,
126126
skip_sorting,
127-
ckpt_path,
127+
ckpt_dir,
128128
rank,
129129
size,
130130
):
@@ -198,11 +198,11 @@ def forward_loop(m):
198198
constraints = {"export_config": export_config}
199199

200200
config = {
201-
"checkpoint": ckpt_path,
201+
"checkpoint": ckpt_dir,
202202
"skip_sorting": skip_sorting,
203203
}
204204
if skip_sorting:
205-
assert ckpt_path is None
205+
assert ckpt_dir is None
206206
else:
207207
config["forward_loop"] = forward_loop
208208
model, pruning_scores = prune_minitron(model, constraints, config, channel_divisor)
@@ -238,11 +238,11 @@ def forward_loop(m):
238238
output = run_mcore_inference(model, prompt_tokens, pruned_hidden_size)
239239

240240
# Assert re-pruning from checkpoint works without running the forward loop again
241-
if ckpt_path:
241+
if ckpt_dir:
242242
model_rerun = _get_model(initialize_megatron=False)
243243
model_rerun.load_state_dict(sd)
244244
model_rerun, pruning_scores = prune_minitron(
245-
model_rerun, constraints, {"checkpoint": ckpt_path}, channel_divisor
245+
model_rerun, constraints, {"checkpoint": ckpt_dir}, channel_divisor
246246
)
247247

248248
output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size)
@@ -307,7 +307,7 @@ def test_mcore_gpt_pruning(
307307
uneven_pp,
308308
position_embedding_type,
309309
skip_sorting,
310-
tmp_path / "minitron_scores.pth" if test_ckpt else None,
310+
tmp_path / "minitron_scores" if test_ckpt else None,
311311
),
312312
)
313313

@@ -394,7 +394,7 @@ def test_mcore_gpt_moe_parameter_sorting(dist_workers):
394394
dist_workers.run(_test_mcore_gpt_moe_parameter_sorting)
395395

396396

397-
def _test_mcore_gpt_pruning_moe(ckpt_path, rank, size):
397+
def _test_mcore_gpt_pruning_moe(ckpt_dir, rank, size):
398398
channel_divisor = 4
399399

400400
num_layers = size
@@ -446,7 +446,7 @@ def forward_loop(m):
446446
prune_minitron(
447447
model,
448448
constraints,
449-
{"checkpoint": ckpt_path, "forward_loop": forward_loop},
449+
{"checkpoint": ckpt_dir, "forward_loop": forward_loop},
450450
channel_divisor,
451451
)
452452

@@ -483,14 +483,14 @@ def forward_loop(m):
483483
# Assert re-pruning from checkpoint works without running the forward loop again
484484
model_rerun = _get_model(initialize_megatron=False)
485485
model_rerun.load_state_dict(sd)
486-
prune_minitron(model_rerun, constraints, {"checkpoint": ckpt_path}, channel_divisor)
486+
prune_minitron(model_rerun, constraints, {"checkpoint": ckpt_dir}, channel_divisor)
487487

488488
output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size)
489489
assert torch.allclose(output, output_rerun, atol=1e-5)
490490

491491

492492
def test_mcore_gpt_pruning_moe(dist_workers, tmp_path):
493-
dist_workers.run(partial(_test_mcore_gpt_pruning_moe, tmp_path / "minitron_scores.pth"))
493+
dist_workers.run(partial(_test_mcore_gpt_pruning_moe, tmp_path / "minitron_scores"))
494494

495495

496496
def test_generate_search_space_combos():

tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_mcore_mamba_parameter_sorting(dist_workers):
120120
dist_workers.run(_test_mcore_mamba_parameter_sorting)
121121

122122

123-
def _test_mcore_mamba_hybrid_pruning(ckpt_path, rank, size):
123+
def _test_mcore_mamba_hybrid_pruning(ckpt_dir, rank, size):
124124
channel_divisor = 4
125125

126126
num_layers = min(size * 2, 8)
@@ -196,7 +196,7 @@ def forward_loop(m):
196196
prune_minitron(
197197
model,
198198
constraints,
199-
{"forward_loop": forward_loop, "checkpoint": ckpt_path},
199+
{"forward_loop": forward_loop, "checkpoint": ckpt_dir},
200200
channel_divisor,
201201
)
202202

@@ -224,16 +224,14 @@ def forward_loop(m):
224224

225225
# Assert re-pruning from checkpoint works without running the forward loop again
226226
model = _get_model(initialize_megatron=False)
227-
prune_minitron(model, constraints, {"checkpoint": ckpt_path}, channel_divisor)
227+
prune_minitron(model, constraints, {"checkpoint": ckpt_dir}, channel_divisor)
228228

229229

230230
def test_mcore_mamba_hybrid_pruning(dist_workers, tmp_path):
231-
dist_workers.run(
232-
partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "modelopt_minitron_scores.pth")
233-
)
231+
dist_workers.run(partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "minitron_scores"))
234232

235233

236-
def _test_mcore_mamba_hybrid_pruning_nas(ckpt_path, rank, size):
234+
def _test_mcore_mamba_hybrid_pruning_nas(ckpt_dir, rank, size):
237235
set_seed(SEED)
238236
channel_divisor = 4
239237

@@ -299,7 +297,7 @@ def score_func(m):
299297
constraints = {"params": int(param_count * 0.7)}
300298
config = {
301299
"forward_loop": forward_loop,
302-
"checkpoint": ckpt_path,
300+
"checkpoint": ckpt_dir,
303301
"score_func": score_func,
304302
"max_width_pruning": 0.5,
305303
"max_depth_pruning": 0.5,
@@ -365,5 +363,5 @@ def score_func(m):
365363
)
366364
def test_mcore_mamba_hybrid_pruning_nas(dist_workers, tmp_path):
367365
dist_workers.run(
368-
partial(_test_mcore_mamba_hybrid_pruning_nas, tmp_path / "modelopt_minitron_scores.pth"),
366+
partial(_test_mcore_mamba_hybrid_pruning_nas, tmp_path / "minitron_scores"),
369367
)

0 commit comments

Comments
 (0)