Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ option(WITH_NVIDIA "Enable CUDA backend" OFF)
option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF)
option(WITH_METAX "Enable MetaX backend" OFF)
option(WITH_CAMBRICON "Enable Cambricon backend" OFF)
option(WITH_MOORE "Enable Moore backend" OFF)

option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)
Expand Down Expand Up @@ -61,6 +62,34 @@ if(AUTO_DETECT_DEVICES)
set(WITH_CAMBRICON ON)
message(STATUS "Auto-detected Cambricon environment.")
endif()

if(DEFINED ENV{MUSA_ROOT} OR DEFINED ENV{MUSA_HOME} OR DEFINED ENV{MUSA_PATH})
set(WITH_MOORE ON)
set(WITH_MOORE ON CACHE BOOL "Enable Moore backend" FORCE)
message(STATUS "Auto-detected Moore environment.")
else()
set(WITH_MOORE OFF)
set(WITH_MOORE OFF CACHE BOOL "Enable Moore backend" FORCE)
endif()
endif()

if(WITH_MOORE)
set(MUSA_ROOT $ENV{MUSA_ROOT} $ENV{MUSA_HOME} $ENV{MUSA_PATH})
list(FILTER MUSA_ROOT EXCLUDE REGEX "^$")
list(GET MUSA_ROOT 0 MUSA_ROOT)
if(NOT MUSA_ROOT)
message(FATAL_ERROR "`WITH_MOORE` is `ON` but `MUSA_ROOT`/`MUSA_HOME`/`MUSA_PATH` is not set.")
endif()
message(STATUS "Using Moore from `${MUSA_ROOT}`.")
list(PREPEND CMAKE_MODULE_PATH "${MUSA_ROOT}/cmake")
set(MUSA_TOOLKIT_ROOT_DIR "${MUSA_ROOT}" CACHE PATH "Toolkit location." FORCE)
find_package(MUSA REQUIRED)
add_compile_definitions(WITH_MOORE=1)
include_directories("${MUSA_ROOT}/include")
link_directories("${MUSA_ROOT}/lib")
find_library(MUSA_LIB NAMES musa HINTS "${MUSA_ROOT}/lib" REQUIRED)
find_library(MUSART_LIB NAMES musart HINTS "${MUSA_ROOT}/lib" REQUIRED)
find_library(MUBLAS_LIB NAMES mublas HINTS "${MUSA_ROOT}/lib" REQUIRED)
endif()

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
Expand Down Expand Up @@ -127,7 +156,7 @@ if(WITH_CAMBRICON)
endif()

# If all other platforms are not enabled, CPU is enabled by default.
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX)
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE)
add_compile_definitions(WITH_CPU=1)
endif()

Expand Down
3 changes: 3 additions & 0 deletions examples/gemm/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#if WITH_CAMBRICON
#include "cambricon/gemm/cnblas.h"
#endif
#if WITH_MOORE
#include "moore/gemm/mublas.h"
#endif

#include "runtime_api.h"
#include "tensor.h"
Expand Down
9 changes: 9 additions & 0 deletions examples/runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@
#define DEVICE_MEMCPY_HOST_TO_DEVICE cnrtMemcpyHostToDev
#define DEVICE_MEMCPY_DEVICE_TO_HOST cnrtMemcpyDevToHost
#define DEFAULT_DEVICE_TYPE Device::Type::kCambricon
#elif WITH_MOORE
#include <musa_runtime_api.h>
#define DEVICE_MALLOC musaMalloc
#define DEVICE_FREE musaFree
#define DEVICE_MEMCPY musaMemcpy
#define DEVICE_MEMSET musaMemset
#define DEVICE_MEMCPY_HOST_TO_DEVICE musaMemcpyHostToDevice
#define DEVICE_MEMCPY_DEVICE_TO_HOST musaMemcpyDeviceToHost
#define DEFAULT_DEVICE_TYPE Device::Type::kMoore
#elif WITH_CPU
#include <cstdlib>
#include <cstring>
Expand Down
10 changes: 9 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,19 @@ if(WITH_CAMBRICON)
list(APPEND DEVICE_LIST "cambricon")
endif()

if(WITH_MOORE)
target_compile_definitions(infiniops PUBLIC WITH_MOORE=1)
target_include_directories(infiniops PUBLIC "${MUSA_ROOT}/include")
target_link_libraries(infiniops PUBLIC ${MUSA_LIB} ${MUSART_LIB} ${MUBLAS_LIB})
list(APPEND DEVICE_LIST "moore")
endif()

target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

