Skip to content

perf: hardware-adaptive GPU optimization (DGX Spark UMA + A100/H100 HBM) #32

@thc1006

Description

@thc1006

Problem

All GPU code runs as vanilla PyTorch with no hardware-specific optimization. Leaves significant performance on the table across all target platforms.

Current State (measured on DGX Spark GB10)

Feature Available Utilized
TF32 matmul Yes (Blackwell) No (allow_tf32 = False)
BF16 Yes No (all FP32/FP64)
128GB UMA Yes No (unnecessary .cpu().numpy() roundtrips)
Tensor Cores (5th gen) Yes No
torch.compile Yes No

Target Platforms

Platform Memory Key Optimization
DGX Spark GB10 128GB UMA (shared) Minimize .cpu()/.to(device) — zero-copy
A100 (40/80GB) Dedicated HBM2e Maximize batch size, fuse transfers
H100 (80GB) Dedicated HBM3 FP8 tensor cores, transformer engine

Proposed Changes (by effort/impact)

Quick wins (1-2 lines each)

  • torch.backends.cuda.matmul.allow_tf32 = True
  • torch.set_float32_matmul_precision('high')
  • torch.compile(nqs) for NQS inference

Medium effort (~50 lines)

  • torch.amp.autocast('cuda', dtype=torch.bfloat16) around NQS training loop
  • Audit and remove unnecessary .cpu().numpy().to(device) roundtrips (especially in UMA)

Larger effort (separate PR)

  • Hardware-adaptive runtime: detect UMA vs HBM and adjust data movement strategy
  • CUDA kernel optimization for diagonal_elements_batch and PT2 scoring
  • Numba/CUDA acceleration for get_connections inner loop (currently pure Python)

Related

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions