Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a FlashQLA backend dispatch for the chunk_gated_delta_rule in the Qwen3Next model, including a comprehensive parity test and benchmark suite. Key feedback highlights that the PyTorch and CUDA version checks (>= 2.8 and >= 12.8) appear to be placeholders for unreleased versions and should be corrected. Additionally, the scale parameter should be assigned its default value before the FlashQLA dispatch to prevent potential errors if the backend receives a null value.
| if (int(tv[0]), int(tv[1])) < (2, 8): | ||
| return None | ||
| cv = torch.version.cuda | ||
| if cv is None: | ||
| return None | ||
| cv_parts = cv.split(".") | ||
| if (int(cv_parts[0]), int(cv_parts[1])) < (12, 8): |
There was a problem hiding this comment.
The version checks for PyTorch (>= 2.8) and CUDA (>= 12.8) appear to be placeholders or typos, as these versions are not yet released (current stable versions are typically PyTorch 2.5/2.6 and CUDA 12.4/12.6). As written, this logic will disable the FlashQLA backend for almost all current environments. Please verify if these should be lower versions (e.g., PyTorch 2.4 and CUDA 12.1).
| flashqla_fn = _flashqla_chunk_gated_delta_rule() | ||
| if flashqla_fn is not None and not head_first: | ||
| return flashqla_fn( | ||
| q=q.contiguous(), | ||
| k=k.contiguous(), | ||
| v=v.contiguous(), | ||
| g=g.contiguous(), | ||
| beta=beta.contiguous(), | ||
| scale=scale, | ||
| initial_state=initial_state.contiguous() if initial_state is not None else None, | ||
| output_final_state=output_final_state, | ||
| cu_seqlens=cu_seqlens, | ||
| head_first=head_first, | ||
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, | ||
| ) |
There was a problem hiding this comment.
The scale parameter is passed to flashqla_fn before it is assigned its default value. In the fallback path (lines 257-258), scale defaults to k.shape[-1] ** -0.5 if it is None. If flash_qla.chunk_gated_delta_rule does not handle None for the scale argument, this will lead to incorrect results or a crash. You should move the default scale calculation before the FlashQLA dispatch logic.
| flashqla_fn = _flashqla_chunk_gated_delta_rule() | |
| if flashqla_fn is not None and not head_first: | |
| return flashqla_fn( | |
| q=q.contiguous(), | |
| k=k.contiguous(), | |
| v=v.contiguous(), | |
| g=g.contiguous(), | |
| beta=beta.contiguous(), | |
| scale=scale, | |
| initial_state=initial_state.contiguous() if initial_state is not None else None, | |
| output_final_state=output_final_state, | |
| cu_seqlens=cu_seqlens, | |
| head_first=head_first, | |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, | |
| ) | |
| if scale is None: | |
| scale = k.shape[-1] ** -0.5 | |
| flashqla_fn = _flashqla_chunk_gated_delta_rule() | |
| if flashqla_fn is not None and not head_first: | |
| return flashqla_fn( | |
| q=q.contiguous(), | |
| k=k.contiguous(), | |
| v=v.contiguous(), | |
| g=g.contiguous(), | |
| beta=beta.contiguous(), | |
| scale=scale, | |
| initial_state=initial_state.contiguous() if initial_state is not None else None, | |
| output_final_state=output_final_state, | |
| cu_seqlens=cu_seqlens, | |
| head_first=head_first, | |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, | |
| ) | |
No description provided.