Skip to content

Commit e7cd6fa

Browse files
committed
Working version of CCDB fetching and loading into ROOT class of std::vector<char>
1 parent 4fed621 commit e7cd6fa

File tree

11 files changed

+339
-74
lines changed

11 files changed

+339
-74
lines changed

Detectors/TPC/base/test/testTPCCDBInterface.cxx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
// o2 includes
2424
#include "TPCBase/CDBInterface.h"
25-
#include "TPCBase/CDBInterface.h"
2625
#include "TPCBase/CalArray.h"
2726
#include "TPCBase/CalDet.h"
2827
#include "TPCBase/Mapper.h"

GPU/GPUTracking/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ set(SRCS_DATATYPES
209209
DataTypes/TPCPadBitMap.cxx
210210
DataTypes/TPCZSLinkMapping.cxx
211211
DataTypes/CalibdEdxContainer.cxx
212+
DataTypes/ORTRootSerializer.cxx
212213
DataTypes/CalibdEdxTrackTopologyPol.cxx
213214
DataTypes/CalibdEdxTrackTopologySpline.cxx
214215
DataTypes/GPUTRDTrackO2.cxx)

GPU/GPUTracking/DataTypes/GPUDataTypes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class Cluster;
8585
namespace tpc
8686
{
8787
class CalibdEdxContainer;
88+
class ORTRootSerializer;
8889
} // namespace tpc
8990
} // namespace o2
9091

