Skip to content

Improve GPU utilization #62

@PatrickRMiles

Description

@PatrickRMiles
  • Pass fused=True to ADAM construction
  • Pass foreach=True to clip_grad_norm(...) call
  • Do zero_grad(set_to_none=True) in training
  • Can we compute local_voxel_count and global_total_voxels outside the batch loop to save an all-reduce?
  • torch.compile() the model
  • torch.compile() the unscale->clip->optimizer step->update block
  • Can we reduce per-batch .item() calls for things like gradient logging?

Context

Pass fused=True to ADAM construction:

  • Fuses entire ADAM update into a single kernel per parameter group. Should reduce kernel launch overhead.

Pass foreach=True to clip_grad_norm(...) call:

  • The default clip_grad_norm iterates per-parameter in python, launching one norm kernel per tensor. With foreach=True, torch uses torch._foreach_norm to batch all gradient norms into a single multi-tensor kernel. This collapses the dozens-to-hundreds of individual norm kernels + a python reduction loop into one fused kernel.

Do zero_grad(set_to_none=True) in training:

  • When False, torch runs one memset kernel per parameter to fill gradient tensors with zeros. When set_to_none=True instead, it simply drops the .grad reference.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions