Skip to content

refactor checkpointing#6

Merged
nkern merged 2 commits into
mainfrom
refactor_checkpoint
May 5, 2026
Merged

refactor checkpointing#6
nkern merged 2 commits into
mainfrom
refactor_checkpoint

Conversation

@nkern
Copy link
Copy Markdown
Owner

@nkern nkern commented May 1, 2026

refactor checkpointing to use accelerate.save_state(). no more pickles for noise_scheduler, lr_scheduler, and optimizer.

due to how pytorch saves state dicts, this means we have to "know" the class names before loading, which is why we now write a ckpt_config.yaml to the checkpoint directory with this info.

In principle, we could get this info from the original config.yaml file, but the idea here is to make train() be able to operate on its own, without needing a config.yaml

also fixed a bug that allows the training to be resumed from a checkpoint

@nkern nkern merged commit fc74b3a into main May 5, 2026
2 checks passed
@nkern nkern deleted the refactor_checkpoint branch May 5, 2026 19:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant