Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
target_compile_options(mlx PRIVATE -Wno-psabi)
endif()

if(MSVC)
# Some of CUDA's headers include windows.h, which defines min/max macros.
if(WIN32)
# windows.h defines min/max macros that conflict with std::min/std::max.
target_compile_definitions(mlx PRIVATE NOMINMAX WIN32_LEAN_AND_MEAN)
endif()
if(MSVC)
# Unicode support in fmt does not compile in .cu files.
target_compile_definitions(mlx PRIVATE FMT_UNICODE=0)
# Disable some MSVC warnings to speed up compilation.
Expand Down
4 changes: 3 additions & 1 deletion mlx/backend/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(CLANG TRUE)
else()
set(COMPILER ${CMAKE_CXX_COMPILER})
set(CLANG FALSE)
endif()

set(COMPILE_DEPS
Expand All @@ -17,7 +18,7 @@ set(COMPILE_DEPS
unary_ops.h
binary_ops.h)

if(MSVC)
if(WIN32)
set(SHELL_EXT ps1)
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
else()
Expand All @@ -31,6 +32,7 @@ add_custom_command(
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
"${PREAMBLE_SIMD_FLAGS}"
DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h
${COMPILE_DEPS})

Expand Down
10 changes: 7 additions & 3 deletions mlx/backend/cpu/compiled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ static CompilerCache& cache() {
return cache_;
};

// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
// GPU compile is available through the GPU backend. CPU compile requires a
// usable C++ compiler (MSVC or clang-cl on Windows).
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
return device == Device::gpu || JitCompiler::available();
}

} // namespace detail
Expand Down Expand Up @@ -100,7 +100,11 @@ void* compile(
std::filesystem::create_directories(output_dir);
}

