support prefill cudagraph for gdn#1294
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a wrapper method _gdn_prefill_wrapper_run in the Qwen3Next transformer layer to handle prefill operations, specifically adding support for CUDA graph capture by managing tensor allocations and CPU-side function execution within the graph. Review feedback identifies a potential typo in the method name prefill_cuda_graph_add_cpu_runnning_func and suggests optimizing the handling of the z tensor during graph capture by using a view instead of redundant allocation and copying.
| _o.copy_(tmp_o) | ||
| return | ||
|
|
||
| infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=gdn_prefill_func, after_graph=pre_capture_graph) |
| z_shape = o_shape | ||
|
|
||
| infer_state.prefill_cuda_graph_create_graph_obj() | ||
| infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() | ||
| o = torch.empty(o_shape, dtype=o_dtype, device=o_device) | ||
| _o = tensor_to_no_ref_tensor(o) | ||
| z = torch.empty(z_shape, dtype=o_dtype, device=o_device) | ||
| _z = tensor_to_no_ref_tensor(z) | ||
|
|
||
| def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo): | ||
| conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_) | ||
| mixed_qkv, tmp_z, b, a = self._split_qkvzba(_mixed_qkvzba, is_decode=False) | ||
| _z.copy_(tmp_z) | ||
| tmp_o = self._gdn_prefill_kernel( | ||
| mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight | ||
| ) | ||
| tmp_o = tmp_o.view(_o.shape) | ||
| _o.copy_(tmp_o) | ||
| return |
There was a problem hiding this comment.
The allocation and copy of the z tensor in the CUDA graph capture path are redundant. Since z is a slice of mixed_qkvzba (which is an output of the previous graph segment), you can obtain z as a view directly during capture. This avoids an unnecessary allocation and a GPU-to-GPU copy inside the host node during replay. The subsequent z.contiguous() call in gdn_forward will handle contiguity if required by the norm kernel, and that copy will be efficiently captured in the graph.
infer_state.prefill_cuda_graph_create_graph_obj()
infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
o = torch.empty(o_shape, dtype=o_dtype, device=o_device)
_o = tensor_to_no_ref_tensor(o)
_, z, _, _ = self._split_qkvzba(mixed_qkvzba, is_decode=False)
def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo):
conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_)
mixed_qkv, _, b, a = self._split_qkvzba(_mixed_qkvzba, is_decode=False)
tmp_o = self._gdn_prefill_kernel(
mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight
)
_o.copy_(tmp_o.view(_o.shape))
return
No description provided.