Unify RoPE Implementations: RotaryEncoder vs ReferenceRotaryEncoder #1213
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do? Please describe:
This PR extends both
RotaryEncoderandReferenceRotaryEncoderto support both split-half (Qwen) and consecutive (Llama) layout encodings, unifying both implementations in either class. Both have different performance implications which motivate which should be the default in the future.Does your PR introduce any breaking changes? If yes, please list them:
There should be none.
Motivation
Currently, fairseq2 has two separate RoPE implementations:
RotaryEncoder: Uses consecutive pair layout[real0, imag0, real1, imag1, ...]with in polar representationReferenceRotaryEncoder: Uses split-half layout[real0, real1, ..., imag0, imag1, ...]with explicit cos/sin operationsComplex number representations are currently not supported by inductor and warns about potentially degraded performance. This creates the assumption that
RotaryEncoder's implementation is slower, which we will find is not the case in practice./venv/main/lib/python3.12/site-packages/torch/_inductor/lowering.py:1917: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. warnings.warn(Solution
Add an
implparameter toReferenceRotaryEncoderthat supports:impl="reference"(default): Original split-half behaviorimpl="llama": New consecutive pair behavior that matchesRotaryEncoderAdd an
implparameter toRotaryEncoderthat supports:impl="llama"(default): Original consecutive behaviorimpl="reference": New split-half behaviorRotaryEncoder's implementation changes the incoming and outgoing encoding order during theforward(...)call, whileReferenceRotaryEncoder's implementation changes the frequency instantiation and reorders the encoding dimension similar to how is done now with_rotate_half_way()already.Technical Details
After verifying that both implementations are mathematically identical (I will commit the tests later depending on which implementation we choose as primary), benchmarks with
torch.compile(..., backend="inductor")show a mixed bag. Memory utilization is identical withtorch.compile.RE's performance is 1.12x faster compared to native RRE, which is why I assumed that adding the reordering to RE will make this faster in both cases, but it resulted in extreme performance degradation due to two new kernel calls being executed (
triton_poi_fused_view_as_complex_0andtriton_poi_fused_index_1), taking roughly4ms(with B=4, S=8192, E=8192). These do not happen with any other implementation.If using the reference implementation:
If using the llama implementation:
I've also played around with different implementations of the reordering using view/flatten/unsqueeze/cat, but all resulted in identical compiled versions.
Complete benchmarking log
Given this, I'd recommend to use RRE implementation as a baseline even though it degrades performance with the "llama" ordering, but it does so a bit less than the other way around.
Check list: