Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,11 @@ if(NOT ${ARCHITECTURE} MATCHES "arm.*")
add_subdirectory(CpuMathNative)
add_subdirectory(FastTreeNative)
add_subdirectory(MklProxyNative)
# TODO: once we fix the 4 intel MKL methods, SymSgdNative will need to go back in.
add_subdirectory(SymSgdNative)
endif()
else()
add_subdirectory(MklImportsArm)
add_subdirectory(SymSgdNative)
endif()

if(${ARCHITECTURE} MATCHES "[xX].*64")
add_subdirectory(OneDalNative)
Expand Down
24 changes: 24 additions & 0 deletions src/Native/MklImportsArm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
project(MklImportsArm)

# On ARM platforms, Intel MKL is not available. This target provides
# a compatible libMklImports.so backed by the system BLAS (typically
# OpenBLAS) with stubs for MKL-specific sparse BLAS and FFT functions.

find_package(BLAS REQUIRED)

set(SOURCES
MklImportsArm.c
)

if(NOT WIN32)
list(APPEND SOURCES ${VERSION_FILE_PATH})
SET(CMAKE_SKIP_BUILD_RPATH FALSE)
SET(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE)
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
SET(CMAKE_INSTALL_RPATH "$ORIGIN/")
endif()

add_library(MklImports SHARED ${SOURCES} ${RESOURCES})
target_link_libraries(MklImports PUBLIC ${BLAS_LIBRARIES})

install_library_and_symbols(MklImports)
57 changes: 57 additions & 0 deletions src/Native/MklImportsArm/MklImportsArm.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

// ARM replacement for Intel MKL (libMklImports.so).
//
// Standard CBLAS functions (sgemm, sgemv, saxpy, sdot, etc.) are
// forwarded to OpenBLAS, which exports them with identical signatures.
//
// Sparse CBLAS extensions (saxpyi, sdoti) are provided here since
// OpenBLAS does not include them.
//
// MKL DFTI (FFT) functions are stubbed — they are referenced by the
// managed MKL Components initializer but not used by SymSGD. The stubs
// return error codes so any actual FFT call fails cleanly rather than
// crashing.

// --- Sparse BLAS (MKL extensions, not in OpenBLAS) ---

void cblas_saxpyi(const int nz, const float a,
const float *x, const int *indx, float *y)
{
for (int i = 0; i < nz; i++)
y[indx[i]] += a * x[i];
}

float cblas_sdoti(const int nz, const float *x,
const int *indx, const float *y)
{
float result = 0.0f;
for (int i = 0; i < nz; i++)
result += x[i] * y[indx[i]];
return result;
}

// --- DFTI (FFT) stubs ---

const char* DftiErrorMessage(long status)
{
return "DFTI not available (OpenBLAS arm64 build)";
}

long DftiCreateDescriptor(void **h, int precision, int domain, int dim, ...)
{
*h = (void*)0;
return -1;
}

long DftiSetValue(void *h, int param, ...)
{
return -1;
}

long DftiCommitDescriptor(void *h) { return -1; }
long DftiComputeForward(void *h, ...) { return -1; }
long DftiComputeBackward(void *h, ...) { return -1; }
long DftiFreeDescriptor(void **h) { return 0; }
6 changes: 5 additions & 1 deletion src/Native/SymSgdNative/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ else()
endif()
endif()

if(NOT ${ARCHITECTURE} MATCHES "arm.*")
if(${ARCHITECTURE} MATCHES "arm.*")
# On ARM, MklImports is built from MklImportsArm (OpenBLAS-backed).
# Link against the CMake target directly.
set(MKL_LIBRARY MklImports)
else()
find_library(MKL_LIBRARY MklImports HINTS ${MKL_LIB_PATH})
endif()

Expand Down
16 changes: 11 additions & 5 deletions src/Native/SymSgdNative/SparseBLAS.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
#pragma once
#include "../Stdafx.h"

extern "C" float __cdecl cblas_sdot(const int vecSize, const float* denseVecX, const int incX, const float* denseVecY, const int incY);
extern "C" float __cdecl cblas_sdoti(const int sparseVecSize, const float* sparseVecValues, const int* sparseVecIndices, float* denseVec);
extern "C" void __cdecl cblas_saxpy(const int vecSize, const float coef, const float* denseVecX, const int incX, float* denseVecY, const int incY);
extern "C" void __cdecl cblas_saxpyi(const int sparseVecSize, const float coef, const float* sparseVecValues, const int* sparseVecIndices, float* denseVec);
#ifdef _WIN32
#define CBLAS_CALLING_CONV __cdecl
#else
#define CBLAS_CALLING_CONV
#endif

extern "C" float CBLAS_CALLING_CONV cblas_sdot(const int vecSize, const float* denseVecX, const int incX, const float* denseVecY, const int incY);
extern "C" float CBLAS_CALLING_CONV cblas_sdoti(const int sparseVecSize, const float* sparseVecValues, const int* sparseVecIndices, float* denseVec);
extern "C" void CBLAS_CALLING_CONV cblas_saxpy(const int vecSize, const float coef, const float* denseVecX, const int incX, float* denseVecY, const int incY);
extern "C" void CBLAS_CALLING_CONV cblas_saxpyi(const int sparseVecSize, const float coef, const float* sparseVecValues, const int* sparseVecIndices, float* denseVec);

float SDOT(const int vecSize, const float* denseVecX, const float* denseVecY)
{
Expand All @@ -28,4 +34,4 @@ void SAXPY(const int vecSize, const float* denseVecX, float* denseVecY, float co
void SAXPYI(const int sparseVecSize, const int* sparseVecIndices, const float* sparseVecValues, float* denseVec, float coef)
{
cblas_saxpyi(sparseVecSize, coef, sparseVecValues, sparseVecIndices, denseVec);
}
}
Loading