Skip to content

Conversation

@dxqb
Copy link
Contributor

@dxqb dxqb commented Dec 2, 2025

addresses #12776

What does this PR do?

This PR keeps the tuples, but moves the splitting from tensors into tuples of tensors to the transformer blocks, to avoid issues with checkpointing. By passing a tensor directly, torch.utils.checkpoint() identifies the tensor and saves it accordingly without running a backward through it multiple times.

This is a draft. If you agree with this change I can make it nicer. Among other things:

  • type hints are incorrect
  • splitting might not be necessary anymore, because they are used immediately after

Who can review?

@yiyixuxu and @asomoza

@github-actions
Copy link
Contributor

github-actions bot commented Jan 9, 2026

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jan 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Issues that haven't received updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant