Skip to content

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142

Open
chen2021673 wants to merge 6 commits intomasterfrom
split_linear_backward
Open

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142
chen2021673 wants to merge 6 commits intomasterfrom
split_linear_backward

Conversation

@chen2021673
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 commented Apr 10, 2026

Summary

Architecture refactoring of Linear/Matmul/Outer kernels.

The core idea is separation of concerns — moving the decision of whether a gradient should be computed from the kernel layer up to the autograd layer, making kernels pure compute functions. At the same time, unified GEMM/SGEMV primitives are abstracted at the bottom layer to eliminate duplicated cuBLAS boilerplate.

Changes

  • Autograd layer: LinearBackward and MatmulBackward are each decomposed into multiple independent Dispatcher calls. The needs_input_grad checks happen at the autograd layer, invoking only the kernels actually needed.
  • Kernel layer: The monolithic LinearBackward is split into LinearBackwardInput / LinearBackwardWeight / LinearBackwardBias; MatmulBackward is split into MatmulBackwardInput / MatmulBackwardOther, with naming aligned to MatmulForward(input, other).
  • File split: Matmul kernels are extracted from linear.cc / linear.cu into dedicated cpu/matmul.cc and cuda/matmul.cu, giving each file a single responsibility.
  • GEMM primitive: New gemm.cuh / gemm.cu define the GemmParams struct and GemmCuda(), providing a unified wrapper over cublasGemmEx and cublasGemmStridedBatchedEx branching logic. GetCublasHandle() / GetCudaStream() are centrally defined and shared across linear.cu, matmul.cu, and outer.cu, eliminating duplicate definitions.
  • SGEMV primitive: New SgemvParams struct and SgemvCuda() wrap the cublasSgemv call. LinearForward and LinearBackwardInput in linear.cu take the SGEMV path when bs==1 and fp32 (more efficient for matrix-vector shapes); bf16 falls back to GemmCuda since cublasSgemv does not support it. The fp32 backward path in outer.cu is migrated to SgemvCuda as well, eliminating inline cublasSgemv calls.

@chen2021673 chen2021673 force-pushed the split_linear_backward branch 3 times, most recently from 283d083 to 23d301b Compare April 15, 2026 01:58
@chen2021673 chen2021673 requested a review from kilinchange April 15, 2026 02:08
Move grad_flags logic from kernel to autograd layer. The
monolithic LinearBackward kernel is replaced by LinearBackwardInput,
LinearBackwardWeight, and LinearBackwardBias — each a pure compute
operation with no autograd-related parameters.
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel
is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
…ls; rename MatmulBackwardInput1/2

- Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or
  cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream()
  shared across all GEMM-using kernels
- Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated
  matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels
- Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther
  for semantic clarity matching MatmulForward(input, other) parameter names
- Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths);
  keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
@chen2021673 chen2021673 force-pushed the split_linear_backward branch 2 times, most recently from ae80cec to 88579ba Compare April 28, 2026 09:06
@chen2021673 chen2021673 force-pushed the split_linear_backward branch from 88579ba to 252e6cd Compare April 28, 2026 09:21
@Chamberlain0w0
Copy link
Copy Markdown
Contributor

另外我看 cpu 的改动也挺多,但看不出什么问题,最好也辛苦验证下精度没问题

Comment thread infini_train/src/kernels/cuda/linear.cu Outdated
Comment thread infini_train/src/autograd/matmul.cc
Comment thread infini_train/src/autograd/matmul.cc Outdated
Comment thread infini_train/include/common/cuda/gemm.cuh Outdated
…s to designated initializers

- Save input1_dims_/input2_dims_ in Matmul::SetupContext to avoid Dims()
  calls on potentially-null saved tensors in Backward
- Get device from grad_output instead of input1 in Matmul::Backward
- Add CHECK guards before dereferencing nullable saved tensors
- Convert all GemmParams/SgemvParams construction in linear.cu, matmul.cu,
  outer.cu to C++20 designated initializer form
@kilinchange kilinchange requested a review from Chamberlain0w0 May 7, 2026 02:17
…evice param

GemmParams and SgemvParams are pure problem descriptions and should not
carry runtime state. Move handle acquisition into GemmCuda/SgemvCuda via
a device parameter, inline the dynamic_cast directly. Remove the public
GetCublasHandle/GetCudaStream helpers from gemm.cuh.
@chen2021673
Copy link
Copy Markdown
Contributor Author

另外我看 cpu 的改动也挺多,但看不出什么问题,最好也辛苦验证下精度没问题

img_v3_0211f_9143a2bd-b6f7-431d-95e5-ea1fc7536c0g 已验证,能完全对齐

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gemm.cuh 和 gemm.cu 都放到 src/kernel/cuda/common/ 目录下吧,include 目录下理论上只放对外提供的接口,最早一些文件没分太仔细,遗留的之后我统一改,新增文件还是遵循这个原则。

const cudaDataType_t type_c = ToCudaDataType(p.output_dtype);
// Always use CUBLAS_COMPUTE_32F: required for bf16/fp16 correctness,
// and fine for fp32 (same compute path).
const cublasComputeType_t ctype = CUBLAS_COMPUTE_32F;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

叫 compute_type 吧

// When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector).
// cublasSgemv does not support bf16, so bf16 falls through to GemmCuda.
if (bs == 1 && dtype == DataType::kFLOAT32) {
SgemvCuda(device, SgemvParams{
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看原来是 fp32 统一走 sgemm,现在改成了 bs=1&fp32 时走 sgemmv,有测试性能收益吗?如果暂时没有明显收益的话,建议这个 pr 先保持原有逻辑,后续单独优化矩阵乘性能(可能需要更复杂的分类讨论)再引入走 gemv 的逻辑。


const std::vector<int64_t> weight_dims
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
auto compute_dtype = weight->Dtype();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 compute_dtype 现在不走 input/weight 类型提升获取了,会有问题吗?

// Compute dtype determined by saved tensors (forward compute dtype), not grad_output
DataType compute_dtype = PromoteDataTypes(input_dtype, weight_dtype);
// For bf16 compute, accumulate in fp32 to preserve precision.
auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里这么写稍微有点 hack,先留个 FIXME 吧,等后面修 autograd/autocast 时看下怎么改合适。

@kilinchange
Copy link
Copy Markdown
Collaborator

麻烦贴一下测试通过的截图。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants