Skip to content

Commit 7edea5d

Browse files
authored
Merge branch 'AliceO2Group:dev' into cocktail
2 parents d52ac29 + 9424b41 commit 7edea5d

File tree

27 files changed

+702
-174
lines changed

27 files changed

+702
-174
lines changed

Common/ML/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,21 @@
99
# granted to it by virtue of its status as an Intergovernmental Organization
1010
# or submit itself to any jurisdiction.
1111

12+
# Pass ORT variables as a preprocessor definition
13+
if(DEFINED ENV{ORT_ROCM_BUILD})
14+
add_compile_definitions(ORT_ROCM_BUILD=$ENV{ORT_ROCM_BUILD})
15+
endif()
16+
if(DEFINED ENV{ORT_CUDA_BUILD})
17+
add_compile_definitions(ORT_CUDA_BUILD=$ENV{ORT_CUDA_BUILD})
18+
endif()
19+
if(DEFINED ENV{ORT_MIGRAPHX_BUILD})
20+
add_compile_definitions(ORT_MIGRAPHX_BUILD=$ENV{ORT_MIGRAPHX_BUILD})
21+
endif()
22+
if(DEFINED ENV{ORT_TENSORRT_BUILD})
23+
add_compile_definitions(ORT_TENSORRT_BUILD=$ENV{ORT_TENSORRT_BUILD})
24+
endif()
25+
1226
o2_add_library(ML
13-
SOURCES src/ort_interface.cxx
27+
SOURCES src/OrtInterface.cxx
1428
TARGETVARNAME targetName
1529
PRIVATE_LINK_LIBRARIES O2::Framework ONNXRuntime::ONNXRuntime)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
// granted to it by virtue of its status as an Intergovernmental Organization
1010
// or submit itself to any jurisdiction.
1111

12-
/// \file ort_interface.h
12+
/// \file OrtInterface.h
1313
/// \author Christian Sonnabend <christian.sonnabend@cern.ch>
1414
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU
1515

16-
#ifndef O2_ML_ONNX_INTERFACE_H
17-
#define O2_ML_ONNX_INTERFACE_H
16+
#ifndef O2_ML_ORTINTERFACE_H
17+
#define O2_ML_ORTINTERFACE_H
1818

1919
// C++ and system includes
2020
#include <vector>
@@ -89,4 +89,4 @@ class OrtModel
8989

9090
} // namespace o2
9191

92-
#endif // O2_ML_ORT_INTERFACE_H
92+
#endif // O2_ML_ORTINTERFACE_H
Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
// granted to it by virtue of its status as an Intergovernmental Organization
1010
// or submit itself to any jurisdiction.
1111

12-
/// \file ort_interface.cxx
12+
/// \file OrtInterface.cxx
1313
/// \author Christian Sonnabend <christian.sonnabend@cern.ch>
1414
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU
1515

16-
#include "ML/ort_interface.h"
16+
#include "ML/OrtInterface.h"
1717
#include "ML/3rdparty/GPUORTFloat16.h"
1818

1919
// ONNX includes
@@ -50,29 +50,35 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
5050
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
5151
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
5252
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
53-
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
53+
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2);
5454
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
5555
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
5656

5757
std::string dev_mem_str = "Hip";
58-
#ifdef ORT_ROCM_BUILD
58+
#if defined(ORT_ROCM_BUILD)
59+
#if ORT_ROCM_BUILD == 1
5960
if (device == "ROCM") {
6061
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId));
6162
LOG(info) << "(ORT) ROCM execution provider set";
6263
}
6364
#endif
64-
#ifdef ORT_MIGRAPHX_BUILD
65+
#endif
66+
#if defined(ORT_MIGRAPHX_BUILD)
67+
#if ORT_MIGRAPHX_BUILD == 1
6568
if (device == "MIGRAPHX") {
6669
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId));
6770
LOG(info) << "(ORT) MIGraphX execution provider set";
6871
}
6972
#endif
70-
#ifdef ORT_CUDA_BUILD
73+
#endif
74+
#if defined(ORT_CUDA_BUILD)
75+
#if ORT_CUDA_BUILD == 1
7176
if (device == "CUDA") {
7277
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId));
7378
LOG(info) << "(ORT) CUDA execution provider set";
7479
dev_mem_str = "Cuda";
7580
}
81+
#endif
7682
#endif
7783

