-
Notifications
You must be signed in to change notification settings - Fork 494
Description
Bug report
expansion_factor_real_data > 1 generates placeholder data and then truncates the inputs in loss_fn to the correct size, which is supposed to eliminate the placeholder data. In gradient accumulation, from train_step we reshape the inputs to introduce a gradient_accumulation_steps dimension.
If we first truncated then reshaped, this would work correctly. However, we reshape then truncate, which means later gradient accumulation steps use the placeholder data.
I believe max_checkify does not catch this issue because it happens too early in the process.
(Internally we're on an older fork of this codebase, so I apologize if this has been fixed already. I looked through the relevant code and it looked like it would have the same issue)
Logs/Output
No response
Environment Information
No response
Additional Context
No response