Skip to content

support prefill cudagraph for gdn#1294

Merged
hiworldwzj merged 4 commits intomainfrom
pr-cudagraph
May 8, 2026
Merged

support prefill cudagraph for gdn#1294
hiworldwzj merged 4 commits intomainfrom
pr-cudagraph

Conversation

@WANDY666
Copy link
Copy Markdown
Contributor

@WANDY666 WANDY666 commented May 8, 2026

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a wrapper method _gdn_prefill_wrapper_run in the Qwen3Next transformer layer to handle prefill operations, specifically adding support for CUDA graph capture by managing tensor allocations and CPU-side function execution within the graph. Review feedback identifies a potential typo in the method name prefill_cuda_graph_add_cpu_runnning_func and suggests optimizing the handling of the z tensor during graph capture by using a view instead of redundant allocation and copying.

_o.copy_(tmp_o)
return

infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=gdn_prefill_func, after_graph=pre_capture_graph)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a likely typo in the method name prefill_cuda_graph_add_cpu_runnning_func. It contains three 'n's in 'runnning'. Please verify if this matches the intended method name in the base class or infer_state object.

Comment on lines +298 to +316
z_shape = o_shape

infer_state.prefill_cuda_graph_create_graph_obj()
infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
o = torch.empty(o_shape, dtype=o_dtype, device=o_device)
_o = tensor_to_no_ref_tensor(o)
z = torch.empty(z_shape, dtype=o_dtype, device=o_device)
_z = tensor_to_no_ref_tensor(z)

def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo):
conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_)
mixed_qkv, tmp_z, b, a = self._split_qkvzba(_mixed_qkvzba, is_decode=False)
_z.copy_(tmp_z)
tmp_o = self._gdn_prefill_kernel(
mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight
)
tmp_o = tmp_o.view(_o.shape)
_o.copy_(tmp_o)
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The allocation and copy of the z tensor in the CUDA graph capture path are redundant. Since z is a slice of mixed_qkvzba (which is an output of the previous graph segment), you can obtain z as a view directly during capture. This avoids an unnecessary allocation and a GPU-to-GPU copy inside the host node during replay. The subsequent z.contiguous() call in gdn_forward will handle contiguity if required by the norm kernel, and that copy will be efficiently captured in the graph.

            infer_state.prefill_cuda_graph_create_graph_obj()
            infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
            o = torch.empty(o_shape, dtype=o_dtype, device=o_device)
            _o = tensor_to_no_ref_tensor(o)
            _, z, _, _ = self._split_qkvzba(mixed_qkvzba, is_decode=False)

            def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo):
                conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_)
                mixed_qkv, _, b, a = self._split_qkvzba(_mixed_qkvzba, is_decode=False)
                tmp_o = self._gdn_prefill_kernel(
                    mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight
                )
                _o.copy_(tmp_o.view(_o.shape))
                return

@hiworldwzj hiworldwzj merged commit e1f8723 into main May 8, 2026
1 check passed
@hiworldwzj hiworldwzj deleted the pr-cudagraph branch May 8, 2026 09:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants