-
Notifications
You must be signed in to change notification settings - Fork 743
[TEST] add test in /root/paddlejob/workspace/env_run/output/zkk/2026_04_17FL… #7766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,18 +23,286 @@ | |
|
|
||
| paddle.set_default_dtype("bfloat16") | ||
|
|
||
| if __name__ == "__main__": | ||
| prop = paddle.device.cuda.get_device_properties() | ||
| if prop.major != 10: | ||
| exit(0) | ||
| try: | ||
| import cutlass | ||
| except ImportError: | ||
| exit(0) | ||
|
|
||
|
|
||
| import cutlass.cute as cute | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug 模块级 cutlass imports 在 pytest 运行时无条件执行 此处 5 行 建议修复方式:将顶层 cutlass 导入改为条件导入,或使用 try:
import cutlass.cute as cute
import cutlass.pipeline as pipeline
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.nvgpu import cpasync, tcgen05
except ImportError:
cute = None # 后续在 two_invoke/DenseGemmKernel 中用 pytest.skip 守卫或在测试方法入口处使用: cutlass = pytest.importorskip("cutlass") |
||
| import cutlass.pipeline as pipeline | ||
| import cutlass.utils as utils | ||
| import cutlass.utils.blackwell_helpers as sm100_utils | ||
| from cutlass.cute.nvgpu import cpasync, tcgen05 | ||
|
|
||
|
|
||
| class DenseGemmKernel: | ||
| def __init__(self): | ||
| self.num_warps = 4 | ||
| self.num_tmem_alloc_cols = 512 | ||
| self.threads_per_cta = 128 | ||
| self.a_dtype = cutlass.BFloat16 | ||
| self.b_dtype = cutlass.BFloat16 | ||
| self.acc_dtype = cutlass.Float32 | ||
|
|
||
| self.num_acc_stage = 1 | ||
| self.use_2cta_instrs = False | ||
| self.cluster_shape_mnk = (2, 1, 1) if self.use_2cta_instrs else (1, 1, 1) | ||
| self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) | ||
| self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE | ||
|
|
||
| self.mma_tiler = (128, 128, 64) | ||
|
|
||
| @cute.jit | ||
| def __call__( | ||
| self, | ||
| a: cute.Tensor, | ||
| b: cute.Tensor, | ||
| c: cute.Tensor, | ||
| ): | ||
| tiled_mma = sm100_utils.make_trivial_tiled_mma( | ||
| cutlass.BFloat16, | ||
| tcgen05.OperandMajorMode.K, | ||
| tcgen05.OperandMajorMode.K, | ||
| self.acc_dtype, | ||
| self.cta_group, | ||
| self.mma_tiler[:2], | ||
| ) | ||
| self.atom_thr_size = cute.size(tiled_mma.thr_id.shape) | ||
|
|
||
| self.cluster_layout_vmnk = cute.tiled_divide( | ||
| cute.make_layout(self.cluster_shape_mnk), | ||
| (tiled_mma.thr_id.shape,), | ||
| ) | ||
| # ((2),1,1,1):((1),0,0,0) | ||
|
|
||
| a_smem_layout_staged = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, cutlass.BFloat16, 1) | ||
| b_smem_layout_staged = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, cutlass.BFloat16, 1) | ||
|
|
||
| a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) | ||
| a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) | ||
| tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( | ||
| a_op, | ||
| a, | ||
| a_smem_layout, | ||
| self.mma_tiler, | ||
| tiled_mma, | ||
| self.cluster_layout_vmnk.shape, | ||
| internal_type=(cutlass.TFloat32 if a.element_type is cutlass.Float32 else None), | ||
| ) | ||
|
|
||
| # Setup TMA load for B | ||
| b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) | ||
| b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) | ||
| tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( | ||
| b_op, | ||
| b, | ||
| b_smem_layout, | ||
| self.mma_tiler, | ||
| tiled_mma, | ||
| self.cluster_layout_vmnk.shape, | ||
| internal_type=(cutlass.TFloat32 if b.element_type is cutlass.Float32 else None), | ||
| ) | ||
|
|
||
| a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) | ||
| b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) | ||
| self.num_tma_load_bytes = (a_copy_size + b_copy_size) * self.atom_thr_size | ||
|
|
||
| self.kernel( | ||
| tiled_mma, | ||
| a, | ||
| b, | ||
| c, | ||
| a_smem_layout_staged, | ||
| b_smem_layout_staged, | ||
| tma_atom_a, | ||
| tma_atom_b, | ||
| self.cluster_layout_vmnk, | ||
| ).launch( | ||
| grid=self.cluster_shape_mnk, | ||
| block=[128, 1, 1], | ||
| cluster=self.cluster_shape_mnk, | ||
| ) | ||
|
|
||
| # GPU device kernel | ||
| @cute.kernel | ||
| def kernel( | ||
| self, | ||
| tiled_mma, | ||
| a, | ||
| b, | ||
| c, | ||
| a_smem_layout_staged, | ||
| b_smem_layout_staged, | ||
| tma_atom_a, | ||
| tma_atom_b, | ||
| cluster_layout_vmnk: cute.Layout, | ||
| ): | ||
| warp_idx = cute.arch.warp_idx() | ||
| warp_idx = cute.arch.make_warp_uniform(warp_idx) | ||
| tidx = cute.arch.thread_idx()[0] | ||
|
|
||
| bidx, bidy, bidz = cute.arch.block_idx() | ||
| mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) | ||
| is_leader_cta = mma_tile_coord_v == 0 | ||
|
|
||
| if warp_idx == 0: | ||
| cpasync.prefetch_descriptor(tma_atom_a) | ||
| cpasync.prefetch_descriptor(tma_atom_b) | ||
|
|
||
| @cute.struct | ||
| class SharedStorage: | ||
| acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] | ||
| tmem_dealloc_mbar: cutlass.Int64 | ||
| tmem_holding_buf: cutlass.Int32 | ||
|
|
||
| smem = utils.SmemAllocator() | ||
| storage = smem.allocate(SharedStorage) | ||
|
|
||
| sA = smem.allocate_tensor( | ||
| element_type=cutlass.BFloat16, | ||
| layout=a_smem_layout_staged.outer, | ||
| byte_alignment=128, | ||
| swizzle=a_smem_layout_staged.inner, | ||
| ) | ||
|
|
||
| sB = smem.allocate_tensor( | ||
| element_type=cutlass.BFloat16, | ||
| layout=b_smem_layout_staged.outer, | ||
| byte_alignment=128, | ||
| swizzle=b_smem_layout_staged.inner, | ||
| ) | ||
|
|
||
| tmem_alloc_barrier = pipeline.NamedBarrier(barrier_id=0, num_threads=self.threads_per_cta) | ||
|
|
||
| # Tensor memory dealloc barrier init | ||
| tmem = utils.TmemAllocator( | ||
| storage.tmem_holding_buf, | ||
| barrier_for_retrieve=tmem_alloc_barrier, | ||
| is_two_cta=self.use_2cta_instrs, | ||
| two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar, | ||
| ) | ||
|
|
||
| # Alloc tensor memory buffer | ||
| tmem.allocate(self.num_tmem_alloc_cols) | ||
| tmem.wait_for_alloc() | ||
| tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) | ||
|
|
||
| # Initialize acc_pipeline (barrier) and states | ||
| acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) | ||
| acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_per_cta) | ||
| acc_pipeline = pipeline.PipelineUmmaAsync.create( | ||
| barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), | ||
| num_stages=self.num_acc_stage, | ||
| producer_group=acc_pipeline_producer_group, | ||
| consumer_group=acc_pipeline_consumer_group, | ||
| cta_layout_vmnk=cluster_layout_vmnk, | ||
| defer_sync=True, | ||
| ) | ||
|
|
||
| acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) | ||
| acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) | ||
|
|
||
| for i in cutlass.range(tidx, cute.cosize(sA), self.threads_per_cta): | ||
| if self.use_2cta_instrs: | ||
| sA[i] = a[bidx * 64 + i % 64, i // 64] | ||
| sB[i] = b[bidx * 64 + i % 64, i // 64] | ||
| else: | ||
| sA[i] = a[i] | ||
| sB[i] = b[i] | ||
|
|
||
| pipeline.sync(barrier_id=1) | ||
|
|
||
| tCrA = tiled_mma.make_fragment_A(sA) | ||
| tCrB = tiled_mma.make_fragment_B(sB) | ||
| acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) | ||
| tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) | ||
| tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) | ||
|
|
||
| if warp_idx == 0 and is_leader_cta: | ||
| blk_count = tCrA.shape[2] | ||
| tiled_mma.set(tcgen05.Field.ACCUMULATE, False) | ||
| for i in cutlass.range_constexpr(blk_count): | ||
| cute.gemm(tiled_mma, tCtAcc, tCrA[None, None, i, 0], tCrB[None, None, i, 0], tCtAcc) | ||
| tiled_mma.set(tcgen05.Field.ACCUMULATE, True) | ||
|
|
||
| acc_pipeline.producer_commit(acc_producer_state) | ||
|
|
||
| acc_pipeline.consumer_wait(acc_consumer_state) | ||
|
|
||
| tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype) | ||
|
|
||
| tCtAcc = tCtAcc[(None, None), 0, 0] | ||
|
|
||
| tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tCtAcc) | ||
|
|
||
| tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) | ||
| tTR_tAcc = tmem_thr_copy.partition_S(tCtAcc) | ||
|
|
||
| mma_tiler = (self.mma_tiler[0] // tiled_mma.thr_id.shape, self.mma_tiler[1], self.mma_tiler[2]) | ||
|
|
||
| cS = cute.make_identity_tensor(cute.select(mma_tiler, mode=[0, 1])) | ||
|
|
||
| tTR_tS = tmem_thr_copy.partition_D(cS) | ||
|
|
||
| tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) | ||
| cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) | ||
|
|
||
| if self.use_2cta_instrs: | ||
| for i in cutlass.range_constexpr(64): | ||
| c[tidx % 64 + 64 * bidx, i + tidx // 64 * 64] = (cutlass.BFloat16)(tTR_rAcc[i]) | ||
| else: | ||
| for i in cutlass.range_constexpr(128): | ||
| c[tidx, i] = (cutlass.BFloat16)(tTR_rAcc[i]) | ||
|
|
||
| pipeline.sync(barrier_id=2) | ||
| tmem.relinquish_alloc_permit() | ||
| tmem.free(tmem_ptr) | ||
|
|
||
|
|
||
| class TestDeepDenseGemm(unittest.TestCase): | ||
| def setUp(self): | ||
| pass | ||
|
|
||
| def two_invoke(self, M, N, K): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议
def two_invoke(self, M, N, K):
prop = paddle.device.cuda.get_device_properties()
if prop.major != 10:
return
try:
import cutlass
except ImportError:
return
... |
||
|
|
||
| a = paddle.randn([M, K]) | ||
| b = paddle.randn([N, K]) | ||
| baseline_out = paddle.matmul(a, b, False, True) | ||
|
|
||
| my_tensor = paddle.empty_like(baseline_out) | ||
|
|
||
| mm = DenseGemmKernel() | ||
| from cutlass.cute.runtime import from_dlpack | ||
|
|
||
| my_a = from_dlpack(a, assumed_align=16) | ||
| my_b = from_dlpack(b, assumed_align=16) | ||
| my_res = from_dlpack(my_tensor, assumed_align=16) | ||
|
|
||
| compiled_mm = cute.compile( | ||
| mm, | ||
| my_a, | ||
| my_b, | ||
| my_res, | ||
| options="--opt-level 2", | ||
| ) | ||
| compiled_mm(my_a, my_b, my_res) | ||
|
|
||
| print(my_tensor) | ||
|
|
||
| print(my_tensor - baseline_out) | ||
| assert (my_tensor - baseline_out).abs().max().item() == 0.0 | ||
|
|
||
| def one_invoke(self, M, N, K): | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| prop = paddle.device.cuda.get_device_properties() | ||
| if prop.major != 10: | ||
| try: | ||
| import deep_gemm | ||
| except: | ||
| return | ||
|
|
||
| import deep_gemm | ||
|
|
||
| block_size = 128 | ||
|
|
||
| raw_x = paddle.randn([M, K], dtype="bfloat16").cast(paddle.float8_e4m3fn) | ||
|
|
@@ -83,6 +351,7 @@ def one_invoke(self, M, N, K): | |
| ) | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 当前 assert (baseline_out - deepgemm_output).abs().max().item() < 0.1 |
||
| print(baseline_out - deepgemm_output) | ||
| # assert (baseline_out - deepgemm_output).abs().max().item() < 0.1 | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| def test_main(self): | ||
| # import paddle.profiler as profiler | ||
|
|
@@ -96,6 +365,8 @@ def test_main(self): | |
| self.one_invoke(128 * 20, 2048, 4096) | ||
| self.one_invoke(128 * 20, 2048, 2048) | ||
|
|
||
| self.two_invoke(128, 128, 64) | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| # p.stop() | ||
|
|
||
|
|
||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.