Conversation
This reverts commit 86fbbac.
tests/pytorch/test_numerics.py
Outdated
| delay_wgrad_compute, | ||
| ): | ||
| os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" | ||
| if IS_HIP_EXTENSION: |
There was a problem hiding this comment.
Is our CK grouped gemm a drop-in replacement with NV upstream CUTLASS grouped gemm? If so, we can share the same env. It's like cublaslt vs hipblaslt...
There was a problem hiding this comment.
It mostly is a drop-in replacement for upstream, so I changed the envs to the upstream versions in 259645c
| struct TileCfg_basic { | ||
| static constexpr ck_tile::index_t M_Tile = 256; | ||
| static constexpr ck_tile::index_t N_Tile = 128; | ||
| static constexpr ck_tile::index_t K_Tile = 64; | ||
|
|
||
| static constexpr ck_tile::index_t M_Warp = 2; | ||
| static constexpr ck_tile::index_t N_Warp = 2; | ||
| static constexpr ck_tile::index_t K_Warp = 1; | ||
|
|
||
| static constexpr ck_tile::index_t M_Warp_Tile = 32; | ||
| static constexpr ck_tile::index_t N_Warp_Tile = 32; | ||
| static constexpr ck_tile::index_t K_Warp_Tile = 16; | ||
|
|
||
| static constexpr bool kPadM = true; | ||
| static constexpr bool kPadN = true; | ||
| static constexpr bool kPadK = true; | ||
|
|
||
| static constexpr bool DoubleSmemBuffer = false; | ||
|
|
||
| static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; | ||
| static constexpr ck_tile::index_t TilePartitionerM01 = 1; | ||
| }; | ||
|
|
||
| template <typename AType, typename BType, typename CType, | ||
| typename ALayout, typename BLayout, typename CLayout, | ||
| typename TileCfg, ck_tile::memory_operation_enum MemOp, | ||
| typename AccType = float> | ||
| class Runner{ | ||
| public: | ||
| using GemmShape = ck_tile::TileGemmShape< | ||
| ck_tile::sequence<TileCfg::M_Tile, TileCfg::N_Tile, TileCfg::K_Tile>, | ||
| ck_tile::sequence<TileCfg::M_Warp, TileCfg::N_Warp, TileCfg::K_Warp>, | ||
| ck_tile::sequence<TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile>>; | ||
|
|
||
| using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< | ||
| GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; | ||
|
|
||
| using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< | ||
| TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, | ||
| TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; | ||
|
|
||
| static constexpr ck_tile::GemmPipelineScheduler Scheduler = | ||
| ck_tile::GemmPipelineScheduler::Intrawave; | ||
|
|
||
| using Problem = ck_tile::UniversalGemmPipelineProblem< | ||
| AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; | ||
|
|
||
| using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>; | ||
|
|
||
| using Epilogue = ck_tile::CShuffleEpilogue< | ||
| ck_tile::CShuffleEpilogueProblem< | ||
| AType, BType, ck_tile::tuple<>, AccType, | ||
| CType, ck_tile::tuple<>, CLayout, | ||
| ck_tile::element_wise::PassThrough, | ||
| Partitioner::MPerBlock, Partitioner::NPerBlock, | ||
| TileCfg::M_Warp, TileCfg::N_Warp, | ||
| TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, | ||
| Problem::TransposeC, MemOp>>; | ||
|
|
||
| using Kernel = ck_tile::GroupedGemmKernel<Partitioner, Pipeline, Epilogue>; | ||
| }; |
There was a problem hiding this comment.
Are these codes from CK repo? If so, can you add a comment to point to the reference?
There was a problem hiding this comment.
I can see the comment with the reference to CK repo, so I am resolving this.
| std::vector<ck_tile::GroupedGemmHostArgs<0>> descs; | ||
| descs.reserve(group_num); |
There was a problem hiding this comment.
Why not put group_num inside the desc vector definition?
There was a problem hiding this comment.
I used reserve() here instead of std::vector<ck_tile::GroupedGemmHostArgs<0>> descs(group_num); to avoid default-constructing GroupedGemmHostArgs objects that are immediately overwritten, to reduce construction overhead.
| using R = Runner<T, T, T, ALayout, BLayout, CLayout, TileCfg_basic, MemOp>; | ||
| using Kernel = typename R::Kernel; |
There was a problem hiding this comment.
This R is not used anywhere else
There was a problem hiding this comment.
I merged R into the next line in fac7c11.
| if (a.shape.size() != 2 || b.shape.size() != 2 || d.shape.size() != 2) { | ||
| NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Does grouped gemm support generalized matrices from high-dimensional tensors? Regular gemm supports that. And TE treat the last dim as col with other dimensions as row:
TransformerEngine/transformer_engine/common/common.h
Lines 238 to 262 in 9d6b0e5
There was a problem hiding this comment.
I added (untested) support for higher-dim tensors in dd3ed2f
| } | ||
| } | ||
|
|
||
| bool grouped_gemm_ck_tile(const NVTETensor* A, |
There was a problem hiding this comment.
Why do we overload this function? In cublaslt_gemm.cu, it's only called by this signature. Perhaps we can rename the grouped_gemm_ck_tile in line 255
There was a problem hiding this comment.
I simplified this in 259645c so that there is no more overload (only this signature remains).
| transformer_engine::getenv<bool>("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", false); | ||
|
|
||
| auto is_supported_dtype = [&]() -> bool { | ||
| auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); |
There was a problem hiding this comment.
Is it possible that num_group=0 so A[0] access not valid?
There was a problem hiding this comment.
| set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) | ||
|
|
||
| target_include_directories(transformer_engine | ||
| BEFORE PRIVATE |
There was a problem hiding this comment.
Why using keyword BEFORE in this target_include_directories? Is it because cmake will not be able to find the correct header files without prioritizing the ck include dirs?
There was a problem hiding this comment.
I removed BEFORE in 259645c, compilation still seems to work fine.
| target_include_directories(transformer_engine PUBLIC | ||
| "${CMAKE_CURRENT_SOURCE_DIR}/include") | ||
|
|
||
| set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) |
There was a problem hiding this comment.
CMAKE_SOURCE_DIR --> CMAKE_CURRENT_SOURCE_DIR? Not sure whether other upstream libs will depend on us but let's make it future proof
There was a problem hiding this comment.
Changed to in CMAKE_CURRENT_SOURCE_DIR in 259645c.
| #include "common/util/cuda_runtime.h" | ||
| #include "common/util/system.h" | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include "cutlass_grouped_gemm.cuh" |
There was a problem hiding this comment.
NV upstream made another .cu file for their cutlass_grouped_gemm and compiled it separately. Maybe we can follow their structure for better isolation (avoid CK defining some macros contaminating our cublaslt_gemm.cu)
There was a problem hiding this comment.
I restructured this to a cpp file and a header file in 259645c.
2095d3f to
ebc005f
Compare
d1ab38e to
0b16287
Compare
Description
See https://github.com/ROCm/frameworks-internal/issues/13792 for context.
TODOs:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: