Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 63 additions & 23 deletions transformer_engine/common/ck_fused_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,56 @@ if(NOT _gfx1250_idx EQUAL -1)
endif()
set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES})
endif()
set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA/3rdparty/aiter")
set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA")
set(__AITER_SOURCE_DIR "${__QOLA_DIR}/build/third_party/aiter")
set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel")

set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include")
message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}")
if(NOT EXISTS "${CK_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find CK API. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()

set(AITER_INCLUDE_DIR "${__AITER_SOURCE_DIR}/csrc/include")
message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}")
if(NOT EXISTS "${AITER_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find AITER API. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()

if(NOT Python_EXECUTABLE)
find_package(Python COMPONENTS Interpreter QUIET)
endif()

# Resolve the manifest-pinned AITER commit (defines AITER_SHA) and bring the
# QoLA-managed AITER source tree to that commit before any consumer reads it
# (header validation below, header includes for the .cpp build later, and
# QoLA's own kernel build if the prebuilt cache misses).
include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake")

if(Python_EXECUTABLE)
set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml")
# Redirect GIT_CONFIG_GLOBAL to a tempfile carrying `safe.directory = *` so
# git operations inside the QoLA-managed AITER tree (and its recursive
# submodules) work in containerized builds where the bind-mounted .git is
# owned by a different UID than the build process. Mirrors the pattern in
# transformer_engine/common/CMakeLists.txt:get_git_commit().
execute_process(
COMMAND sh -c
"tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \
GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \
GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -m qola.cli checkout \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make sense to integrate safe.directory overriding to qola? The pattern with dubious ownership is probably not TE specific

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry that such behavior is a bit too authoritative for qola if that makes sense? My reasoning is that the permission scope here seems to be outside of qola and hence qola should not be the one in charge of it. I'm open to reconsidering that, but it's just my initial position.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is valid concern bearing in mind QoLA is intended to be reused by different components. May be make this behavior controllable then. It is OK to keep the things in TE, it will just require doing things this overriding twice - here and when build is called - BTW, there are comments there but not actual code change

--manifest '${__QOLA_MANIFEST}' \
--aiter-root '${__AITER_SOURCE_DIR}'; \
rc=$?; rm -f \"$tmp\"; exit $rc"
RESULT_VARIABLE AITER_CHECKOUT_RESULT
OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT
ERROR_VARIABLE AITER_CHECKOUT_ERROR
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_STRIP_TRAILING_WHITESPACE
)
if(NOT AITER_CHECKOUT_RESULT EQUAL 0)
message(FATAL_ERROR
"Failed to sync AITER source tree at ${__AITER_SOURCE_DIR} to "
"manifest-pinned commit ${AITER_SHA}.\n"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also validate that actual commit matches one detected by prebuilt.cmake? If QoLA checkout AITER unconditionally, may be keep prebuilt.cmake as-is and where it is now? I.e. QoLA fetches AITER, prebuit.cmake checks for git commit as before. It will only loose AITER_SHA value in this error message

"${AITER_CHECKOUT_OUTPUT}\n${AITER_CHECKOUT_ERROR}")
endif()
message(STATUS "[AITER] Synced ${__AITER_SOURCE_DIR} to ${AITER_SHA}")

execute_process(
COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py --mode both --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.."
COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py
--mode both
--te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.."
--aiter-root "${__AITER_SOURCE_DIR}"
RESULT_VARIABLE AITER_ARG_CHECK_RESULT
OUTPUT_VARIABLE AITER_ARG_CHECK_OUTPUT
ERROR_VARIABLE AITER_ARG_CHECK_ERROR
Expand All @@ -64,7 +86,24 @@ if(Python_EXECUTABLE)
endif()
message(STATUS "AITER API validation passed via check_aiter_mha_args.py")
else()
message(WARNING "Python interpreter not found; skipping AITER API validation.")
message(WARNING "Python interpreter not found; skipping AITER source-tree sync and API validation.")
endif()

# Sanity-check the resolved include directories now that `qola checkout` has
# materialized the AITER tree.
message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}")
if(NOT EXISTS "${CK_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find CK API at ${CK_INCLUDE_DIR}. "
"Re-run the build to let `qola checkout` clone AITER and its "
"composable_kernel submodule.")
endif()

message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}")
if(NOT EXISTS "${AITER_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find AITER API at ${AITER_INCLUDE_DIR}. "
"Re-run the build to let `qola checkout` clone AITER.")
endif()

if(DEFINED AITER_MHA_PATH)
Expand All @@ -73,16 +112,16 @@ if(DEFINED AITER_MHA_PATH)
set(__AITER_MHA_PATH ${AITER_MHA_PATH})
else()
set(__AITER_MHA_PATH "")
include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake")
get_prebuilt_aiter(__AITER_MHA_PATH)

