Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions build_tools/build_ext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -25,7 +25,6 @@
cmake_bin,
debug_build_enabled,
found_ninja,
get_frameworks,
nvcc_path,
get_max_jobs_for_parallel_build,
)
Expand Down Expand Up @@ -158,8 +157,11 @@ def run(self) -> None:
def build_extensions(self):
# For core lib + JAX install, fix build_ext from pybind11.setup_helpers
# to handle CUDA files correctly.
# Upstream uses get_frameworks() here which is incorrectly works when install from
# release (sdist) wheel on a system with both frameworks installed.
ext_names = [ext.name for ext in self.extensions]
if "transformer_engine_pytorch" not in ext_names:
if ("transformer_engine_torch" not in ext_names and
"transformer_engine_rocm_torch" not in ext_names):
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
for ext in self.extensions:
Expand Down
20 changes: 18 additions & 2 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -208,6 +208,7 @@ def rocm_build() -> bool:
# If neither ROCm nor CUDA is detected, raise an error
raise FileNotFoundError("Could not detect ROCm or CUDA platform")


@functools.lru_cache(maxsize=None)
def rocm_path() -> Tuple[str, str]:
"""ROCm root path and HIPCC binary path as a tuple"""
Expand All @@ -227,6 +228,18 @@ def rocm_path() -> Tuple[str, str]:
return rocm_home, hipcc_bin


def rocm_version() -> Tuple[int, ...]:
"""ROCm version as a (major, minor) tuple.
Try to get ROCm version by parsing .info/version.
"""
rocm_home, _ = rocm_path()
try:
with open(rocm_home / ".info" / "version", "r") as f:
rocm_version= f.read().strip().split('.')[:2]
return tuple(int(v) for v in rocm_version)
except FileNotFoundError:
raise RuntimeError("Could not determine ROCm version.")


