Refactor(linear): split LinearBackward kernel into 3 independent kernels#142
Open
chen2021673 wants to merge 6 commits intomasterfrom
Open
Refactor(linear): split LinearBackward kernel into 3 independent kernels#142chen2021673 wants to merge 6 commits intomasterfrom
chen2021673 wants to merge 6 commits intomasterfrom
Conversation
283d083 to
23d301b
Compare
Move grad_flags logic from kernel to autograd layer. The monolithic LinearBackward kernel is replaced by LinearBackwardInput, LinearBackwardWeight, and LinearBackwardBias — each a pure compute operation with no autograd-related parameters.
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
…ls; rename MatmulBackwardInput1/2 - Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream() shared across all GEMM-using kernels - Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels - Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther for semantic clarity matching MatmulForward(input, other) parameter names - Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths); keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
ae80cec to
88579ba
Compare
…es in linear kernels
88579ba to
252e6cd
Compare
Contributor
|
另外我看 cpu 的改动也挺多,但看不出什么问题,最好也辛苦验证下精度没问题 |
Chamberlain0w0
requested changes
Apr 28, 2026
…s to designated initializers - Save input1_dims_/input2_dims_ in Matmul::SetupContext to avoid Dims() calls on potentially-null saved tensors in Backward - Get device from grad_output instead of input1 in Matmul::Backward - Add CHECK guards before dereferencing nullable saved tensors - Convert all GemmParams/SgemvParams construction in linear.cu, matmul.cu, outer.cu to C++20 designated initializer form
…evice param GemmParams and SgemvParams are pure problem descriptions and should not carry runtime state. Move handle acquisition into GemmCuda/SgemvCuda via a device parameter, inline the dynamic_cast directly. Remove the public GetCublasHandle/GetCudaStream helpers from gemm.cuh.
Chamberlain0w0
approved these changes
May 7, 2026
Contributor
Author
kilinchange
requested changes
May 7, 2026
Collaborator
There was a problem hiding this comment.
gemm.cuh 和 gemm.cu 都放到 src/kernel/cuda/common/ 目录下吧,include 目录下理论上只放对外提供的接口,最早一些文件没分太仔细,遗留的之后我统一改,新增文件还是遵循这个原则。
| const cudaDataType_t type_c = ToCudaDataType(p.output_dtype); | ||
| // Always use CUBLAS_COMPUTE_32F: required for bf16/fp16 correctness, | ||
| // and fine for fp32 (same compute path). | ||
| const cublasComputeType_t ctype = CUBLAS_COMPUTE_32F; |
| // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector). | ||
| // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda. | ||
| if (bs == 1 && dtype == DataType::kFLOAT32) { | ||
| SgemvCuda(device, SgemvParams{ |
Collaborator
There was a problem hiding this comment.
我看原来是 fp32 统一走 sgemm,现在改成了 bs=1&fp32 时走 sgemmv,有测试性能收益吗?如果暂时没有明显收益的话,建议这个 pr 先保持原有逻辑,后续单独优化矩阵乘性能(可能需要更复杂的分类讨论)再引入走 gemv 的逻辑。
|
|
||
| const std::vector<int64_t> weight_dims | ||
| = transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features}; | ||
| auto compute_dtype = weight->Dtype(); |
Collaborator
There was a problem hiding this comment.
这里 compute_dtype 现在不走 input/weight 类型提升获取了,会有问题吗?
| // Compute dtype determined by saved tensors (forward compute dtype), not grad_output | ||
| DataType compute_dtype = PromoteDataTypes(input_dtype, weight_dtype); | ||
| // For bf16 compute, accumulate in fp32 to preserve precision. | ||
| auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; |
Collaborator
There was a problem hiding this comment.
这里这么写稍微有点 hack,先留个 FIXME 吧,等后面修 autograd/autocast 时看下怎么改合适。
Collaborator
|
麻烦贴一下测试通过的截图。 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Summary
Architecture refactoring of Linear/Matmul/Outer kernels.
The core idea is separation of concerns — moving the decision of whether a gradient should be computed from the kernel layer up to the autograd layer, making kernels pure compute functions. At the same time, unified GEMM/SGEMV primitives are abstracted at the bottom layer to eliminate duplicated cuBLAS boilerplate.
Changes