#ifdef _WIN32
std::string shared_lib_name = kernel_file_name + ".dll";
#else
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
#endif
auto shared_lib_path = (output_dir / shared_lib_name).string();
bool lib_exists = false;
{
Expand Down
176 changes: 153 additions & 23 deletions mlx/backend/cpu/jit_compiler.cpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2024-2026 Apple Inc.

#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/compiled_preamble.h"

#include <algorithm>
#include <cstdlib>
#include <sstream>
#include <vector>

#include <fmt/format.h>

namespace mlx::core {
#if defined(_MSC_VER) && \
(defined(_M_X64) || defined(_M_IX86) || defined(_M_AMD64))
#include <intrin.h>
#endif

#ifdef _MSC_VER
namespace mlx::core {

namespace {

#if defined(_MSC_VER)

// Split string into array.
std::vector<std::string> str_split(const std::string& str, char delimiter) {
std::vector<std::string> tokens;
Expand All @@ -38,11 +44,16 @@ struct VisualStudioInfo {
// Get path of Visual Studio.
// Use -latest to get only the most recent installation when multiple
// versions are installed, avoiding path concatenation issues.
auto pf86 = std::getenv("ProgramFiles(x86)");
if (!pf86) {
throw std::runtime_error(
"ProgramFiles(x86) environment variable not set.");
}
std::string vs_path = JitCompiler::exec(
fmt::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -latest -property installationPath",
std::getenv("ProgramFiles(x86)")));
" -latest -property installationPath 2>&1",
pf86));
if (vs_path.empty()) {
throw std::runtime_error("Can not find Visual Studio.");
}
Expand All @@ -57,7 +68,7 @@ struct VisualStudioInfo {
// Read the envs from vcvarsall.
std::string envs = JitCompiler::exec(
fmt::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL 2>&1 && set",
vs_path,
arch));
for (const std::string& line : str_split(envs, '\n')) {
Expand All @@ -67,15 +78,45 @@ struct VisualStudioInfo {
continue;
std::string name = line.substr(0, pos);
std::string value = line.substr(pos + 1);
if (name == "LIB") {
if (name == "INCLUDE") {
includepaths = str_split(value, ';');
} else if (name == "LIB") {
libpaths = str_split(value, ';');
} else if (name == "VCToolsInstallDir" || name == "VCTOOLSINSTALLDIR") {
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
msvc_cl = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
}
}

// Check for clang-cl bundled with Visual Studio.
std::string clang_cl_path = fmt::format(
"{0}\\VC\\Tools\\Llvm\\{1}\\bin\\clang-cl.exe", vs_path, arch);
if (std::filesystem::exists(clang_cl_path)) {
clang_cl = clang_cl_path;
}
}

std::string compiler(bool use_include) const {
if (use_include) {
// With installed headers, either cl.exe or clang-cl can parse the source.
// Prefer cl.exe since it is part of the standard Visual Studio C++
// toolchain.
return !msvc_cl.empty() ? msvc_cl : clang_cl;
}

// Without installed headers, fall back to the embedded prebuilt preamble.
// It was generated by the build compiler and may contain compiler-specific
// builtins, so prefer the same compiler family when possible.
#ifdef __clang__
return !clang_cl.empty() ? clang_cl : msvc_cl;
#else
return !msvc_cl.empty() ? msvc_cl : clang_cl;
#endif
}

std::string arch;
std::string cl_exe;
std::string msvc_cl;
std::string clang_cl;
std::vector<std::string> includepaths;
std::vector<std::string> libpaths;
};

Expand All @@ -84,21 +125,63 @@ const VisualStudioInfo& GetVisualStudioInfo() {
return info;
}

} // namespace
#endif // defined(_MSC_VER)

bool supports_avx2_target() {
#if defined(_MSC_VER)

#if defined(_M_X64) || defined(_M_IX86) || defined(_M_AMD64)
int info[4];
__cpuid(info, 0);
if (info[0] < 7) {
return false;
}
__cpuid(info, 1);
bool os_avx = (info[2] & (1 << 27)) && (info[2] & (1 << 28));
bool fma = info[2] & (1 << 12);
bool f16c = info[2] & (1 << 29);
if (!os_avx || !fma || !f16c) {
return false;
}
unsigned long long xcr0 = _xgetbv(0);
if ((xcr0 & 0x6) != 0x6) {
return false;
}
__cpuidex(info, 7, 0);
return info[1] & (1 << 5);
#else
return false;
#endif // defined(_M_X64) || defined(_M_IX86) || defined(_M_AMD64)

#elif defined(__GNUC__) || defined(__clang__)

#if defined(__x86_64__) || defined(__i386__) || defined(__amd64__)
__builtin_cpu_init();
return __builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma") &&
__builtin_cpu_supports("f16c");
#else
return false;
#endif

#else
return false;
#endif // defined(_MSC_VER)
}

#endif // _MSC_VER
} // namespace

const std::tuple<bool, std::string, std::string>& JitCompiler::get_preamble() {
static auto preamble = []() -> std::tuple<bool, std::string, std::string> {
// Check whether the headers are shipped with the binary, if so use the
// preamble from the headers, otherwise use the prebuilt one embeded in
// preamble from the headers, otherwise use the prebuilt one embedded in
// binary, which may not work with all compilers.
auto root_dir = current_binary_dir();
#if !defined(_WIN32)
root_dir = root_dir.parent_path();
#endif
auto include_dir = root_dir / "include";
if (std::filesystem::exists(include_dir / "mlx")) {
if (std::filesystem::exists(
include_dir / "mlx/backend/cpu/compiled_preamble.h")) {
return std::make_tuple(
true,
include_dir.string(),
Expand All @@ -110,36 +193,83 @@ const std::tuple<bool, std::string, std::string>& JitCompiler::get_preamble() {
return preamble;
}

bool JitCompiler::available() {
#ifdef _MSC_VER
static bool result = [] {
try {
bool use_include = std::get<0>(get_preamble());
const auto& info = GetVisualStudioInfo();
return !info.compiler(use_include).empty();
} catch (...) {
return false;
}
}();
return result;
#else
#ifdef _WIN32
static int result = std::system("g++ --version > NUL 2>&1");
#else
static int result = std::system("g++ --version > /dev/null 2>&1");
#endif
return result == 0;
#endif
}

std::string JitCompiler::build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name) {
auto& [use_include, include_dir, preamble] = get_preamble();
#ifdef _MSC_VER
std::string extra_flags;
std::string compiler_flags;
if (use_include) {
extra_flags += fmt::format("/I \"{}\"", include_dir);
compiler_flags += fmt::format(" /I \"{}\"", include_dir);
}
const VisualStudioInfo& info = GetVisualStudioInfo();
std::string compiler = info.compiler(use_include);
for (const std::string& include : info.includepaths) {
compiler_flags += fmt::format(" /I \"{}\"", include);
}
std::string libpaths;
for (const std::string& lib : info.libpaths) {
extra_flags += fmt::format(" /libpath:\"{}\"", lib);
libpaths += fmt::format(" /libpath:\"{}\"", lib);
}
// clang-cl accepts the same flags as cl.exe (/LD, /EHsc, etc.)
// but we add -Wno-everything to suppress warnings from system and shipped
// MLX headers.
if (!info.clang_cl.empty() && compiler == info.clang_cl) {
compiler_flags += " -Wno-everything";
}
#ifdef __AVX2__
Comment thread
zcbenz marked this conversation as resolved.
if (supports_avx2_target()) {
compiler_flags += " /arch:AVX2";
}
#endif
return fmt::format(
"\""
"cd /D \"{}\" && "
"\"{}\" /LD /EHsc /MD /Ox /nologo /std:c++17 {} \"{}\" "
"/link /out:\"{}\" 2>&1"
"cd /D \"{0}\" && "
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17{2} \"{3}\" "
"/link /out:\"{4}\" {5} 2>&1"
"\"",
dir.string(),
info.cl_exe,
extra_flags,
compiler,
compiler_flags,
source_file_name,
shared_lib_name);
shared_lib_name,
libpaths);
#else
std::string extra_flags;
if (use_include) {
extra_flags = fmt::format("-I \"{}\"", include_dir);
}
#ifdef __AVX2__
if (supports_avx2_target()) {
if (!extra_flags.empty()) {
extra_flags += " ";
}
extra_flags += "-mavx2 -mfma -mf16c";
}
#endif
return fmt::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared {} \"{}\" -o \"{}\" 2>&1",
extra_flags,
Expand All @@ -162,7 +292,7 @@ std::string JitCompiler::exec(const std::string& cmd) {
while (fgets(buffer, sizeof(buffer), pipe)) {
ret += buffer;
}
// Trim trailing spaces.
// Trim trailing whitespace.
ret.erase(
std::find_if(
ret.rbegin(),
Expand Down
8 changes: 7 additions & 1 deletion mlx/backend/cpu/jit_compiler.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2024-2026 Apple Inc.
#pragma once

#include <filesystem>
Expand All @@ -10,6 +10,12 @@ class JitCompiler {
// Return the includes that should be prepended to the source code.
static const std::tuple<bool, std::string, std::string>& get_preamble();

// Check if a JIT compiler is available on this system.
// On Windows, this probes for Visual Studio and a usable C++ compiler
// (MSVC cl.exe or clang-cl). On Linux/macOS, checks for g++ in PATH.
// Returns false (rather than throwing) if no compiler is found.
static bool available();

// Build a shell command that compiles a source code file to a shared library.
static std::string build_command(
const std::filesystem::path& dir,
Expand Down
Loading
Loading