Skip to content

feat: reduce train memory#44

Draft
yukkysaito wants to merge 4 commits intotier4-mainfrom
feature/reduce-train-memory
Draft

feat: reduce train memory#44
yukkysaito wants to merge 4 commits intotier4-mainfrom
feature/reduce-train-memory

Conversation

@yukkysaito
Copy link
Copy Markdown
Collaborator

@yukkysaito yukkysaito commented Mar 31, 2026

Summary

Reduce GPU memory usage during training so the train batch size can be increased without changing the input token counts.

This PR keeps the existing model inputs intact and focuses on memory reductions in the training path.

Changes

  • enable bf16 autocast during training when the GPU supports it
  • disable unused attention weight outputs by setting need_weights=False on the attention layers used in the encoder and DiT blocks
  • move EMA shadow weights from GPU memory to CPU memory
  • add activation checkpointing to the heavy transformer blocks in the encoder fusion and decoder DiT

Expected Effect

These changes target different parts of the memory footprint:

  • bf16: reduces activation memory and temporary tensor size during training
  • need_weights=False: avoids materializing unused attention-weight tensors
  • EMA on CPU: frees one extra model copy from GPU memory
  • activation checkpointing: trades extra recomputation for lower activation memory

In practice, the main memory savings are expected to come from bf16 and activation checkpointing.

Notes

  • input token counts are unchanged
  • training behavior is unchanged in intent; the changes are limited to precision and memory-management paths
  • use_bf16 and use_activation_checkpointing are configurable and default to enabled

Validation

  • verified the touched Python files with python3 -m py_compile
  • full training / GPU memory profiling has not been run in this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant