Skip to content

Add mHC (Manifold-Constrained Hyper-Connections) fused kernels to Liger-Kernel #1066

@yukiu00

Description

@yukiu00

🚀 The feature, motivation and pitch

Background (Paper)

mHC: Manifold-Constrained Hyper-Connections (arXiv:2512.24880v2)
https://arxiv.org/abs/2512.24880

Paper alignment (what this proposal implements)

  • Constrain the residual mapping H_res via Sinkhorn-Knopp onto the doubly-stochastic set (Birkhoff polytope), restoring identity-mapping stability while keeping multi-stream residual benefits.
  • Follow the paper’s fused-kernel decomposition and recompute strategy (Sec. 4.3.1, Eq.(14)–(19)), including mixed precision (e.g., x in BF16/FP16, projections/coeffs in fp32/TF32-style).

Why this matters for Liger-Kernel

  • Strong fit with Liger-Kernel’s goals: fusion, bandwidth reduction, and activation memory efficiency.
  • mHC is a paper-defined, kernelization-friendly target with practical value for large-scale LM training.
  • Extends coverage beyond single-stream residual style blocks.

Proposal

#1065

  • Add Triton fused kernels for mHC (coeffs / Sinkhorn / apply) with forward + backward.
  • Add LigerMHC module + liger_mhc_* functional APIs following existing Liger naming.
  • Add allow_fp32 as opt-in (default remains BF16/FP16 mixed precision; intended for specific/debug use cases).
  • Add correctness tests (ops + transformer-level + convergence) and benchmarks.

Repro / Environment

  • GPU: RTX 3090 (CUDA)
  • torch: 2.10.0+cu128, cuda: 12.3
  • Benchmark measurement: warmup + median over 20 runs

Benchmark Results

Environment

Item Value
GPU NVIDIA GeForce RTX 3090
CUDA 12.8
PyTorch 2.10.0+cu128
Triton 3.6.0
Python 3.12.4
OS Linux 6.8.0-90-generic (x86_64)

Micro-Benchmarks (B=4, HC=4, C=4096, tmax=20, BF16, T=128~2048)

Kernel Forward Backward Full Memory
coeffs (3.5x faster, 77% less mem) Image Image Image Image
pre (2.2x faster, 33% less mem) Image Image Image Image
post_res (1.9x faster, 31% less mem) Image Image Image Image

End-to-End LM Benchmark (B=2, T=256, HC=4, layers=2, heads=8, vocab=4096, BF16, hidden_size=256~1024)

Kernel Forward Backward Full Memory
mhc_llama_like_lm (1.5x faster, 18% less mem) Image Image Image Image

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions