Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@
# - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler
cmake_minimum_required(VERSION 3.22.1)

# On Windows with HIP backend, auto-detect compilers from ROCM_PATH before project()
if(WIN32 AND COMPUTE_BACKEND STREQUAL "hip")
if(DEFINED ENV{ROCM_PATH})
set(ROCM_PATH $ENV{ROCM_PATH})
endif()
if(ROCM_PATH AND NOT DEFINED CMAKE_CXX_COMPILER)
set(CMAKE_CXX_COMPILER "${ROCM_PATH}/lib/llvm/bin/clang++.exe")
endif()
if(ROCM_PATH AND NOT DEFINED CMAKE_HIP_COMPILER)
set(CMAKE_HIP_COMPILER "${ROCM_PATH}/lib/llvm/bin/clang++.exe")
endif()
endif()

project(bitsandbytes LANGUAGES CXX)

# If run without specifying a build type, default to using the Release configuration:
Expand Down Expand Up @@ -200,6 +213,20 @@ if(BUILD_CUDA)
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
add_compile_definitions(BUILD_CUDA)
elseif(BUILD_HIP)
# Auto-detect GPU architecture on Windows using hipinfo.exe
if(WIN32 AND NOT DEFINED BNB_ROCM_ARCH AND NOT DEFINED AMDGPU_TARGETS AND NOT DEFINED CMAKE_HIP_ARCHITECTURES)
execute_process(
COMMAND hipinfo
OUTPUT_VARIABLE HIPINFO_OUTPUT
ERROR_QUIET
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(HIPINFO_OUTPUT MATCHES "gcnArchName:[ \t]*([a-z0-9]+)")
set(CMAKE_HIP_ARCHITECTURES "${CMAKE_MATCH_1}")
message(STATUS "Auto-detected HIP architecture: ${CMAKE_HIP_ARCHITECTURES}")
endif()
endif()

enable_language(HIP)
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
if(DEFINED BNB_ROCM_ARCH)
Expand Down Expand Up @@ -263,6 +290,8 @@ endif()
if(WIN32)
# Export all symbols
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
# Prevent Windows SDK min/max macros from conflicting with std::min/std::max
add_compile_definitions(NOMINMAX)
endif()

if(MSVC)
Expand Down Expand Up @@ -315,10 +344,11 @@ if(BUILD_CUDA)
)
endif()
if(BUILD_HIP)
if(NOT DEFINED ENV{ROCM_PATH})
set(ROCM_PATH /opt/rocm)
else()
# Determine ROCM_PATH from environment variable, fallback to /opt/rocm on Linux
if(DEFINED ENV{ROCM_PATH})
set(ROCM_PATH $ENV{ROCM_PATH})
else()
set(ROCM_PATH /opt/rocm)
endif()
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
macro(find_package_and_print_version PACKAGE_NAME)
Expand All @@ -330,14 +360,24 @@ if(BUILD_HIP)
find_package_and_print_version(hipsparse REQUIRED)

## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
## On Windows, we need to link amdhip64 explicitly
if(NOT WIN32)
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
endif()

target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)

# On Windows, link the HIP runtime and rocblas directly using full paths
if(WIN32)
target_link_libraries(bitsandbytes PUBLIC
"${ROCM_PATH}/lib/amdhip64.lib"
"${ROCM_PATH}/lib/rocblas.lib")
endif()

target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
Expand Down
31 changes: 26 additions & 5 deletions bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
from functools import lru_cache
import logging
import platform
import re
import subprocess
from typing import Optional
Expand Down Expand Up @@ -83,10 +84,21 @@ def get_rocm_gpu_arch() -> str:
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
# On Windows, use hipinfo.exe; on Linux, use rocminfo
if platform.system() == "Windows":
cmd = ["hipinfo.exe"]
arch_pattern = r"gcnArchName:\s+(gfx[a-zA-Z\d]+)"
else:
cmd = ["rocminfo"]
arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)"

result = subprocess.run(cmd, capture_output=True, text=True)
match = re.search(arch_pattern, result.stdout)
if match:
return "gfx" + match.group(1)
if platform.system() == "Windows":
return match.group(1)
else:
return "gfx" + match.group(1)
else:
return "unknown"
else:
Expand All @@ -107,8 +119,17 @@ def get_rocm_warpsize() -> int:
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
# On Windows, use hipinfo.exe; on Linux, use rocminfo
if platform.system() == "Windows":
cmd = ["hipinfo.exe"]
# hipinfo.exe output format: "warpSize: 32" or "warpSize: 64"
warp_pattern = r"warpSize:\s+(\d+)"
else:
cmd = ["rocminfo"]
warp_pattern = r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)"

result = subprocess.run(cmd, capture_output=True, text=True)
match = re.search(warp_pattern, result.stdout)
if match:
return int(match.group(1))
else:
Expand Down
7 changes: 7 additions & 0 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include <cstdint>
#include <iostream>
#include <stdio.h>
#ifdef _WIN32
#include <io.h>
#include <process.h>
#include <windows.h>
#else
#include <unistd.h>
#endif

#include <common.h>
#include <cublasLt.h>
Expand Down
10 changes: 10 additions & 0 deletions csrc/ops_hip.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@
#include <cstdint>
#include <iostream>
#include <stdio.h>

#ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <io.h>
#include <process.h>
#include <windows.h>
#else
#include <unistd.h>
#endif

#include <common.h>
#include <functional>
Expand Down