7884
if (allocateDeviceMemory) {
@@ -106,7 +112,27 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
106112
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
107113
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
108114

109-
pImplOrt->env = std::make_shared<Ort::Env>(OrtLoggingLevel(loggingLevel), (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()));
115+
pImplOrt->env = std::make_shared<Ort::Env>(
116+
OrtLoggingLevel(loggingLevel),
117+
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
118+
// Integrate ORT logging into Fairlogger
119+
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
120+
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
121+
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
122+
} else if (severity == ORT_LOGGING_LEVEL_INFO) {
123+
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
124+
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
125+
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
126+
} else if (severity == ORT_LOGGING_LEVEL_ERROR) {
127+
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
128+
} else if (severity == ORT_LOGGING_LEVEL_FATAL) {
129+
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
130+
} else {
131+
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
132+
}
133+
},
134+
(void*)3);
135+
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
110136
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
111137

112138
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
@@ -130,16 +156,14 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
130156
[&](const std::string& str) { return str.c_str(); });
131157

132158
// Print names
133-
if (loggingLevel > 1) {
134-
LOG(info) << "Input Nodes:";
135-
for (size_t i = 0; i < mInputNames.size(); i++) {
136-
LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
137-
}
159+
LOG(info) << "\tInput Nodes:";
160+
for (size_t i = 0; i < mInputNames.size(); i++) {
161+
LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
162+
}
138163

139-
LOG(info) << "Output Nodes:";
140-
for (size_t i = 0; i < mOutputNames.size(); i++) {
141-
LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
142-
}
164+
LOG(info) << "\tOutput Nodes:";
165+
for (size_t i = 0; i < mOutputNames.size(); i++) {
166+
LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
143167
}
144168
}
145169

