@@ -35,6 +35,7 @@ def main():
3535 parser .add_argument ('-print-per-tensor-counts' , help = 'Print parameter counts per tensor' , action = 'store_true' )
3636 parser .add_argument ('-no-compile' , help = 'Do not torch.compile' , action = 'store_true' )
3737 parser .add_argument ('-use-tf32-matmul' , help = 'Reduce float32 precision for speed on some gpus' , action = 'store_true' )
38+ parser .add_argument ('-override-config' , help = 'Override model config params, e.g. "gab_d1=16,tab_num_freqs=8"' , type = str , default = None )
3839 args = vars (parser .parse_args ())
3940
4041 model_kind = args ["model_kind" ]
@@ -48,6 +49,7 @@ def main():
4849 print_per_tensor = args ["print_per_tensor_counts" ]
4950 no_compile = args ["no_compile" ]
5051 use_tf32_matmul = args ["use_tf32_matmul" ]
52+ override_config_str = args ["override_config" ]
5153
5254 device = torch .device (f"cuda:{ gpu_idx } " )
5355
@@ -60,7 +62,45 @@ def main():
6062
6163 # Load model config and create model
6264 assert model_kind in modelconfigs .config_of_name , f"Unknown model kind: { model_kind } , available: { list (modelconfigs .config_of_name .keys ())} "
63- model_config = modelconfigs .config_of_name [model_kind ]
65+ model_config = modelconfigs .config_of_name [model_kind ].copy ()
66+
67+ # Apply config overrides
68+ if override_config_str :
69+ for kv in override_config_str .split ("," ):
70+ kv = kv .strip ()
71+ if not kv :
72+ continue
73+ key , val_str = kv .split ("=" , 1 )
74+ key = key .strip ()
75+ val_str = val_str .strip ()
76+ if key in model_config :
77+ orig = model_config [key ]
78+ if isinstance (orig , bool ):
79+ model_config [key ] = val_str .lower () in ("true" , "1" , "yes" )
80+ elif isinstance (orig , int ):
81+ model_config [key ] = int (val_str )
82+ elif isinstance (orig , float ):
83+ model_config [key ] = float (val_str )
84+ elif isinstance (orig , str ):
85+ model_config [key ] = val_str
86+ else :
87+ import json
88+ model_config [key ] = json .loads (val_str )
89+ print (f"Config override: { key } = { model_config [key ]} (was { orig } )" )
90+ else :
91+ # New key: infer type from value string
92+ if val_str .lower () in ("true" , "false" ):
93+ model_config [key ] = val_str .lower () == "true"
94+ else :
95+ try :
96+ model_config [key ] = int (val_str )
97+ except ValueError :
98+ try :
99+ model_config [key ] = float (val_str )
100+ except ValueError :
101+ model_config [key ] = val_str
102+ print (f"Config override (new): { key } = { model_config [key ]} " )
103+
64104 print (f"Model kind: { model_kind } " )
65105 print (f"Optimizer: { optimizer_kind } " )
66106 print (f"Batch size: { batch_size } " )
@@ -143,6 +183,7 @@ def main():
143183 forward_times = benchmark_forward (model , batch , num_iters , warmup_iters )
144184 print_timing_stats ("Forward" , forward_times )
145185 print ()
186+ torch .cuda .empty_cache ()
146187
147188 # Benchmark forward + backward + optimizer step with attribution
148189 print ("=" * 80 )
@@ -157,6 +198,7 @@ def main():
157198 total_times = [f + b + o for f , b , o in zip (fwd_times , bwd_times , opt_times )]
158199 print_timing_stats ("Total " , total_times )
159200 print ()
201+ torch .cuda .empty_cache ()
160202
161203 # Print proportions
162204 mean_fwd = sum (fwd_times ) / len (fwd_times )
@@ -247,6 +289,7 @@ def benchmark_forward(model, batch, num_iters, warmup_iters):
247289
248290 torch .cuda .synchronize ()
249291 t1 = time .perf_counter ()
292+ del model_outputs
250293
251294 if i >= warmup_iters :
252295 times .append (t1 - t0 )
@@ -303,6 +346,7 @@ def benchmark_full_step(model, raw_model, optimizer, metrics_obj, batch, model_c
303346
304347 torch .cuda .synchronize ()
305348 t_opt_end = time .perf_counter ()
349+ del model_outputs , postprocessed , metrics , loss
306350
307351 if i >= warmup_iters :
308352 fwd_times .append (t_bwd_start - t_fwd_start )
@@ -348,6 +392,7 @@ def benchmark_full_step_throughput(model, raw_model, optimizer, metrics_obj, bat
348392
349393 torch .cuda .synchronize ()
350394 t1 = time .perf_counter ()
395+ del model_outputs , postprocessed , metrics , loss
351396
352397 if i >= warmup_iters :
353398 times .append (t1 - t0 )
0 commit comments