Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions kernel_tuner/backends/nvcuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@
"""
kernel_string = kernel_instance.kernel_string
kernel_name = kernel_instance.name

# mimic pycuda behavior to wrap kernel_string in extern "C" if not in kernel_string already
if 'extern "C"' not in kernel_string:
kernel_string = 'extern "C" {\n' + kernel_string + "\n}"
expression_name = str.encode(kernel_name)

compiler_options = self.compiler_options_bytes
if not any([b"--std=" in opt for opt in compiler_options]):
Expand All @@ -171,20 +168,37 @@

err, program = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"CUDAProgram", 0, [], [])
try:
# Add the kernel as an expression. This is necessary for templated kernels to ensure that the
# compiler actually instantiates the kernel that we want to compile.
cuda_error_check(err)
err = nvrtc.nvrtcAddNameExpression(program, expression_name)

Check warning on line 175 in kernel_tuner/backends/nvcuda.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the useless trailing whitespaces at the end of this line.

See more on https://sonarcloud.io/project/issues?id=KernelTuner_kernel_tuner&issues=AZz3Uk0mrZDh_eVvPbm9&open=AZz3Uk0mrZDh_eVvPbm9&pullRequest=367
# Compile the program
cuda_error_check(err)
err = nvrtc.nvrtcCompileProgram(program, len(compiler_options), compiler_options)

Check warning on line 179 in kernel_tuner/backends/nvcuda.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the useless trailing whitespaces at the end of this line.

See more on https://sonarcloud.io/project/issues?id=KernelTuner_kernel_tuner&issues=AZz3Uk0mrZDh_eVvPbm-&open=AZz3Uk0mrZDh_eVvPbm-&pullRequest=367
# Get the PTX
cuda_error_check(err)
err, size = nvrtc.nvrtcGetPTXSize(program)
cuda_error_check(err)
buff = b" " * size
err = nvrtc.nvrtcGetPTX(program, buff)
cuda_error_check(err)

Check warning on line 187 in kernel_tuner/backends/nvcuda.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the useless trailing whitespaces at the end of this line.

See more on https://sonarcloud.io/project/issues?id=KernelTuner_kernel_tuner&issues=AZz3Uk0mrZDh_eVvPbm_&open=AZz3Uk0mrZDh_eVvPbm_&pullRequest=367
# Load the module
err, self.current_module = driver.cuModuleLoadData(np.char.array(buff))
if err == driver.CUresult.CUDA_ERROR_INVALID_PTX:
raise SkippableFailure("uses too much shared data")
else:
cuda_error_check(err)
err, self.func = driver.cuModuleGetFunction(self.current_module, str.encode(kernel_name))

Check warning on line 194 in kernel_tuner/backends/nvcuda.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the useless trailing whitespaces at the end of this line.

See more on https://sonarcloud.io/project/issues?id=KernelTuner_kernel_tuner&issues=AZz3Uk0mrZDh_eVvPbnA&open=AZz3Uk0mrZDh_eVvPbnA&pullRequest=367
# First, get the "lowered" name of the kernel (i.e., the name inside the PTX).
# After, we can use the lowered name to lookup the kernel in the module.

Check warning on line 196 in kernel_tuner/backends/nvcuda.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the useless trailing whitespaces at the end of this line.

See more on https://sonarcloud.io/project/issues?id=KernelTuner_kernel_tuner&issues=AZz3Uk0mrZDh_eVvPbnB&open=AZz3Uk0mrZDh_eVvPbnB&pullRequest=367
err, lowered_name = nvrtc.nvrtcGetLoweredName(program, expression_name)
cuda_error_check(err)
err, self.func = driver.cuModuleGetFunction(
self.current_module, lowered_name
)
cuda_error_check(err)

# get the number of registers per thread used in this kernel
Expand Down
2 changes: 1 addition & 1 deletion kernel_tuner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose)
)

# check for templated kernel
if kernel_source.lang in ["CUDA", "NVCUDA", "HIP"] and "<" in name and ">" in name:
if kernel_source.lang in ["CUDA", "HIP"] and "<" in name and ">" in name:
kernel_string, name = wrap_templated_kernel(kernel_string, name)

# Preprocess GPU arguments. Require for handling `Tunable` arguments
Expand Down
Loading