Problem
All GPU code runs as vanilla PyTorch with no hardware-specific optimization. Leaves significant performance on the table across all target platforms.
Current State (measured on DGX Spark GB10)
| Feature |
Available |
Utilized |
| TF32 matmul |
Yes (Blackwell) |
No (allow_tf32 = False) |
| BF16 |
Yes |
No (all FP32/FP64) |
| 128GB UMA |
Yes |
No (unnecessary .cpu().numpy() roundtrips) |
| Tensor Cores (5th gen) |
Yes |
No |
torch.compile |
Yes |
No |
Target Platforms
| Platform |
Memory |
Key Optimization |
| DGX Spark GB10 |
128GB UMA (shared) |
Minimize .cpu()/.to(device) — zero-copy |
| A100 (40/80GB) |
Dedicated HBM2e |
Maximize batch size, fuse transfers |
| H100 (80GB) |
Dedicated HBM3 |
FP8 tensor cores, transformer engine |
Proposed Changes (by effort/impact)
Quick wins (1-2 lines each)
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')
torch.compile(nqs) for NQS inference
Medium effort (~50 lines)
torch.amp.autocast('cuda', dtype=torch.bfloat16) around NQS training loop
- Audit and remove unnecessary
.cpu().numpy() → .to(device) roundtrips (especially in UMA)
Larger effort (separate PR)
- Hardware-adaptive runtime: detect UMA vs HBM and adjust data movement strategy
- CUDA kernel optimization for
diagonal_elements_batch and PT2 scoring
- Numba/CUDA acceleration for
get_connections inner loop (currently pure Python)
Related
Problem
All GPU code runs as vanilla PyTorch with no hardware-specific optimization. Leaves significant performance on the table across all target platforms.
Current State (measured on DGX Spark GB10)
allow_tf32 = False).cpu().numpy()roundtrips)torch.compileTarget Platforms
.cpu()/.to(device)— zero-copyProposed Changes (by effort/impact)
Quick wins (1-2 lines each)
torch.backends.cuda.matmul.allow_tf32 = Truetorch.set_float32_matmul_precision('high')torch.compile(nqs)for NQS inferenceMedium effort (~50 lines)
torch.amp.autocast('cuda', dtype=torch.bfloat16)around NQS training loop.cpu().numpy()→.to(device)roundtrips (especially in UMA)Larger effort (separate PR)
diagonal_elements_batchand PT2 scoringget_connectionsinner loop (currently pure Python)Related
compute_pt2_scoresis CPU-bound Python loop)