if(GENERATE_PYTHON_BINDINGS)
find_package(Python COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND python ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST}
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST}
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
RESULT_VARIABLE script_result
)
Expand Down
20 changes: 15 additions & 5 deletions src/cuda/gemm/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,30 @@ class Blas : public Gemm {
const auto& trans_b_value{trans_b.value_or(trans_b_)};
auto op_a{GetOpA(trans_a_value, trans_b_value)};
auto op_b{GetOpB(trans_a_value, trans_b_value)};
const void* alpha_ptr{GetAlphaPtr(alpha_value, c.dtype())};
const void* beta_ptr{GetBetaPtr(beta_value, c.dtype())};

Backend::blasGemmStridedBatchedEx(
handle_, op_a, op_b, swap_a_and_b_ ? n_ : m_, swap_a_and_b_ ? m_ : n_,
k_, &alpha_value, swap_a_and_b_ ? b.data() : a.data(),
k_, alpha_ptr, swap_a_and_b_ ? b.data() : a.data(),
Backend::GetDataType(swap_a_and_b_ ? b.dtype() : a.dtype()),
swap_a_and_b_ ? ldb_ : lda_,
swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_,
swap_a_and_b_ ? a.data() : b.data(),
Backend::GetDataType(swap_a_and_b_ ? a.dtype() : b.dtype()),
swap_a_and_b_ ? lda_ : ldb_,
swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_, &beta_value,
c.data(), Backend::GetDataType(c.dtype()), ldc_, batch_stride_c_,
batch_count_, Backend::GetComputeType(c.dtype()),
Backend::BLAS_GEMM_DEFAULT);
swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_, beta_ptr, c.data(),
Backend::GetDataType(c.dtype()), ldc_, batch_stride_c_, batch_count_,
Backend::GetComputeType(c.dtype()), Backend::BLAS_GEMM_DEFAULT);
}

protected:
virtual const void* GetAlphaPtr(const float& alpha, DataType) const {
return &alpha;
}

virtual const void* GetBetaPtr(const float& beta, DataType) const {
return &beta;
}

private:
Expand Down
89 changes: 89 additions & 0 deletions src/moore/gemm/mublas.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#ifndef INFINI_OPS_MOORE_GEMM_MUBLAS_H_
#define INFINI_OPS_MOORE_GEMM_MUBLAS_H_

#include <mublas.h>
#include <musa_runtime_api.h>

#include <utility>

#include "cuda/gemm/blas.h"

namespace infini::ops {

namespace gemm {

struct MooreBackend {
using blasHandle_t = mublasHandle_t;

using stream_t = musaStream_t;

static constexpr auto BLAS_OP_N = MUBLAS_OP_N;

static constexpr auto BLAS_OP_T = MUBLAS_OP_T;

static constexpr auto R_16F = MUSA_R_16F;

static constexpr auto R_16BF = MUSA_R_16BF;

static constexpr auto R_32F = MUSA_R_32F;

static constexpr auto BLAS_GEMM_DEFAULT = MUBLAS_GEMM_DEFAULT;

static constexpr auto blasCreate = mublasCreate;

static constexpr auto blasSetStream = mublasSetStream;

static constexpr auto blasDestroy = mublasDestroy;

static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) {
return mublasGemmStridedBatchedEx(std::forward<decltype(args)>(args)...);
};

static musaDataType_t GetDataType(DataType dtype) {
if (dtype == DataType::kFloat16) return R_16F;
if (dtype == DataType::kBFloat16) return R_16BF;
return R_32F;
}

static mublasComputeType_t GetComputeType(DataType dtype) {
if (dtype == DataType::kFloat16) return MUBLAS_COMPUTE_16F;
if (dtype == DataType::kBFloat16) return MUBLAS_COMPUTE_32F;
return MUBLAS_COMPUTE_32F;
}
};

} // namespace gemm

template <>
class Operator<Gemm, Device::Type::kMoore> : public Blas<gemm::MooreBackend> {
public:
using Blas<gemm::MooreBackend>::Blas;

protected:
const void* GetAlphaPtr(const float& alpha, DataType dtype) const override {
if (gemm::MooreBackend::GetComputeType(dtype) == MUBLAS_COMPUTE_16F) {
alpha_fp16_ = Float16::FromFloat(alpha);
return &alpha_fp16_;
}

return &alpha;
}

const void* GetBetaPtr(const float& beta, DataType dtype) const override {
if (gemm::MooreBackend::GetComputeType(dtype) == MUBLAS_COMPUTE_16F) {
beta_fp16_ = Float16::FromFloat(beta);
return &beta_fp16_;
}

return &beta;
}

private:
mutable Float16 alpha_fp16_{};

mutable Float16 beta_fp16_{};
};

} // namespace infini::ops

#endif
16 changes: 13 additions & 3 deletions tests/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,17 @@ def _torch_gemm(a, b, alpha=1.0, beta=1.0, trans_a=False, trans_b=False, c=None)

return c

if a.ndim == 2:
return torch.addmm(c, a, b, beta=beta, alpha=alpha, out=c)
# Some backends (e.g. `torch_musa`) may reject `addmm`/`baddbmm(out=...)`
# for certain strided outputs. Fall back to `matmul` plus fused `alpha`/`beta`
# update to keep reference coverage.
try:
if a.ndim == 2:
return torch.addmm(c, a, b, beta=beta, alpha=alpha, out=c)

return torch.baddbmm(c, a, b, beta=beta, alpha=alpha, out=c)
except RuntimeError:
c_original = c.clone()
torch.matmul(a, b, out=c)
c.mul_(alpha).add_(c_original, alpha=beta)

return torch.baddbmm(c, a, b, beta=beta, alpha=alpha, out=c)
return c
3 changes: 3 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def get_available_devices():
if hasattr(torch, "mlu") and torch.mlu.is_available():
devices.append("mlu")

if hasattr(torch, "musa") and torch.musa.is_available():
devices.append("musa")

return tuple(devices)


Expand Down