@@ -124,7 +124,7 @@ def _test_mcore_gpt_pruning(
124124 uneven_pp ,
125125 position_embedding_type ,
126126 skip_sorting ,
127- ckpt_path ,
127+ ckpt_dir ,
128128 rank ,
129129 size ,
130130):
@@ -198,11 +198,11 @@ def forward_loop(m):
198198 constraints = {"export_config" : export_config }
199199
200200 config = {
201- "checkpoint" : ckpt_path ,
201+ "checkpoint" : ckpt_dir ,
202202 "skip_sorting" : skip_sorting ,
203203 }
204204 if skip_sorting :
205- assert ckpt_path is None
205+ assert ckpt_dir is None
206206 else :
207207 config ["forward_loop" ] = forward_loop
208208 model , pruning_scores = prune_minitron (model , constraints , config , channel_divisor )
@@ -238,11 +238,11 @@ def forward_loop(m):
238238 output = run_mcore_inference (model , prompt_tokens , pruned_hidden_size )
239239
240240 # Assert re-pruning from checkpoint works without running the forward loop again
241- if ckpt_path :
241+ if ckpt_dir :
242242 model_rerun = _get_model (initialize_megatron = False )
243243 model_rerun .load_state_dict (sd )
244244 model_rerun , pruning_scores = prune_minitron (
245- model_rerun , constraints , {"checkpoint" : ckpt_path }, channel_divisor
245+ model_rerun , constraints , {"checkpoint" : ckpt_dir }, channel_divisor
246246 )
247247
248248 output_rerun = run_mcore_inference (model_rerun , prompt_tokens , pruned_hidden_size )
@@ -307,7 +307,7 @@ def test_mcore_gpt_pruning(
307307 uneven_pp ,
308308 position_embedding_type ,
309309 skip_sorting ,
310- tmp_path / "minitron_scores.pth " if test_ckpt else None ,
310+ tmp_path / "minitron_scores" if test_ckpt else None ,
311311 ),
312312 )
313313
@@ -394,7 +394,7 @@ def test_mcore_gpt_moe_parameter_sorting(dist_workers):
394394 dist_workers .run (_test_mcore_gpt_moe_parameter_sorting )
395395
396396
397- def _test_mcore_gpt_pruning_moe (ckpt_path , rank , size ):
397+ def _test_mcore_gpt_pruning_moe (ckpt_dir , rank , size ):
398398 channel_divisor = 4
399399
400400 num_layers = size
@@ -446,7 +446,7 @@ def forward_loop(m):
446446 prune_minitron (
447447 model ,
448448 constraints ,
449- {"checkpoint" : ckpt_path , "forward_loop" : forward_loop },
449+ {"checkpoint" : ckpt_dir , "forward_loop" : forward_loop },
450450 channel_divisor ,
451451 )
452452
@@ -483,14 +483,14 @@ def forward_loop(m):
483483 # Assert re-pruning from checkpoint works without running the forward loop again
484484 model_rerun = _get_model (initialize_megatron = False )
485485 model_rerun .load_state_dict (sd )
486- prune_minitron (model_rerun , constraints , {"checkpoint" : ckpt_path }, channel_divisor )
486+ prune_minitron (model_rerun , constraints , {"checkpoint" : ckpt_dir }, channel_divisor )
487487
488488 output_rerun = run_mcore_inference (model_rerun , prompt_tokens , pruned_hidden_size )
489489 assert torch .allclose (output , output_rerun , atol = 1e-5 )
490490
491491
492492def test_mcore_gpt_pruning_moe (dist_workers , tmp_path ):
493- dist_workers .run (partial (_test_mcore_gpt_pruning_moe , tmp_path / "minitron_scores.pth " ))
493+ dist_workers .run (partial (_test_mcore_gpt_pruning_moe , tmp_path / "minitron_scores" ))
494494
495495
496496def test_generate_search_space_combos ():
0 commit comments