DataFormats/Detectors/CTP/src/Scalers.cxx

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -674,35 +674,35 @@ std::pair<double, double> CTPRunScalers::getRate(uint32_t orbit, int classindex,
674674

675675
// then we can use binary search to find the right entries
676676
auto iter = std::lower_bound(mScalerRecordO2.begin(), mScalerRecordO2.end(), orbit, [&](CTPScalerRecordO2 const& a, uint32_t value) { return a.intRecord.orbit <= value; });
677-
auto nextindex = iter - mScalerRecordO2.begin(); // this points to the first index that has orbit greater or equal to given orbit
677+
auto nextindex = std::distance(mScalerRecordO2.begin(), iter); // this points to the first index that has orbit greater or equal to given orbit
678678

679679
auto calcRate = [&](auto index1, auto index2) -> double {
680-
auto next = &mScalerRecordO2[index2];
681-
auto prev = &mScalerRecordO2[index1];
682-
auto timedelta = (next->intRecord.orbit - prev->intRecord.orbit) * 88.e-6; // converts orbits into time
680+
const auto& snext = mScalerRecordO2[index2];
681+
const auto& sprev = mScalerRecordO2[index1];
682+
auto timedelta = (snext.intRecord.orbit - sprev.intRecord.orbit) * 88.e-6; // converts orbits into time
683683
if (type < 7) {
684-
auto s0 = &(prev->scalers[classindex]); // type CTPScalerO2*
685-
auto s1 = &(next->scalers[classindex]);
684+
const auto& s0 = sprev.scalers[classindex]; // type CTPScalerO2*
685+
const auto& s1 = snext.scalers[classindex];
686686
switch (type) {
687687
case 1:
688-
return (s1->lmBefore - s0->lmBefore) / timedelta;
688+
return (s1.lmBefore - s0.lmBefore) / timedelta;
689689
case 2:
690-
return (s1->lmAfter - s0->lmAfter) / timedelta;
690+
return (s1.lmAfter - s0.lmAfter) / timedelta;
691691
case 3:
692-
return (s1->l0Before - s0->l0Before) / timedelta;
692+
return (s1.l0Before - s0.l0Before) / timedelta;
693693
case 4:
694-
return (s1->l0After - s0->l0After) / timedelta;
694+
return (s1.l0After - s0.l0After) / timedelta;
695695
case 5:
696-
return (s1->l1Before - s0->l1Before) / timedelta;
696+
return (s1.l1Before - s0.l1Before) / timedelta;
697697
case 6:
698-
return (s1->l1After - s0->l1After) / timedelta;
698+
return (s1.l1After - s0.l1After) / timedelta;
699699
default:
700700
LOG(error) << "Wrong type:" << type;
701701
return -1; // wrong type
702702
}
703703
} else if (type == 7) {
704-
auto s0 = &(prev->scalersInps[classindex]); // type CTPScalerO2*
705-
auto s1 = &(next->scalersInps[classindex]);
704+
auto s0 = sprev.scalersInps[classindex]; // type CTPScalerO2*
705+
auto s1 = snext.scalersInps[classindex];
706706
return (s1 - s0) / timedelta;
707707
} else {
708708
LOG(error) << "Wrong type:" << type;
@@ -738,37 +738,37 @@ std::pair<double, double> CTPRunScalers::getRateGivenT(double timestamp, int cla
738738
// this points to the first index that has orbit greater to given orbit;
739739
// If this is 0, it means that the above condition was false from the beginning, basically saying that the timestamp is below any of the ScalerRecords' orbits.
740740
// If this is mScalerRecordO2.size(), it means mScalerRecordO2.end() was returned, condition was met throughout all ScalerRecords, basically saying the timestamp is above any of the ScalarRecordss orbits.
741-
auto nextindex = iter - mScalerRecordO2.begin();
741+
auto nextindex = std::distance(mScalerRecordO2.begin(), iter);
742742

743743
auto calcRate = [&](auto index1, auto index2) -> double {
744-
auto next = &mScalerRecordO2[index2];
745-
auto prev = &mScalerRecordO2[index1];
746-
auto timedelta = (next->intRecord.orbit - prev->intRecord.orbit) * 88.e-6; // converts orbits into time
744+
const auto& snext = mScalerRecordO2[index2];
745+
const auto& sprev = mScalerRecordO2[index1];
746+
auto timedelta = (snext.intRecord.orbit - sprev.intRecord.orbit) * 88.e-6; // converts orbits into time
747747
// std::cout << "timedelta:" << timedelta << std::endl;
748748
if (type < 7) {
749-
auto s0 = &(prev->scalers[classindex]); // type CTPScalerO2*
750-
auto s1 = &(next->scalers[classindex]);
749+
const auto& s0 = sprev.scalers[classindex]; // type CTPScalerO2*
750+
const auto& s1 = snext.scalers[classindex];
751751
switch (type) {
752752
case 1:
753-
return (s1->lmBefore - s0->lmBefore) / timedelta;
753+
return (s1.lmBefore - s0.lmBefore) / timedelta;
754754
case 2:
755-
return (s1->lmAfter - s0->lmAfter) / timedelta;
755+
return (s1.lmAfter - s0.lmAfter) / timedelta;
756756
case 3:
757-
return (s1->l0Before - s0->l0Before) / timedelta;
757+
return (s1.l0Before - s0.l0Before) / timedelta;
758758
case 4:
759-
return (s1->l0After - s0->l0After) / timedelta;
759+
return (s1.l0After - s0.l0After) / timedelta;
760760
case 5:
761-
return (s1->l1Before - s0->l1Before) / timedelta;
761+
return (s1.l1Before - s0.l1Before) / timedelta;
762762
case 6:
763-
return (s1->l1After - s0->l1After) / timedelta;
763+
return (s1.l1After - s0.l1After) / timedelta;
764764
default:
765765
LOG(error) << "Wrong type:" << type;
766766
return -1; // wrong type
767767
}
768768
} else if (type == 7) {
769769
// LOG(info) << "doing input:";
770-
auto s0 = prev->scalersInps[classindex]; // type CTPScalerO2*
771-
auto s1 = next->scalersInps[classindex];
770+
auto s0 = sprev.scalersInps[classindex]; // type CTPScalerO2*
771+
auto s1 = snext.scalersInps[classindex];
772772
return (s1 - s0) / timedelta;
773773
} else {
774774
LOG(error) << "Wrong type:" << type;

0 commit comments

Comments
 (0)