Skip to content

Conversation

@zacharydenton
Copy link
Contributor

This fixes an error in interpolate_position_embeddings when batch_size is > 1 (#428).

interpolated_position_embeddings =
input_position_embeddings
|> Nx.reshape({batch_size, original_positions, original_positions, spec.hidden_size})
|> Nx.reshape({1, original_positions, original_positions, spec.hidden_size})
Copy link
Member

@jonatanklosko jonatanklosko Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zacharydenton great find! I would suggest a small improvement:

position_embeddings_batch_size = Nx.axis_size(position_embeddings, 0)

and then use that in both places.

In practice it's always 1, but hardcoding the value implies an extra assumption that the reader may need to understand, which means going back to the caller side and figuring out the position embeddings shape :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@jonatanklosko jonatanklosko changed the title Fix DinoV2 crash when batch_size > 1. Fix DinoV2 crash when batch_size > 1 Dec 18, 2025
@jonatanklosko jonatanklosko merged commit 55ec9ac into elixir-nx:main Dec 18, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants