Skip to content

Commit 83d0257

Browse files
committed
Limiting threads for ONNX evaluation
1 parent 19b5bd5 commit 83d0257

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

Common/ML/CMakeLists.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
# granted to it by virtue of its status as an Intergovernmental Organization
1010
# or submit itself to any jurisdiction.
1111

12-
# Pass ORT variables as a preprocessor definition
13-
add_compile_definitions(ORT_ROCM_BUILD=${ORT_ROCM_BUILD})
14-
add_compile_definitions(ORT_CUDA_BUILD=${ORT_CUDA_BUILD})
15-
add_compile_definitions(ORT_MIGRAPHX_BUILD=${ORT_MIGRAPHX_BUILD})
16-
add_compile_definitions(ORT_TENSORRT_BUILD=${ORT_TENSORRT_BUILD})
17-
1812
o2_add_library(ML
1913
SOURCES src/OrtInterface.cxx
2014
TARGETVARNAME targetName
2115
PRIVATE_LINK_LIBRARIES O2::Framework ONNXRuntime::ONNXRuntime)
16+
17+
# Pass ORT variables as a preprocessor definition
18+
target_compile_definitions(${targetName} PRIVATE
19+
ORT_ROCM_BUILD=$<BOOL:${ORT_ROCM_BUILD}>
20+
ORT_CUDA_BUILD=$<BOOL:${ORT_CUDA_BUILD}>
21+
ORT_MIGRAPHX_BUILD=$<BOOL:${ORT_MIGRAPHX_BUILD}>
22+
ORT_TENSORRT_BUILD=$<BOOL:${ORT_TENSORRT_BUILD}>)

Common/ML/include/ML/OrtInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,15 @@ class OrtModel
6464
std::vector<std::string> getOutputNames() const { return mOutputNames; }
6565
Ort::SessionOptions& getSessionOptions();
6666
Ort::MemoryInfo& getMemoryInfo();
67+
int32_t getIntraOpNumThreads() const { return intraOpNumThreads; }
68+
int32_t getInterOpNumThreads() const { return interOpNumThreads; }
6769

6870
// Setters
6971
void setDeviceId(int32_t id) { deviceId = id; }
7072
void setIO();
7173
void setActiveThreads(int threads) { intraOpNumThreads = threads; }
74+
void setIntraOpNumThreads(int threads) { if(deviceType == "CPU") { intraOpNumThreads = threads; } }
75+
void setInterOpNumThreads(int threads) { if(deviceType == "CPU") { interOpNumThreads = threads; } }
7276

7377
// Conversion
7478
template <class I, class O>

GPU/GPUTracking/Global/GPUChainTrackingClusterizer.cxx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
627627
uint32_t maxClusters = 0;
628628
int32_t deviceId = -1;
629629
int32_t numLanes = GetProcessingSettings().nTPCClustererLanes;
630+
int32_t maxThreads = mRec->MemoryScalers()->nTPCdigits / 6000;
630631
for (uint32_t lane = 0; lane < NSECTORS; lane++) {
631632
maxClusters = std::max(maxClusters, processors()->tpcClusterer[lane].mNMaxClusters);
632633
}
@@ -635,16 +636,25 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
635636
if (nnApplications[lane].modelsUsed[0]) {
636637
SetONNXGPUStream((nnApplications[lane].model_class).getSessionOptions(), lane, &deviceId);
637638
(nnApplications[lane].model_class).setDeviceId(deviceId);
639+
if (nnApplications[lane].model_class.getIntraOpNumThreads() > maxThreads) {
640+
nnApplications[lane].model_class.setIntraOpNumThreads(maxThreads);
641+
}
638642
(nnApplications[lane].model_class).initEnvironment();
639643
}
640644
if (nnApplications[lane].modelsUsed[1]) {
641645
SetONNXGPUStream((nnApplications[lane].model_reg_1).getSessionOptions(), lane, &deviceId);
642646
(nnApplications[lane].model_reg_1).setDeviceId(deviceId);
647+
if (nnApplications[lane].model_reg_1.getIntraOpNumThreads() > maxThreads) {
648+
nnApplications[lane].model_reg_1.setIntraOpNumThreads(maxThreads);
649+
}
643650
(nnApplications[lane].model_reg_1).initEnvironment();
644651
}
645652
if (nnApplications[lane].modelsUsed[2]) {
646653
SetONNXGPUStream((nnApplications[lane].model_reg_2).getSessionOptions(), lane, &deviceId);
647654
(nnApplications[lane].model_reg_2).setDeviceId(deviceId);
655+
if (nnApplications[lane].model_reg_2.getIntraOpNumThreads() > maxThreads) {
656+
nnApplications[lane].model_reg_2.setIntraOpNumThreads(maxThreads);
657+
}
648658
(nnApplications[lane].model_reg_2).initEnvironment();
649659
}
650660
if (nn_settings.nnClusterizerVerbosity < 3) {

0 commit comments

Comments
 (0)