Skip to content

Commit cbeab3b

Browse files
committed
GAB and TAB experimentation, nbt for transformers, learnable rope.
1 parent 5e38948 commit cbeab3b

6 files changed

Lines changed: 1954 additions & 417 deletions

File tree

python/benchmark_fresh_model.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def main():
3535
parser.add_argument('-print-per-tensor-counts', help='Print parameter counts per tensor', action='store_true')
3636
parser.add_argument('-no-compile', help='Do not torch.compile', action='store_true')
3737
parser.add_argument('-use-tf32-matmul', help='Reduce float32 precision for speed on some gpus', action='store_true')
38+
parser.add_argument('-override-config', help='Override model config params, e.g. "gab_d1=16,tab_num_freqs=8"', type=str, default=None)
3839
args = vars(parser.parse_args())
3940

4041
model_kind = args["model_kind"]
@@ -48,6 +49,7 @@ def main():
4849
print_per_tensor = args["print_per_tensor_counts"]
4950
no_compile = args["no_compile"]
5051
use_tf32_matmul = args["use_tf32_matmul"]
52+
override_config_str = args["override_config"]
5153

5254
device = torch.device(f"cuda:{gpu_idx}")
5355

@@ -60,7 +62,45 @@ def main():
6062

6163
# Load model config and create model
6264
assert model_kind in modelconfigs.config_of_name, f"Unknown model kind: {model_kind}, available: {list(modelconfigs.config_of_name.keys())}"
63-
model_config = modelconfigs.config_of_name[model_kind]
65+
model_config = modelconfigs.config_of_name[model_kind].copy()
66+
67+
# Apply config overrides
68+
if override_config_str:
69+
for kv in override_config_str.split(","):
70+
kv = kv.strip()
71+
if not kv:
72+
continue
73+
key, val_str = kv.split("=", 1)
74+
key = key.strip()
75+
val_str = val_str.strip()
76+
if key in model_config:
77+
orig = model_config[key]
78+
if isinstance(orig, bool):
79+
model_config[key] = val_str.lower() in ("true", "1", "yes")
80+
elif isinstance(orig, int):
81+
model_config[key] = int(val_str)
82+
elif isinstance(orig, float):
83+
model_config[key] = float(val_str)
84+
elif isinstance(orig, str):
85+
model_config[key] = val_str
86+
else:
87+
import json
88+
model_config[key] = json.loads(val_str)
89+
print(f"Config override: {key} = {model_config[key]} (was {orig})")
90+
else:
91+
# New key: infer type from value string
92+
if val_str.lower() in ("true", "false"):
93+
model_config[key] = val_str.lower() == "true"
94+
else:
95+
try:
96+
model_config[key] = int(val_str)
97+
except ValueError:
98+
try:
99+
model_config[key] = float(val_str)
100+
except ValueError:
101+
model_config[key] = val_str
102+
print(f"Config override (new): {key} = {model_config[key]}")
103+
64104
print(f"Model kind: {model_kind}")
65105
print(f"Optimizer: {optimizer_kind}")
66106
print(f"Batch size: {batch_size}")
@@ -143,6 +183,7 @@ def main():
143183
forward_times = benchmark_forward(model, batch, num_iters, warmup_iters)
144184
print_timing_stats("Forward", forward_times)
145185
print()
186+
torch.cuda.empty_cache()
146187

147188
# Benchmark forward + backward + optimizer step with attribution
148189
print("=" * 80)
@@ -157,6 +198,7 @@ def main():
157198
total_times = [f + b + o for f, b, o in zip(fwd_times, bwd_times, opt_times)]
158199
print_timing_stats("Total ", total_times)
159200
print()
201+
torch.cuda.empty_cache()
160202

161203
# Print proportions
162204
mean_fwd = sum(fwd_times) / len(fwd_times)
@@ -247,6 +289,7 @@ def benchmark_forward(model, batch, num_iters, warmup_iters):
247289

248290
torch.cuda.synchronize()
249291
t1 = time.perf_counter()
292+
del model_outputs
250293

251294
if i >= warmup_iters:
252295
times.append(t1 - t0)
@@ -303,6 +346,7 @@ def benchmark_full_step(model, raw_model, optimizer, metrics_obj, batch, model_c
303346

304347
torch.cuda.synchronize()
305348
t_opt_end = time.perf_counter()
349+
del model_outputs, postprocessed, metrics, loss
306350

307351
if i >= warmup_iters:
308352
fwd_times.append(t_bwd_start - t_fwd_start)
@@ -348,6 +392,7 @@ def benchmark_full_step_throughput(model, raw_model, optimizer, metrics_obj, bat
348392

349393
torch.cuda.synchronize()
350394
t1 = time.perf_counter()
395+
del model_outputs, postprocessed, metrics, loss
351396

352397
if i >= warmup_iters:
353398
times.append(t1 - t0)

python/katago/train/load_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
torch.serialization.add_safe_globals([float])
1616

1717
def load_model_state_dict(state_dict):
18-
# Strip off any "module." from when the model was saved with DDP or other things
18+
# Strip off any "module." from DDP or "_orig_mod." from torch.compile
1919
model_state_dict = {}
2020
for key in state_dict["model"]:
2121
old_key = key
22-
while key.startswith("module."):
23-
key = key[7:]
22+
while key.startswith("module.") or key.startswith("_orig_mod."):
23+
if key.startswith("module."):
24+
key = key[len("module."):]
25+
elif key.startswith("_orig_mod."):
26+
key = key[len("_orig_mod."):]
2427
# Filter out some extra keys that were present in older checkpoints
2528
if "score_belief_offset_vector" in key or "score_belief_offset_bias_vector" in key or "score_belief_parity_vector" in key:
2629
continue

0 commit comments

Comments
 (0)