@@ -184,8 +185,7 @@ struct GPUCalibObjectsTemplate { // use only pointers on PODs or flat objects he
184185
typename S<o2::itsmft::TopologyDictionary>::type* itsPatternDict = nullptr;
185186

186187
// NN clusterizer objects
187-
char* nnClusterizerNetworks[3] = {nullptr, nullptr, nullptr}; // [c, r1, r2] networks as char arrays from CCDB
188-
uint32_t nnClusterizerNetworkSizes[3] = {0, 0, 0};
188+
typename S<o2::tpc::ORTRootSerializer>::type* nnClusterizerNetworks[3] = {nullptr, nullptr, nullptr};
189189
};
190190
typedef GPUCalibObjectsTemplate<DefaultPtr> GPUCalibObjects; // NOTE: These 2 must have identical layout since they are memcopied
191191
typedef GPUCalibObjectsTemplate<ConstPtr> GPUCalibObjectsConst;
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file ORTRootSerializer.cxx
13+
/// \author Christian Sonnabend <christian.sonnabend@cern.ch>
14+
15+
#include "ORTRootSerializer.h"
16+
#include <cstring>
17+
18+
using namespace o2::tpc;
19+
20+
/// Initialize the serialization from a char* buffer containing the model
21+
void ORTRootSerializer::setOnnxModel(const char* onnxModel, uint32_t size)
22+
{
23+
mModelBuffer.resize(size);
24+
std::memcpy(mModelBuffer.data(), onnxModel, size);
25+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file ORTRootSerializer.h
13+
/// \brief Class to serialize ONNX objects for ROOT snapshots of CCDB objects at runtime
14+
/// \author Christian Sonnabend <christian.sonnabend@cern.ch>
15+
16+
#ifndef ALICEO2_TPC_ORTROOTSERIALIZER_H_
17+
#define ALICEO2_TPC_ORTROOTSERIALIZER_H_
18+
19+
#include "GPUCommonRtypes.h"
20+
#include <vector>
21+
#include <string>
22+
23+
namespace o2::tpc
24+
{
25+
26+
class ORTRootSerializer
27+
{
28+
public:
29+
ORTRootSerializer() = default;
30+
~ORTRootSerializer() = default;
31+
32+
void setOnnxModel(const char* onnxModel, uint32_t size);
33+
const char* getONNXModel() const { return mModelBuffer.data(); }
34+
uint32_t getONNXModelSize() const { return static_cast<uint32_t>(mModelBuffer.size()); }
35+
36+
private:
37+
std::vector<char> mModelBuffer; ///< buffer for serialization
38+
ClassDefNV(ORTRootSerializer, 1);
39+
};
40+
41+
} // namespace o2::tpc
42+
43+
#endif // ALICEO2_TPC_ORTROOTSERIALIZER_H_

GPU/GPUTracking/Definitions/GPUSettingsList.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ AddOption(nnClusterizerBoundaryFillValue, int, -1, "", 0, "Fill value for the bo
277277
AddOption(nnClusterizerApplyNoiseSuppression, int, 1, "", 0, "Applies the NoiseSuppression kernel before the digits to the network are filled")
278278
AddOption(nnClusterizerSetDeconvolutionFlags, int, 1, "", 0, "Runs the deconvolution kernel without overwriting the charge in order to make cluster-to-track attachment identical to heuristic CF")
279279
AddOption(nnClassificationPath, std::string, "network_class.onnx", "", 0, "The classification network path")
280-
AddOption(nnClassThreshold, float, 0.5, "", 0, "The cutoff at which clusters will be accepted / rejected.")
281280
AddOption(nnRegressionPath, std::string, "network_reg.onnx", "", 0, "The regression network path")
281+
AddOption(nnClassThreshold, float, 0.5, "", 0, "The cutoff at which clusters will be accepted / rejected.")
282282
AddOption(nnSigmoidTrafoClassThreshold, int, 1, "", 0, "If true (default), then the classification threshold is transformed by an inverse sigmoid function. This depends on how the network was trained (with a sigmoid as acitvation function in the last layer or not).")
283283
AddOption(nnEvalMode, std::string, "c1:r1", "", 0, "Concatention of modes, e.g. c1:r1 (classification class 1, regression class 1)")
284284
AddOption(nnClusterizerUseClassification, int, 1, "", 0, "If 1, the classification output of the network is used to select clusters, else only the regression output is used and no clusters are rejected by classification")

GPU/GPUTracking/GPUTrackingLinkDef_O2_DataTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,6 @@
4343
#pragma link C++ class o2::tpc::CalibdEdxTrackTopologyPol + ;
4444
#pragma link C++ class o2::tpc::CalibdEdxTrackTopologySpline + ;
4545
#pragma link C++ struct o2::tpc::CalibdEdxTrackTopologyPolContainer + ;
46+
#pragma link C++ struct o2::tpc::ORTRootSerializer + ;
4647

4748
#endif

GPU/GPUTracking/Global/GPUChainTrackingClusterizer.cxx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#ifdef GPUCA_HAS_ONNX
4848
#include "GPUTPCNNClusterizerKernels.h"
4949
#include "GPUTPCNNClusterizerHost.h"
50+
#include "ORTRootSerializer.h"
5051
#endif
5152

5253
#ifdef GPUCA_O2_LIB
@@ -680,7 +681,7 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
680681
if (!nn_settings.nnLoadFromCCDB) {
681682
(nnApplications[lane].mModelClass).initSession(); // loads from file
682683
} else {
683-
(nnApplications[lane].mModelClass).initSessionFromBuffer(processors()->calibObjects.nnClusterizerNetworks[0], processors()->calibObjects.nnClusterizerNetworkSizes[0]); // loads from CCDB
684+
(nnApplications[lane].mModelClass).initSessionFromBuffer((processors()->calibObjects.nnClusterizerNetworks[0])->getONNXModel(), (processors()->calibObjects.nnClusterizerNetworks[0])->getONNXModelSize()); // loads from CCDB
684685
}
685686
}
686687
if (nnApplications[lane].mModelsUsed[1]) {
@@ -695,7 +696,7 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
695696
if (!nn_settings.nnLoadFromCCDB) {
696697
(nnApplications[lane].mModelReg1).initSession(); // loads from file
697698
} else {
698-
(nnApplications[lane].mModelReg1).initSessionFromBuffer(processors()->calibObjects.nnClusterizerNetworks[1], processors()->calibObjects.nnClusterizerNetworkSizes[1]); // loads from CCDB
699+
(nnApplications[lane].mModelReg1).initSessionFromBuffer((processors()->calibObjects.nnClusterizerNetworks[1])->getONNXModel(), (processors()->calibObjects.nnClusterizerNetworks[1])->getONNXModelSize()); // loads from CCDB
699700
}
700701
}
701702
if (nnApplications[lane].mModelsUsed[2]) {
@@ -710,7 +711,7 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
710711
if (!nn_settings.nnLoadFromCCDB) {
711712
(nnApplications[lane].mModelReg2).initSession(); // loads from file
712713
} else {
713-
(nnApplications[lane].mModelReg2).initSessionFromBuffer(processors()->calibObjects.nnClusterizerNetworks[2], processors()->calibObjects.nnClusterizerNetworkSizes[2]); // loads from CCDB
714+
(nnApplications[lane].mModelReg2).initSessionFromBuffer((processors()->calibObjects.nnClusterizerNetworks[2])->getONNXModel(), (processors()->calibObjects.nnClusterizerNetworks[2])->getONNXModelSize()); // loads from CCDB
714715
}
715716
}
716717
if (nn_settings.nnClusterizerVerbosity > 0) {
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file convert_onnx_to_root_serialized.C
13+
/// \brief Utility functions to be executed as a ROOT macro for uploading ONNX models to CCDB as ROOT serialized objects and vice versa
14+
/// \author Christian Sonnabend <christian.sonnabend@cern.ch>
15+
16+
// Example execution: root -l -b -q '/scratch/csonnabe/MyO2/O2/GPU/GPUTracking/utils/convert_onnx_to_root_serialized.C("/scratch/csonnabe/PhD/jobs/clusterization/NN/output/21082025_smallWindow_clean/SC/training_data_21082025_reco_noise_supressed_p3t6_CoGselected/SC/PbPb_24arp2/0_5/class1/regression/399_noMom/network/net_fp16.onnx", "", 1, 1, "nnCCDBLayerType=FC/nnCCDBWithMomentum=0/inputDType=FP16/nnCCDBInteractionRate=500/outputDType=FP16/nnCCDBEvalType=regression_c1/nnCCDBBeamType=pp/partName=blob/quality=3", 1, 4108971600000, "Users/c/csonnabe/TPC/Clusterization", "model.root")'
17+
18+
#include "ORTRootSerializer.h"
19+
#include "CCDB/CcdbApi.h"
20+
#include "CCDB/CcdbObjectInfo.h"
21+
#include "TFile.h"
22+
#include <fstream>
23+
#include <stdexcept>
24+
25+
o2::tpc::ORTRootSerializer serializer;
26+
27+
/// Dumps the char* to a .onnx file -> Directly readable by ONNX runtime or Netron
28+
void dumpOnnxToFile(const char* modelBuffer, uint32_t size, const std::string outputPath)
29+
{
30+
std::ofstream outFile(outputPath, std::ios::binary | std::ios::trunc);
31+
if (!outFile.is_open()) {
32+
throw std::runtime_error("Failed to open output ONNX file: " + outputPath);
33+
}
34+
outFile.write(modelBuffer, static_cast<std::streamsize>(size));
35+
if (!outFile) {
36+
throw std::runtime_error("Failed while writing data to: " + outputPath);
37+
}
38+
outFile.close();
39+
}
40+
41+
/// Initialize the serialization from an ONNX file
42+
void readOnnxModelFromFile(const std::string modelPath)
43+
{
44+
std::ifstream inFile(modelPath, std::ios::binary | std::ios::ate);
45+
if (!inFile.is_open()) {
46+
throw std::runtime_error("Could not open input ONNX file " + modelPath);
47+
}
48+
std::streamsize size = inFile.tellg();
49+
std::vector<char> mModelBuffer(size);
50+
inFile.seekg(0, std::ios::beg);
51+
if (!inFile.read(mModelBuffer.data(), size)) {
52+
throw std::runtime_error("Could not read input ONNX file " + modelPath);
53+
}
54+
inFile.close();
55+
serializer.setOnnxModel(mModelBuffer.data(), static_cast<uint32_t>(size));
56+
}
57+
58+
/// Initialize the serialization from a ROOT file
59+
void readRootModelFromFile(const std::string rootFilePath, std::string key)
60+
{
61+
TFile inRootFile(rootFilePath.c_str());
62+
if (inRootFile.IsZombie()) {
63+
throw std::runtime_error("Could not open input ROOT file " + rootFilePath);
64+
}
65+
auto* serPtr = inRootFile.Get<o2::tpc::ORTRootSerializer>(key.c_str());
66+
if (!serPtr) {
67+
throw std::runtime_error("Could not find " + key + " in ROOT file " + rootFilePath);
68+
}
69+
serializer = *serPtr;
70+
inRootFile.Close();
71+
}
72+
73+
/// Serialize the ONNX model to a ROOT object and store to file
74+
void onnxToRoot(std::string infile, std::string outfile, std::string key)
75+
{
76+
readOnnxModelFromFile(infile);
77+
TFile outRootFile(outfile.c_str(), "RECREATE");
78+
if (outRootFile.IsZombie()) {
79+
throw std::runtime_error("Could not create output ROOT file " + outfile);
80+
}
81+
outRootFile.WriteObject(&serializer, key.c_str());
82+
outRootFile.Close();
83+
}
84+
85+
/// Deserialize the ONNX model from a ROOT object and store to a .onnx file
86+
void rootToOnnx(std::string infile, std::string outfile, std::string key)
87+
{
88+
TFile inRootFile(infile.c_str());
89+
if (inRootFile.IsZombie()) {
90+
throw std::runtime_error("Could not open input ROOT file " + infile);
91+
}
92+
auto* serPtr = inRootFile.Get<o2::tpc::ORTRootSerializer>(key.c_str());
93+
if (!serPtr) {
94+
throw std::runtime_error("Could not find " + key + " in ROOT file " + infile);
95+
}
96+
serializer = *serPtr;
97+
98+
std::ofstream outFile(outfile, std::ios::binary | std::ios::trunc);
99+
if (!outFile.is_open()) {
100+
throw std::runtime_error("Failed to open output ONNX file: " + outfile);
101+
}
102+
outFile.write(serializer.getONNXModel(), static_cast<std::streamsize>(serializer.getONNXModelSize()));
103+
if (!outFile) {
104+
throw std::runtime_error("Failed while writing data to: " + outfile);
105+
}
106+
outFile.close();
107+
108+
inRootFile.Close();
109+
}
110+
111+
/// Upload the ONNX model to CCDB from an ONNX file
112+
/// !!! Adjust the metadata, path and validity !!!
113+
void uploadToCCDBFromONNX(std::string onnxFile,
114+
const std::map<std::string, std::string>& metadata,
115+
// { // some example metadata entries
116+
// "nnCCDBLayerType": "FC",
117+
// "nnCCDBWithMomentum": "0",
118+
// "inputDType": "FP16",
119+
// "nnCCDBInteractionRate": "500",
120+
// "outputDType": "FP16",
121+
// "nnCCDBEvalType": "regression_c1",
122+
// "nnCCDBBeamType": "pp",
123+
// "partName": "blob",
124+
// "quality": "3"
125+
// }
126+
long tsMin /* = 1 */,
127+
long tsMax /* = 4108971600000 */,
128+
std::string ccdbPath /* = "Users/c/csonnabe/TPC/Clusterization" */,
129+
std::string objname /* = "net_regression_r1.root" */,
130+
std::string ccdbUrl /* = "http://alice-ccdb.cern.ch" */)
131+
{
132+
readOnnxModelFromFile(onnxFile);
133+
134+
o2::ccdb::CcdbApi api;
135+
api.init(ccdbUrl);
136+
137+
// build full CCDB path including filename
138+
const std::string fullPath = ccdbPath;//.back() == '/' ? (ccdbPath + objname) : (ccdbPath + "/" + objname);
139+
140+
api.storeAsTFileAny(&serializer, fullPath, metadata, tsMin, tsMax);
141+
}
142+
143+
/// Upload the ONNX model to CCDB from a ROOT file
144+
/// !!! Adjust the metadata, path and validity !!!
145+
void uploadToCCDBFromROOT(std::string rootFile,
146+
const std::map<std::string, std::string>& metadata,
147+
long tsMin /* = 1 */,
148+
long tsMax /* = 4108971600000 */,
149+
std::string ccdbPath /* = "Users/c/csonnabe/TPC/Clusterization" */,
150+
std::string objname /* = "net_regression_r1.root" */,
151+
std::string ccdbUrl /* = "http://alice-ccdb.cern.ch" */)
152+
{
153+
// read ROOT file, extract ORTRootSerializer object and upload via storeAsTFileAny
154+
TFile inRootFile(rootFile.c_str());
155+
if (inRootFile.IsZombie()) {
156+
throw std::runtime_error("Could not open input ROOT file " + rootFile);
157+
}
158+
159+
// if objname is empty, fall back to default CCDB object key
160+
const std::string key = objname.empty() ? o2::ccdb::CcdbApi::CCDBOBJECT_ENTRY : objname;
161+
162+
auto* serPtr = inRootFile.Get<o2::tpc::ORTRootSerializer>(key.c_str());
163+
if (!serPtr) {
164+
inRootFile.Close();
165+
throw std::runtime_error("Could not find " + key + " in ROOT file " + rootFile);
166+
}
167+
serializer = *serPtr;
168+
169+
o2::ccdb::CcdbApi api;
170+
api.init(ccdbUrl);
171+
172+
// build full CCDB path including filename
173+
const std::string fullPath = ccdbPath;//.back() == '/' ? (ccdbPath + objname) : (ccdbPath + "/" + objname);
174+
175+
api.storeAsTFileAny(&serializer, fullPath, metadata, tsMin, tsMax);
176+
177+
inRootFile.Close();
178+
}
179+
180+
void convert_onnx_to_root_serialized(const std::string& onnxFile,
181+
const std::string& rootFile,
182+
int mode = 0,
183+
int ccdbUpload = 0,
184+
const std::string& metadataStr = "nnCCDBLayerType=FC/nnCCDBWithMomentum=0/inputDType=FP16/nnCCDBInteractionRate=500/outputDType=FP16/nnCCDBEvalType=regression_c1/nnCCDBBeamType=pp/partName=blob/quality=3",
185+
long tsMin = 1,
186+
long tsMax = 4108971600000,
187+
std::string ccdbPath = "Users/c/csonnabe/TPC/Clusterization",
188+
std::string objname = "net_regression_r1.root",
189+
std::string ccdbUrl = "http://alice-ccdb.cern.ch")
190+
{
191+
// parse metadataStr of the form key=value/key2=value2/...
192+
std::map<std::string, std::string> metadata;
193+
std::size_t start = 0;
194+
while (start < metadataStr.size()) {
195+
auto sep = metadataStr.find('/', start);
196+
auto token = metadataStr.substr(start, sep == std::string::npos ? std::string::npos : sep - start);
197+
if (!token.empty()) {
198+
auto eq = token.find('=');
199+
if (eq != std::string::npos && eq > 0 && eq + 1 < token.size()) {
200+
metadata.emplace(token.substr(0, eq), token.substr(eq + 1));
201+
}
202+
}
203+
if (sep == std::string::npos) {
204+
break;
205+
}
206+
start = sep + 1;
207+
}
208+
209+
if (ccdbUpload == 0){
210+
if (mode == 0)
211+
onnxToRoot(onnxFile, rootFile, o2::ccdb::CcdbApi::CCDBOBJECT_ENTRY);
212+
else if (mode == 1)
213+
rootToOnnx(rootFile, onnxFile, o2::ccdb::CcdbApi::CCDBOBJECT_ENTRY);
214+
} else if (ccdbUpload == 1){
215+
if (mode == 0)
216+
uploadToCCDBFromROOT(rootFile, metadata, tsMin, tsMax, ccdbPath, objname, ccdbUrl);
217+
else if (mode == 1)
218+
uploadToCCDBFromONNX(onnxFile, metadata, tsMin, tsMax, ccdbPath, objname, ccdbUrl);
219+
}
220+
}

0 commit comments

Comments
 (0)