if(__AITER_MHA_PATH STREQUAL "")
# If not available, fallback: Build from source via QoLA
list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR)
message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.")
set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA")
set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build")
set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml")
# Same GIT_CONFIG_GLOBAL trick as the earlier `qola.cli checkout` call:
# `qola.cli build` re-invokes ensure_aiter_commit internally and will hit
# the same dubious-ownership trap without it.
execute_process(
COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}"
${Python_EXECUTABLE} -m qola.cli build
Expand Down Expand Up @@ -124,7 +163,8 @@ endforeach()
add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES})
set(CK_FUSED_ATTN_COMPILE_OPTIONS)
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS
-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT})
-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}
-DENABLE_CK=1)

# Public QoLA headers ship alongside the .so libs in ${__AITER_MHA_PATH}/../include
# (emitted by qola.cli build, or copied from the QoLA build dir above for the
Expand Down
18 changes: 16 additions & 2 deletions transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,22 @@ string(STRIP "${ROCM_VER_CONTENT}" ROCM_VER_CONTENT)
string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER_CONTENT}")
string(REGEX MATCH "^[0-9]+" ROCM_VER_MAJOR "${ROCM_VER}")

# AITER commit
get_git_commit("${__AITER_SOURCE_DIR}" AITER_SHA)
# AITER commit — read from the QoLA manifest so the cache key tracks the
# commit QoLA will actually check out and build, not whatever happens to be
# the submodule's current HEAD at configure time.
set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml")
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${__QOLA_MANIFEST}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER_SHA is not cached variable. Why is CMAKE_CONFIGURE_DEPENDS needed?

file(STRINGS "${__QOLA_MANIFEST}" __AITER_COMMIT_LINES
REGEX "^[ \t]*aiter_commit[ \t]*=[ \t]*\"[^\"]+\"")
list(LENGTH __AITER_COMMIT_LINES __AITER_COMMIT_COUNT)
if(NOT __AITER_COMMIT_COUNT EQUAL 1)
message(FATAL_ERROR
"Expected exactly one 'aiter_commit = \"...\"' line in "
"${__QOLA_MANIFEST}, found ${__AITER_COMMIT_COUNT}.")
endif()
list(GET __AITER_COMMIT_LINES 0 __AITER_COMMIT_LINE)
string(REGEX MATCH "\"([^\"]+)\"" _UNUSED "${__AITER_COMMIT_LINE}")
set(AITER_SHA "${CMAKE_MATCH_1}")

# Cache key & local paths
set(AITER_CACHE_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../build/aiter-prebuilts")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def parse_with_skip_comments(buffer, line, regex, outputs):


def extract_fields_from_header(text: str, struct_name: str) -> List[str]:
struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$")
struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*|\{[^;]*\})?;\s*$")
struct_end_re = re.compile(r"^\s*};\s*$")

struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b")
Expand Down Expand Up @@ -64,11 +64,14 @@ def main() -> int:
parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition")
parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both", help="Mode: fwd, bwd, or both")
parser.add_argument("--te-dir", type=Path, default=Path(__file__).parent.parent.parent.parent, help="Root directory of TransformerEngine")
parser.add_argument("--aiter-root", type=Path, default=None,
help="AITER source tree root. Defaults to <te-dir>/3rdparty/aiter.")
args = parser.parse_args()
aiter_root = args.aiter_root if args.aiter_root else args.te_dir / "3rdparty/aiter"
modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode]
mismatch = 0
for mode in modes:
header_path = args.te_dir / f"3rdparty/aiter/csrc/include/mha_{mode}.h"
header_path = aiter_root / f"csrc/include/mha_{mode}.h"
source_path = args.te_dir / f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp"
header_text = header_path.read_text(encoding="utf-8")
source_text = source_path.read_text(encoding="utf-8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ struct CkAttnBwdArgs : CKAttnCommonArgs {
// dQ
void* dq_ptr = nullptr;
uint64_t stride_b_dq = 0, stride_h_dq = 0, stride_s_dq = 0;
void* dq_acc_ptr = nullptr;

// dK / dV expanded (MQA/GQA reduction inputs; null when h==hg)
void* dk_expanded_ptr = nullptr;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[qola]
aiter_commit = "33f2e6af5f39379c739720080ed0033d533f5cb2" # pinned AITER submodule commit
aiter_commit = "e3940660b40f4764cdf09147af96a2a764f264be" # pinned AITER submodule commit
namespace = "te"
rocm_versions = ["7.2"]

Expand All @@ -9,9 +9,11 @@ architectures = ["gfx950", "gfx942"]
[[modules]]
name = "libmha_fwd"
mode = "cpp_itfs"
receipt = 700
drop_srcs = ["mha_fwd_split.cu", "mha_fwd_batch_prefill.cu"]
drop_directions = ["fwd_splitkv", "batch_prefill"]

[[modules]]
name = "libmha_bwd"
mode = "cpp_itfs"
receipt = 700
Loading
Loading