Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("pre_cache_len_concat",
&PreCacheLenConcat,
"pre_cache len concat function");

/**
* moe/fused_moe/fused_moe.cu
* fused_moe
Expand Down Expand Up @@ -1596,7 +1595,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {

#ifdef ENABLE_SM80_EXT_OPS
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);

m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);

m.def("MoeWna16MarlinGemmApi",
Expand Down Expand Up @@ -1632,6 +1630,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&GetPositionIdsAndMaskEncoderBatch,
"get_position_ids_and_mask_encoder_batch function");

#ifdef ENABLE_SM75_EXT_OPS
/**
* cutlass_scaled_mm.cu
* cutlass_scaled_mm
Expand Down Expand Up @@ -1669,6 +1668,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("input"),
py::arg("scales"),
py::arg("scale_ub"));
#endif
#ifdef ENABLE_SM80_EXT_OPS
m.def("decode_mla_write_cache",
&DecodeMLAWriteCacheKernel,
Expand Down Expand Up @@ -1885,6 +1885,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("custom_numpy_to_tensor",
&CustomNumpyToTensor,
"custom_numpy_to_tensor function");
#ifdef ENABLE_SM80_EXT_OPS
m.def("prefill_permute_to_masked_gemm",
&PrefillPermuteToMaskedGemm,
py::arg("x"),
Expand Down Expand Up @@ -1919,4 +1920,5 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("per_token_group_fp8_quant",
&PerTokenGroupQuantFp8,
"per_token_group_quant_fp8");
#endif
}
43 changes: 9 additions & 34 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,32 +179,6 @@ def get_gencode_flags(archs):
return flags


def get_compile_parallelism():
"""
Decide safe compile parallelism for both build workers and nvcc threads.
"""
cpu_count = os.cpu_count() or 1

max_jobs_env = os.getenv("MAX_JOBS")
if max_jobs_env is not None:
try:
max_jobs = int(max_jobs_env)
if max_jobs < 1:
raise ValueError
except ValueError as exc:
raise ValueError(f"Invalid MAX_JOBS={max_jobs_env!r}, expected a positive integer.") from exc
else:
# Cap default build workers to avoid OOM in high-core CI runners.
max_jobs = min(cpu_count, 32)
os.environ["MAX_JOBS"] = str(max_jobs)

# Limit nvcc internal threads to avoid resource exhaustion when Paddle's
# ThreadPoolExecutor also launches many parallel compilations.
# Total threads ~= (number of parallel compile jobs) * nvcc_threads.
nvcc_threads = min(max_jobs, 4)
return max_jobs, nvcc_threads


def find_end_files(directory, end_str):
"""
Find files with end str in directory.
Expand Down Expand Up @@ -339,8 +313,8 @@ def find_end_files(directory, end_str):
"gpu_ops/reasoning_phase_token_constraint.cu",
"gpu_ops/get_attn_mask_q.cu",
]

sm_versions = get_sm_version(archs)
# Some kernels in this file require SM75+ instructions. Exclude them when building SM70 (V100).
disable_gelu_tanh = 70 in sm_versions
if disable_gelu_tanh:
sources = [s for s in sources if s != "gpu_ops/gelu_tanh.cu"]
Expand Down Expand Up @@ -397,8 +371,10 @@ def find_end_files(directory, end_str):
"-Igpu_ops",
"-Ithird_party/nlohmann_json/include",
]
max_jobs, nvcc_threads = get_compile_parallelism()
print(f"MAX_JOBS = {max_jobs}, nvcc -t = {nvcc_threads}")
# Limit nvcc internal threads to avoid resource exhaustion when Paddle's
# ThreadPoolExecutor also launches many parallel compilations.
# Total threads ≈ (number of parallel compile jobs) × nvcc_threads, so cap nvcc_threads at 4.
nvcc_threads = min(os.cpu_count() or 1, 4)
nvcc_compile_args += ["-t", str(nvcc_threads)]

nvcc_version = get_nvcc_version()
Expand Down Expand Up @@ -428,14 +404,12 @@ def find_end_files(directory, end_str):
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu",
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu",
"gpu_ops/quantization/common.cu",
# cpp_extensions.cc always registers these two ops; include their kernels on SM75 as well.
# deepgemm permute/depermute can compile on SM75 (no BF16 dependency).
"gpu_ops/moe/moe_deepgemm_permute.cu",
"gpu_ops/moe/moe_deepgemm_depermute.cu",
]

if cc >= 80:
cc_compile_args += ["-DENABLE_SM80_EXT_OPS"]
nvcc_compile_args += ["-DENABLE_SM80_EXT_OPS"]
# append_attention
os.system(
"python utils/auto_gen_template_instantiation.py --config gpu_ops/append_attn/template_config.json --output gpu_ops/append_attn/template_instantiation/autogen"
Expand All @@ -451,7 +425,9 @@ def find_end_files(directory, end_str):
# speculate_decoding
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
cc_compile_args += ["-DENABLE_SM80_EXT_OPS"]
nvcc_compile_args += ["-DENABLE_BF16"]
nvcc_compile_args += ["-DENABLE_SM80_EXT_OPS"]
# moe
os.system("python gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py")
os.system(
Expand Down Expand Up @@ -558,8 +534,7 @@ def find_end_files(directory, end_str):
sources += find_end_files("gpu_ops/machete", ".cu")
cc_compile_args += ["-DENABLE_MACHETE"]

# Deduplicate translation units while preserving order. Some files are
# appended explicitly for SM75 and also discovered by later directory globs.
# Deduplicate translation units while preserving order.
sources = list(dict.fromkeys(sources))

setup(
Expand Down
Loading