def cuda_toolkit_include_path() -> Tuple[str, str]:
"""Returns root path for cuda toolkit includes.
Expand Down Expand Up @@ -495,10 +508,13 @@ def uninstall_te_wheel_packages():
"pip",
"uninstall",
"-y",
"transformer_engine_rocm", # te_cuda_vers for ROCm build
"transformer_engine",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_jax",
"transformer_engine_rocm",
"transformer_engine_rocm_jax",
"transformer_engine_rocm_torch",
]
)

Expand Down
4 changes: 2 additions & 2 deletions build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

Expand Down Expand Up @@ -45,4 +45,4 @@ COPY build_wheels.sh /
WORKDIR /TransformerEngine/
RUN git clone https://github.com/ROCm/TransformerEngine.git /TransformerEngine

CMD ["/bin/bash", "/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"]
CMD ["/bin/bash", "/build_wheels.sh", "manylinux_2_28_x86_64", "false", "true", "true", "true"]
106 changes: 60 additions & 46 deletions build_tools/wheel_utils/build_wheels.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -14,11 +14,13 @@ BUILD_JAX=${5:-true}

export NVTE_RELEASE_BUILD=1
export TARGET_BRANCH=${TARGET_BRANCH:-}
mkdir -p /wheelhouse/logs

WHEEL_ROOT=${WHEEL_ROOT:-/wheelhouse}
mkdir -p $WHEEL_ROOT/logs

# Generate wheels for common library.
git config --global --add safe.directory /TransformerEngine
cd /TransformerEngine
TE_ROOT=${TE_ROOT:-/TransformerEngine}
cd $TE_ROOT

#If there is default Python installation, use it
PYTHON=`which python || true`
Expand All @@ -29,90 +31,102 @@ else
fi

ROCM_BUILD=`${PYBINDIR}python -c "import build_tools.utils as u; print(int(u.rocm_build()))"`

if [ "$LOCAL_TREE_BUILD" != "1" ]; then
if [ "$ROCM_BUILD" = "1" ]; then
git pull
fi
git checkout $TARGET_BRANCH
git submodule update --init --recursive
if [ "$ROCM_BUILD" = "1" ]; then
ROCM_BUILD=true
else
git submodule status --recursive | cut -d' ' -f3 | xargs -l -P1 -I_SUB_ git config --global --add safe.directory /TransformerEngine/_SUB_
ROCM_BUILD=false
fi

if [ "$ROCM_BUILD" = "1" ]; then
${PYBINDIR}pip install setuptools wheel
if [ "$LOCAL_TREE_BUILD" != "1" ]; then
git config --global --add safe.directory $TE_ROOT
if [ "$NO_REPO_UPDATE" = "1" ]; then
git submodule status --recursive | cut -d' ' -f3 | xargs -l -P1 -I_SUB_ git config --global --add safe.directory $TE_ROOT/_SUB_
else
if [ $ROCM_BUILD ]; then
git pull
fi
git checkout $TARGET_BRANCH
git submodule update --init --recursive
fi
fi

# Install deps
if [ "$ROCM_BUILD" = "1" ]; then
${PYBINDIR}pip install pybind11[global] ninja
if [ $ROCM_BUILD ]; then
${PYBINDIR}pip install setuptools wheel pybind11[global] ninja
else
${PYBINDIR}pip install cmake pybind11[global] ninja
fi

if $BUILD_METAPACKAGE ; then
cd /TransformerEngine
if [ "$ROCM_BUILD" != "1" ]; then
cd $TE_ROOT
if [ ! $ROCM_BUILD ]; then
PYBINDIR=/opt/python/cp310-cp310/bin/
fi
NVTE_BUILD_METAPACKAGE=1 ${PYBINDIR}python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt
mv dist/* /wheelhouse/
NVTE_BUILD_METAPACKAGE=1 ${PYBINDIR}python setup.py bdist_wheel 2>&1 | tee $WHEEL_ROOT/logs/metapackage.txt
mv dist/* $WHEEL_ROOT/
fi

if $BUILD_COMMON ; then
if $BUILD_COMMON -a $ROCM_BUILD; then
VERSION=`cat build_tools/VERSION.txt`
WHL_BASE="transformer_engine_rocm-${VERSION}"
#dataclasses, psutil are needed for AITER
${PYBINDIR}pip install dataclasses psutil
#hipify expects python in PATH, also ninja may be installed to python bindir
test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true

# Create the wheel.
${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee $WHEEL_ROOT/logs/common.txt

# Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}"
mv dist/*.whl $WHEEL_ROOT/"$whl_name_target"

elif $BUILD_COMMON; then
VERSION=`cat build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
if [ "$ROCM_BUILD" = "1" ]; then
TE_CUDA_VERS="rocm"
#dataclasses, psutil are needed for AITER
${PYBINDIR}pip install dataclasses psutil
#hipify expects python in PATH, also ninja may be installed to python bindir
test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true
else
TE_CUDA_VERS="cu12"
PYBINDIR=/opt/python/cp38-cp38/bin/
fi
PYBINDIR=/opt/python/cp38-cp38/bin/

# Create the wheel.
${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee $WHEEL_ROOT/logs/common.txt

# Repack the wheel for cuda specific package, i.e. cu12.
${PYBINDIR}wheel unpack dist/*
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
${PYBINDIR}wheel pack ${WHL_BASE}

# Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}_${TE_CUDA_VERS}-${whl_parts[1]}-py3-none-${whl_parts[4]}"
whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}"
rm -rf $WHL_BASE dist
mv *.whl /wheelhouse/"$whl_name_target"
mv *.whl $WHEEL_ROOT/"$whl_name_target"
fi

if $BUILD_PYTORCH ; then
cd /TransformerEngine/transformer_engine/pytorch
if [ "$ROCM_BUILD" = "1" ]; then
${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3
cd $TE_ROOT/transformer_engine/pytorch
if [ $ROCM_BUILD ]; then
${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/cpu
else
PYBINDIR=/opt/python/cp38-cp38/bin/
${PYBINDIR}pip install torch
fi
${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
cp dist/* /wheelhouse/
${PYBINDIR}python setup.py sdist 2>&1 | tee $WHEEL_ROOT/logs/torch.txt
cp dist/* $WHEEL_ROOT/
fi

if $BUILD_JAX ; then
cd /TransformerEngine/transformer_engine/jax
if [ "$ROCM_BUILD" = "1" ]; then
cd $TE_ROOT/transformer_engine/jax
if [ $ROCM_BUILD ]; then
${PYBINDIR}pip install jax
else
PYBINDIR=/opt/python/cp310-cp310/bin/
${PYBINDIR}pip install "jax[cuda12_local]" jaxlib
fi
${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/
${PYBINDIR}python setup.py sdist 2>&1 | tee $WHEEL_ROOT/logs/jax.txt
cp dist/* $WHEEL_ROOT/
fi
19 changes: 15 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,11 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
assert bool(
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
), "NVTE_RELEASE_BUILD env must be set for metapackage build."
te_cuda_vers = "rocm" if rocm_build() else "cu12"
ext_modules = []
cmdclass = {}
package_data = {}
include_package_data = False
install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],)
install_requires = ([f"transformer_engine_cu12=={__version__}"],)
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
Expand Down Expand Up @@ -222,9 +221,21 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
)
)

PACKAGE_NAME="transformer_engine"
if rocm_build():
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
install_requires = ([f"transformer_engine_rocm=={__version__}"],)
else:
PACKAGE_NAME="transformer_engine_rocm"
#On ROCm make add extra to core package too so it can be installed w/o metapackage
extras_require = {
"pytorch": [f"transformer_engine_rocm_torch=={__version__}"],
"jax": [f"transformer_engine_rocm_jax=={__version__}"],
}
# Configure package
setuptools.setup(
name="transformer_engine",
name=PACKAGE_NAME,
version=__version__,
packages=setuptools.find_packages(
include=[
Expand All @@ -239,7 +250,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8",
python_requires=">=3.9",
classifiers=["Programming Language :: Python :: 3"],
install_requires=install_requires,
license_files=("LICENSE",),
Expand Down
9 changes: 7 additions & 2 deletions transformer_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -83,4 +83,9 @@
category=RuntimeWarning,
)

__version__ = str(metadata.version("transformer_engine"))
try:
__version__ = str(metadata.version("transformer_engine"))
except metadata.PackageNotFoundError:
if not transformer_engine.common.te_rocm_build:
raise
__version__ = str(metadata.version("transformer_engine_rocm"))
4 changes: 3 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -473,9 +473,11 @@ if (USE_ROCM)
file(READ "${ROCM_PATH}/.info/version" ROCM_VER)
string(STRIP "${ROCM_VER}" ROCM_VER)
string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER}")
get_git_commit("${TE}" TE_COMMIT_ID)
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt"
"ROCM_VERSION: ${ROCM_VER}\n"
"GPU_TARGETS: ${CMAKE_HIP_ARCHITECTURES}\n"
"COMMIT_ID: ${TE_COMMIT_ID}\n"
)
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" DESTINATION "transformer_engine/")
endif()
16 changes: 8 additions & 8 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -133,7 +133,7 @@ def load_framework_extension(framework: str) -> None:
if framework == "torch":
extra_dep_name = "pytorch"

te_cuda_vers = "rocm" if te_rocm_build else "cu12"
te_core_tag = "rocm" if te_rocm_build else "cu12"

# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
Expand All @@ -143,24 +143,24 @@ def load_framework_extension(framework: str) -> None:
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
f"transformer_engine_{te_cuda_vers}"
), f"Could not find `transformer-engine-{te_cuda_vers}`."
f"transformer_engine_{te_core_tag}"
), f"Could not find `transformer-engine-{te_core_tag}`."
assert (
version(module_name)
== version("transformer-engine")
== version(f"transformer-engine-{te_cuda_vers}")
== version(f"transformer-engine-{te_core_tag}")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-{te_cuda_vers}"
f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using "
f" v{version('transformer-engine')}, and transformer-engine-{te_core_tag}"
f" v{version(f'transformer-engine-{te_core_tag}')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)

# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed(f"transformer-engine-{te_cuda_vers}"):
if _is_pip_package_installed(f"transformer-engine-{te_core_tag}"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
Expand Down
Loading