Skip to content

Issue with gradient accumulation and expansion_factor_real_data>1 #3381

@philip-essential

Description

@philip-essential

Bug report

expansion_factor_real_data > 1 generates placeholder data and then truncates the inputs in loss_fn to the correct size, which is supposed to eliminate the placeholder data. In gradient accumulation, from train_step we reshape the inputs to introduce a gradient_accumulation_steps dimension.

If we first truncated then reshaped, this would work correctly. However, we reshape then truncate, which means later gradient accumulation steps use the placeholder data.

I believe max_checkify does not catch this issue because it happens too early in the process.

(Internally we're on an older fork of this codebase, so I apologize if this has been fixed already. I looked through the relevant code and it looked like it would have the same issue)

Logs/Output

No response

Environment Information

No response

Additional Context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions