Conversation
There was a problem hiding this comment.
Code Review
This pull request integrates the FlashQLA backend into the qwen3next model's chunk_gated_delta_rule operation, providing a high-performance alternative to FLA Triton kernels on SM90+ hardware. It includes compatibility detection logic and a new parity test suite with benchmarking. Feedback highlights that the version requirements for PyTorch (2.8) and CUDA (12.8) appear to be typos that would disable the backend on current environments, and suggests making the version parsing logic more robust. Additionally, it is recommended to explicitly calculate the scale parameter when it is None to ensure consistency with the fallback implementation.
| tv = torch.__version__.split("+")[0].split(".") | ||
| if (int(tv[0]), int(tv[1])) < (2, 8): | ||
| return None | ||
| cv = torch.version.cuda | ||
| if cv is None: | ||
| return None | ||
| cv_parts = cv.split(".") | ||
| if (int(cv_parts[0]), int(cv_parts[1])) < (12, 8): | ||
| return None |
There was a problem hiding this comment.
The version requirements for PyTorch (2.8) and CUDA (12.8) appear to be typos, as these versions are either not yet released or do not exist (CUDA 12.8). This will cause the FlashQLA backend to be disabled on all current environments. Additionally, the parsing logic is fragile and may raise IndexError or ValueError depending on the version string format (e.g., if it contains non-numeric suffixes like rc1).
try:
tv = torch.__version__.split("+")[0].split(".")
if len(tv) < 2 or (int(tv[0]), int(tv[1])) < (2, 4):
return None
cv = torch.version.cuda
if cv is None:
return None
cv_parts = cv.split(".")
if len(cv_parts) < 2 or (int(cv_parts[0]), int(cv_parts[1])) < (12, 1):
return None
except (ValueError, IndexError):
return None| return flashqla_fn( | ||
| q=q.contiguous(), | ||
| k=k.contiguous(), | ||
| v=v.contiguous(), | ||
| g=g.contiguous(), | ||
| beta=beta.contiguous(), | ||
| scale=scale, | ||
| initial_state=initial_state.contiguous() if initial_state is not None else None, | ||
| output_final_state=output_final_state, | ||
| cu_seqlens=cu_seqlens, | ||
| head_first=head_first, | ||
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, | ||
| ) |
There was a problem hiding this comment.
If scale is None, it is passed directly to flashqla_fn. The fallback Triton path explicitly calculates scale as k.shape[-1] ** -0.5. To ensure consistency and avoid potential issues if the flash_qla library does not handle None defaults, the scale should be explicitly provided.
| return flashqla_fn( | |
| q=q.contiguous(), | |
| k=k.contiguous(), | |
| v=v.contiguous(), | |
| g=g.contiguous(), | |
| beta=beta.contiguous(), | |
| scale=scale, | |
| initial_state=initial_state.contiguous() if initial_state is not None else None, | |
| output_final_state=output_final_state, | |
| cu_seqlens=cu_seqlens, | |
| head_first=head_first, | |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, | |
| ) | |
| return flashqla_fn( | |
| q=q.contiguous(), | |
| k=k.contiguous(), | |
| v=v.contiguous(), | |
| g=g.contiguous(), | |
| beta=beta.contiguous(), | |
| scale=scale if scale is not None else k.shape[-1] ** -0.5, | |
| initial_state=initial_state.contiguous() if initial_state is not None else None, | |
| output_final_state=output_final_state, | |
| cu_seqlens=cu_seqlens, | |
| head_first=head_first, | |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, | |
| ) |
No description provided.