Skip to content

Conversation

@cirquit
Copy link
Contributor

@cirquit cirquit commented Jun 20, 2025

What does this PR do? Please describe:
This PR extends both RotaryEncoder and ReferenceRotaryEncoderto support both split-half (Qwen) and consecutive (Llama) layout encodings, unifying both implementations in either class. Both have different performance implications which motivate which should be the default in the future.

Does your PR introduce any breaking changes? If yes, please list them:
There should be none.

Motivation

Currently, fairseq2 has two separate RoPE implementations:

  • RotaryEncoder: Uses consecutive pair layout [real0, imag0, real1, imag1, ...] with in polar representation
  • ReferenceRotaryEncoder: Uses split-half layout [real0, real1, ..., imag0, imag1, ...] with explicit cos/sin operations

Complex number representations are currently not supported by inductor and warns about potentially degraded performance. This creates the assumption that RotaryEncoder's implementation is slower, which we will find is not the case in practice.

/venv/main/lib/python3.12/site-packages/torch/_inductor/lowering.py:1917: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(

Solution

Add an impl parameter to ReferenceRotaryEncoder that supports:

  • impl="reference" (default): Original split-half behavior
  • impl="llama": New consecutive pair behavior that matches RotaryEncoder

Add an impl parameter to RotaryEncoder that supports:

  • impl="llama" (default): Original consecutive behavior
  • impl="reference": New split-half behavior

RotaryEncoder's implementation changes the incoming and outgoing encoding order during the forward(...) call, while ReferenceRotaryEncoder's implementation changes the frequency instantiation and reorders the encoding dimension similar to how is done now with _rotate_half_way() already.

Technical Details

After verifying that both implementations are mathematically identical (I will commit the tests later depending on which implementation we choose as primary), benchmarks with torch.compile(..., backend="inductor") show a mixed bag. Memory utilization is identical with torch.compile.

RE's performance is 1.12x faster compared to native RRE, which is why I assumed that adding the reordering to RE will make this faster in both cases, but it resulted in extreme performance degradation due to two new kernel calls being executed (triton_poi_fused_view_as_complex_0 and triton_poi_fused_index_1), taking roughly 4ms (with B=4, S=8192, E=8192). These do not happen with any other implementation.

If using the reference implementation:

  • RE + "reference": 0.66x (degraded performance)
  • RRE: 1.0x (speedup to the current implementation)

If using the llama implementation:

  • RE: 1.0x (speedup to the current implementation)
  • RRE + "llama": 0.86x (degraded performance)

I've also played around with different implementations of the reordering using view/flatten/unsqueeze/cat, but all resulted in identical compiled versions.

Complete benchmarking log
============================= test session starts ==============================
platform linux -- Python 3.12.3, pytest-8.4.1, pluggy-1.6.0 -- /venv/main/bin/python3.12
cachedir: .pytest_cache
rootdir: /root/fairseq2
configfile: pyproject.toml
plugins: asyncio-1.0.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=function, asyncio_default_test_loop_scope=function
collecting ... collected 55 items

benchmark.py::test_benchmark_positional_encoders[False-2048-1-1024] B1_S2048_E1024  Eager        0.1ms      0.4ms    0.2ms     0.3ms    3.95x    1.19x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-2048-1-2048] B1_S2048_E2048  Eager        0.1ms      0.7ms    0.4ms     0.5ms    4.62x    1.21x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-2048-1-4096] B1_S2048_E4096  Eager        0.2ms      1.3ms    0.7ms     0.9ms    5.09x    1.23x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-2048-4-1024] B4_S2048_E1024  Eager        0.2ms      1.3ms    0.7ms     0.9ms    5.15x    1.21x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-2048-4-2048] B4_S2048_E2048  Eager        0.5ms      2.5ms    1.5ms     1.7ms    5.05x    1.18x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-2048-4-4096] B4_S2048_E4096  Eager        0.9ms      4.8ms    3.0ms     3.5ms    5.62x    1.15x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-2048-8-1024] B8_S2048_E1024  Eager        0.5ms      2.5ms    1.5ms     1.7ms    5.46x    1.19x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-2048-8-2048] B8_S2048_E2048  Eager        0.9ms      4.8ms    3.0ms     3.4ms    5.62x    1.16x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-2048-8-4096] B8_S2048_E4096  Eager        1.7ms      9.7ms    6.0ms     6.9ms    5.75x    1.14x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-4096-1-1024] B1_S4096_E1024  Eager        0.1ms      0.7ms    0.4ms     0.5ms    4.62x    1.20x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-4096-1-2048] B1_S4096_E2048  Eager        0.2ms      1.3ms    0.7ms     0.9ms    5.08x    1.20x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-4096-1-4096] B1_S4096_E4096  Eager        0.5ms      2.5ms    1.5ms     1.8ms    5.36x    1.18x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-4096-4-1024] B4_S4096_E1024  Eager        0.5ms      2.5ms    1.5ms     1.7ms    5.45x    1.14x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-4096-4-2048] B4_S4096_E2048  Eager        0.9ms      4.8ms    3.0ms     3.5ms    5.63x    1.16x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-4096-4-4096] B4_S4096_E4096  Eager        1.7ms      9.6ms    6.0ms     6.9ms    5.74x    1.14x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-4096-8-1024] B8_S4096_E1024  Eager        0.9ms      4.8ms    3.0ms     3.4ms    5.63x    1.16x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-4096-8-2048] B8_S4096_E2048  Eager        1.7ms      9.6ms    6.0ms     6.9ms    5.74x    1.14x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-4096-8-4096] B8_S4096_E4096  Eager        3.3ms     19.1ms   12.2ms    13.6ms    5.77x    1.12x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-16384-1-1024] B1_S16384_E1024 Eager        0.5ms      2.5ms    1.5ms     1.7ms    5.36x    1.18x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-16384-1-2048] B1_S16384_E2048 Eager        0.9ms      4.8ms    3.0ms     3.4ms    5.52x    1.15x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-16384-1-4096] B1_S16384_E4096 Eager        1.7ms      9.7ms    6.1ms     6.9ms    5.61x    1.14x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-16384-4-1024] B4_S16384_E1024 Eager        1.7ms      9.6ms    6.0ms     6.9ms    5.74x    1.14x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-16384-4-2048] B4_S16384_E2048 Eager        3.3ms     19.1ms   12.3ms    13.6ms    5.77x    1.11x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-16384-4-4096] B4_S16384_E4096 Eager        6.6ms     38.1ms   24.7ms    27.2ms    5.79x    1.10x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-16384-8-1024] B8_S16384_E1024 Eager        3.3ms     19.1ms   12.3ms    13.6ms    5.77x    1.11x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-16384-8-2048] B8_S16384_E2048 Eager        6.6ms     38.2ms   24.8ms    27.2ms    5.80x    1.10x
PASSED
benchmark.py::test_benchmark_positional_encoders[False-16384-8-4096] B8_S16384_E4096 Eager       13.1ms     76.4ms   50.1ms    54.6ms    5.82x    1.09x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-2048-1-1024] B1_S2048_E1024  Compiled     0.2ms      0.2ms    0.2ms     0.2ms    1.01x    0.74x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-2048-1-2048] B1_S2048_E2048  Compiled     0.3ms      0.3ms    0.4ms     0.3ms    1.07x    0.70x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-2048-1-4096] B1_S2048_E4096  Compiled     0.5ms      0.5ms    0.7ms     0.5ms    1.07x    0.68x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-2048-4-1024] B4_S2048_E1024  Compiled     0.5ms      0.5ms    0.7ms     0.5ms    1.06x    0.69x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-2048-4-2048] B4_S2048_E2048  Compiled     0.8ms      0.9ms    1.3ms     0.9ms    1.14x    0.67x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-2048-4-4096] B4_S2048_E4096  Compiled     1.5ms      1.7ms    2.6ms     1.7ms    1.15x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-2048-8-1024] B8_S2048_E1024  Compiled     0.9ms      0.9ms    1.3ms     0.9ms    1.06x    0.67x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-2048-8-2048] B8_S2048_E2048  Compiled     1.5ms      1.7ms    2.6ms     1.7ms    1.15x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-2048-8-4096] B8_S2048_E4096  Compiled     2.9ms      3.3ms    5.1ms     3.4ms    1.16x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-4096-1-1024] B1_S4096_E1024  Compiled     0.3ms      0.3ms    0.4ms     0.3ms    1.03x    0.70x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-4096-1-2048] B1_S4096_E2048  Compiled     0.4ms      0.5ms    0.7ms     0.5ms    1.11x    0.68x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-4096-1-4096] B1_S4096_E4096  Compiled     0.8ms      0.9ms    1.3ms     0.9ms    1.10x    0.67x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-4096-4-1024] B4_S4096_E1024  Compiled     0.8ms      0.9ms    1.3ms     0.9ms    1.07x    0.67x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-4096-4-2048] B4_S4096_E2048  Compiled     1.5ms      1.7ms    2.6ms     1.7ms    1.16x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-4096-4-4096] B4_S4096_E4096  Compiled     2.9ms      3.3ms    5.1ms     3.4ms    1.16x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-4096-8-1024] B8_S4096_E1024  Compiled     1.6ms      1.7ms    2.6ms     1.7ms    1.11x    0.67x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-4096-8-2048] B8_S4096_E2048  Compiled     2.9ms      3.3ms    5.1ms     3.4ms    1.16x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-4096-8-4096] B8_S4096_E4096  Compiled     5.7ms      6.6ms   10.0ms     6.6ms    1.17x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-16384-1-1024] B1_S16384_E1024 Compiled     0.9ms      0.9ms    1.4ms     0.9ms    1.05x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-16384-1-2048] B1_S16384_E2048 Compiled     1.5ms      1.7ms    2.6ms     1.7ms    1.14x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-16384-1-4096] B1_S16384_E4096 Compiled     2.9ms      3.3ms    5.1ms     3.3ms    1.14x    0.65x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-16384-4-1024] B4_S16384_E1024 Compiled     2.9ms      3.4ms    5.1ms     3.4ms    1.14x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-16384-4-2048] B4_S16384_E2048 Compiled     5.6ms      6.6ms   10.0ms     6.6ms    1.17x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-16384-4-4096] B4_S16384_E4096 Compiled    11.2ms     13.1ms   20.0ms    13.1ms    1.17x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-16384-8-1024] B8_S16384_E1024 Compiled     5.7ms      6.6ms   10.0ms     6.6ms    1.15x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-16384-8-2048] B8_S16384_E2048 Compiled    11.2ms     13.1ms   20.0ms    13.1ms    1.17x    0.66x
PASSED
benchmark.py::test_benchmark_positional_encoders[True-16384-8-4096] B8_S16384_E4096 Compiled    22.3ms     26.1ms   39.7ms    26.2ms    1.17x    0.66x
PASSED
benchmark.py::test_zzz_summary 
Config          Mode     RE_Llama  RRE_Llama  RE_Ref   RRE_Ref   RRE/RE_L RRE/RE_R
------------------------------------------------------------------------------------------
B1_S16384_E1024 Eager        0.5ms      2.5ms    1.5ms     1.7ms    5.36x    1.18x
B1_S16384_E2048 Eager        0.9ms      4.8ms    3.0ms     3.4ms    5.52x    1.15x
B1_S16384_E4096 Eager        1.7ms      9.7ms    6.1ms     6.9ms    5.61x    1.14x
B1_S2048_E1024  Eager        0.1ms      0.4ms    0.2ms     0.3ms    3.95x    1.19x
B1_S2048_E2048  Eager        0.1ms      0.7ms    0.4ms     0.5ms    4.62x    1.21x
B1_S2048_E4096  Eager        0.2ms      1.3ms    0.7ms     0.9ms    5.09x    1.23x
B1_S4096_E1024  Eager        0.1ms      0.7ms    0.4ms     0.5ms    4.62x    1.20x
B1_S4096_E2048  Eager        0.2ms      1.3ms    0.7ms     0.9ms    5.08x    1.20x
B1_S4096_E4096  Eager        0.5ms      2.5ms    1.5ms     1.8ms    5.36x    1.18x
B4_S16384_E1024 Eager        1.7ms      9.6ms    6.0ms     6.9ms    5.74x    1.14x
B4_S16384_E2048 Eager        3.3ms     19.1ms   12.3ms    13.6ms    5.77x    1.11x
B4_S16384_E4096 Eager        6.6ms     38.1ms   24.7ms    27.2ms    5.79x    1.10x
B4_S2048_E1024  Eager        0.2ms      1.3ms    0.7ms     0.9ms    5.15x    1.21x
B4_S2048_E2048  Eager        0.5ms      2.5ms    1.5ms     1.7ms    5.05x    1.18x
B4_S2048_E4096  Eager        0.9ms      4.8ms    3.0ms     3.5ms    5.62x    1.15x
B4_S4096_E1024  Eager        0.5ms      2.5ms    1.5ms     1.7ms    5.45x    1.14x
B4_S4096_E2048  Eager        0.9ms      4.8ms    3.0ms     3.5ms    5.63x    1.16x
B4_S4096_E4096  Eager        1.7ms      9.6ms    6.0ms     6.9ms    5.74x    1.14x
B8_S16384_E1024 Eager        3.3ms     19.1ms   12.3ms    13.6ms    5.77x    1.11x
B8_S16384_E2048 Eager        6.6ms     38.2ms   24.8ms    27.2ms    5.80x    1.10x
B8_S16384_E4096 Eager       13.1ms     76.4ms   50.1ms    54.6ms    5.82x    1.09x
B8_S2048_E1024  Eager        0.5ms      2.5ms    1.5ms     1.7ms    5.46x    1.19x
B8_S2048_E2048  Eager        0.9ms      4.8ms    3.0ms     3.4ms    5.62x    1.16x
B8_S2048_E4096  Eager        1.7ms      9.7ms    6.0ms     6.9ms    5.75x    1.14x
B8_S4096_E1024  Eager        0.9ms      4.8ms    3.0ms     3.4ms    5.63x    1.16x
B8_S4096_E2048  Eager        1.7ms      9.6ms    6.0ms     6.9ms    5.74x    1.14x
B8_S4096_E4096  Eager        3.3ms     19.1ms   12.2ms    13.6ms    5.77x    1.12x
B1_S16384_E1024 Compiled     0.9ms      0.9ms    1.4ms     0.9ms    1.05x    0.66x
B1_S16384_E2048 Compiled     1.5ms      1.7ms    2.6ms     1.7ms    1.14x    0.66x
B1_S16384_E4096 Compiled     2.9ms      3.3ms    5.1ms     3.3ms    1.14x    0.65x
B1_S2048_E1024  Compiled     0.2ms      0.2ms    0.2ms     0.2ms    1.01x    0.74x
B1_S2048_E2048  Compiled     0.3ms      0.3ms    0.4ms     0.3ms    1.07x    0.70x
B1_S2048_E4096  Compiled     0.5ms      0.5ms    0.7ms     0.5ms    1.07x    0.68x
B1_S4096_E1024  Compiled     0.3ms      0.3ms    0.4ms     0.3ms    1.03x    0.70x
B1_S4096_E2048  Compiled     0.4ms      0.5ms    0.7ms     0.5ms    1.11x    0.68x
B1_S4096_E4096  Compiled     0.8ms      0.9ms    1.3ms     0.9ms    1.10x    0.67x
B4_S16384_E1024 Compiled     2.9ms      3.4ms    5.1ms     3.4ms    1.14x    0.66x
B4_S16384_E2048 Compiled     5.6ms      6.6ms   10.0ms     6.6ms    1.17x    0.66x
B4_S16384_E4096 Compiled    11.2ms     13.1ms   20.0ms    13.1ms    1.17x    0.66x
B4_S2048_E1024  Compiled     0.5ms      0.5ms    0.7ms     0.5ms    1.06x    0.69x
B4_S2048_E2048  Compiled     0.8ms      0.9ms    1.3ms     0.9ms    1.14x    0.67x
B4_S2048_E4096  Compiled     1.5ms      1.7ms    2.6ms     1.7ms    1.15x    0.66x
B4_S4096_E1024  Compiled     0.8ms      0.9ms    1.3ms     0.9ms    1.07x    0.67x
B4_S4096_E2048  Compiled     1.5ms      1.7ms    2.6ms     1.7ms    1.16x    0.66x
B4_S4096_E4096  Compiled     2.9ms      3.3ms    5.1ms     3.4ms    1.16x    0.66x
B8_S16384_E1024 Compiled     5.7ms      6.6ms   10.0ms     6.6ms    1.15x    0.66x
B8_S16384_E2048 Compiled    11.2ms     13.1ms   20.0ms    13.1ms    1.17x    0.66x
B8_S16384_E4096 Compiled    22.3ms     26.1ms   39.7ms    26.2ms    1.17x    0.66x
B8_S2048_E1024  Compiled     0.9ms      0.9ms    1.3ms     0.9ms    1.06x    0.67x
B8_S2048_E2048  Compiled     1.5ms      1.7ms    2.6ms     1.7ms    1.15x    0.66x
B8_S2048_E4096  Compiled     2.9ms      3.3ms    5.1ms     3.4ms    1.16x    0.66x
B8_S4096_E1024  Compiled     1.6ms      1.7ms    2.6ms     1.7ms    1.11x    0.67x
B8_S4096_E2048  Compiled     2.9ms      3.3ms    5.1ms     3.4ms    1.16x    0.66x
B8_S4096_E4096  Compiled     5.7ms      6.6ms   10.0ms     6.6ms    1.17x    0.66x

Summary:
Eager    - RRE/RE Llama: 5.43x, RRE/RE Reference: 1.16x
Compiled - RRE/RE Llama: 1.12x, RRE/RE Reference: 0.67x

======================== 55 passed in 97.40s (0:01:37) =========================

Given this, I'd recommend to use RRE implementation as a baseline even though it degrades performance with the "llama" ordering, but it does so a bit less than the other way around.

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

cirquit added 2 commits June 20, 2025 19:20
* Added impl flag to choose between the llama and reference implementation
* Added impl flag to choose between the llama and reference implementation
@cirquit cirquit requested a review from cbalioglu as a code owner June 20, 2025 23:28
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 20, 2025
Copy link

@djsaunde djsaunde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me; just missing tests, but these come later based on the PR description.

theta: float = 10_000.0,
freqs_init_fn: Callable[[RotaryEncoder], Tensor] | None = None,
device: Device | None = None,
impl: str = "llama"
Copy link

@djsaunde djsaunde Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could use Literal["value1", "value2", ...] typing for better imputation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants