You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: Common/ML/include/ML/OrtInterface.h
+4-4Lines changed: 4 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -58,13 +58,13 @@ class OrtModel
58
58
59
59
// Inferencing
60
60
template <classI, classO> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
template <classI, classO> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
template <classI, classO> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
67
-
voidinference(I*, size_t, O*);
67
+
voidinference(I*, size_t, O*, int32_t = -1);
68
68
69
69
// template<class I, class T, class O> // class I is the input data type, e.g. float, class T the throughput data type and class O is the output data type
70
70
// std::vector<O> inference(std::vector<I>&);
@@ -92,7 +92,7 @@ class OrtModel
92
92
// Environment settings
93
93
boolmInitialized = false;
94
94
std::string modelPath, device = "cpu", thread_affinity = ""; // device options should be cpu, rocm, migraphx, cuda
(pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size()); // TODO: Not sure if 1 is correct here
271
+
(pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size()); // TODO: Not sure if 1 is always correct here
AddOption(nnInferenceAllocateDevMem, int, 0, "", 0, "(bool, default = 0), if the device memory should be allocated for inference")
232
-
AddOption(nnInferenceDtype, std::string, "fp32", "", 0, "(std::string) Specify the datatype for which inference is performed (fp32: default, fp16)") // fp32 or fp16
232
+
AddOption(nnInferenceInputDType, std::string, "FP32", "", 0, "(std::string) Specify the datatype for which inference is performed (FP32: default, fp16)") // fp32 or fp16
233
+
AddOption(nnInferenceOutputDType, std::string, "FP32", "", 0, "(std::string) Specify the datatype for which inference is performed (fp32: default, fp16)") // fp32 or fp16
233
234
AddOption(nnInferenceIntraOpNumThreads, int, 1, "", 0, "Number of threads used to evaluate one neural network (ONNX: SetIntraOpNumThreads). 0 = auto-detect, can lead to problems on SLURM systems.")
234
235
AddOption(nnInferenceInterOpNumThreads, int, 1, "", 0, "Number of threads used to evaluate one neural network (ONNX: SetInterOpNumThreads). 0 = auto-detect, can lead to problems on SLURM systems.")
235
236
AddOption(nnInferenceEnableOrtOptimization, unsignedint, 99, "", 0, "Enables graph optimizations in ONNX Runtime. Can be [0, 1, 2, 99] -> see https://github.com/microsoft/onnxruntime/blob/3f71d637a83dc3540753a8bb06740f67e926dc13/include/onnxruntime/core/session/onnxruntime_c_api.h#L347")
@@ -250,6 +251,16 @@ AddOption(nnClassificationPath, std::string, "network_class.onnx", "", 0, "The c
250
251
AddOption(nnClassThreshold, float, 0.5, "", 0, "The cutoff at which clusters will be accepted / rejected.")
251
252
AddOption(nnRegressionPath, std::string, "network_reg.onnx", "", 0, "The regression network path")
252
253
AddOption(nnSigmoidTrafoClassThreshold, int, 1, "", 0, "If true (default), then the classification threshold is transformed by an inverse sigmoid function. This depends on how the network was trained (with a sigmoid as acitvation function in the last layer or not).")
254
+
// CCDB
255
+
AddOption(nnLoadFromCCDB, int, 1, "", 0, "If 1 networks are fetched from ccdb, else locally")
256
+
AddOption(nnCCDBURL, std::string, "http://ccdb-test.cern.ch:8080", "", 0, "The CCDB URL from where the network files are fetched")
257
+
AddOption(nnCCDBPath, std::string, "Users/c/csonnabe/TPC/Clusterization", "", 0, "Folder path containing the networks")
258
+
AddOption(nnCCDBFetchMode, std::string, "c1:r1", "", 0, "Concatention of modes, e.g. c1:r1 (classification class 1, regression class 1)")
259
+
AddOption(nnCCDBWithMomentum, int, 1, "", 0, "Distinguishes between the network with and without momentum output for the regression")
260
+
AddOption(nnCCDBClassificationLayerType, std::string, "FC", "", 0, "Distinguishes between network with different layer types. Options: FC, CNN")
261
+
AddOption(nnCCDBRegressionLayerType, std::string, "CNN", "", 0, "Distinguishes between network with different layer types. Options: FC, CNN")
262
+
AddOption(nnCCDBBeamType, std::string, "PbPb", "", 0, "Distinguishes between networks trained for different beam types. Options: PbPb, pp")
263
+
AddOption(nnCCDBInteractionRate, int, 50, "", 0, "Distinguishes between networks for different interaction rates [kHz].")
0 commit comments