|
| 1 | +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. |
| 2 | +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. |
| 3 | +// All rights not expressly granted are reserved. |
| 4 | +// |
| 5 | +// This software is distributed under the terms of the GNU General Public |
| 6 | +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". |
| 7 | +// |
| 8 | +// In applying this license CERN does not waive the privileges and immunities |
| 9 | +// granted to it by virtue of its status as an Intergovernmental Organization |
| 10 | +// or submit itself to any jurisdiction. |
| 11 | + |
| 12 | +/// \file convert_onnx_to_root_serialized.C |
| 13 | +/// \brief Utility functions to be executed as a ROOT macro for uploading ONNX models to CCDB as ROOT serialized objects and vice versa |
| 14 | +/// \author Christian Sonnabend <christian.sonnabend@cern.ch> |
| 15 | + |
| 16 | +// Example execution: root -l -b -q '/scratch/csonnabe/MyO2/O2/GPU/GPUTracking/utils/convert_onnx_to_root_serialized.C("/scratch/csonnabe/PhD/jobs/clusterization/NN/output/21082025_smallWindow_clean/SC/training_data_21082025_reco_noise_supressed_p3t6_CoGselected/SC/PbPb_24arp2/0_5/class1/regression/399_noMom/network/net_fp16.onnx", "", 1, 1, "nnCCDBLayerType=FC/nnCCDBWithMomentum=0/inputDType=FP16/nnCCDBInteractionRate=500/outputDType=FP16/nnCCDBEvalType=regression_c1/nnCCDBBeamType=pp/partName=blob/quality=3", 1, 4108971600000, "Users/c/csonnabe/TPC/Clusterization", "model.root")' |
| 17 | + |
| 18 | +#include "ORTRootSerializer.h" |
| 19 | +#include "CCDB/CcdbApi.h" |
| 20 | +#include "CCDB/CcdbObjectInfo.h" |
| 21 | +#include "TFile.h" |
| 22 | +#include <fstream> |
| 23 | +#include <stdexcept> |
| 24 | + |
| 25 | +o2::tpc::ORTRootSerializer serializer; |
| 26 | + |
| 27 | +/// Dumps the char* to a .onnx file -> Directly readable by ONNX runtime or Netron |
| 28 | +void dumpOnnxToFile(const char* modelBuffer, uint32_t size, const std::string outputPath) |
| 29 | +{ |
| 30 | + std::ofstream outFile(outputPath, std::ios::binary | std::ios::trunc); |
| 31 | + if (!outFile.is_open()) { |
| 32 | + throw std::runtime_error("Failed to open output ONNX file: " + outputPath); |
| 33 | + } |
| 34 | + outFile.write(modelBuffer, static_cast<std::streamsize>(size)); |
| 35 | + if (!outFile) { |
| 36 | + throw std::runtime_error("Failed while writing data to: " + outputPath); |
| 37 | + } |
| 38 | + outFile.close(); |
| 39 | +} |
| 40 | + |
| 41 | +/// Initialize the serialization from an ONNX file |
| 42 | +void readOnnxModelFromFile(const std::string modelPath) |
| 43 | +{ |
| 44 | + std::ifstream inFile(modelPath, std::ios::binary | std::ios::ate); |
| 45 | + if (!inFile.is_open()) { |
| 46 | + throw std::runtime_error("Could not open input ONNX file " + modelPath); |
| 47 | + } |
| 48 | + std::streamsize size = inFile.tellg(); |
| 49 | + std::vector<char> mModelBuffer(size); |
| 50 | + inFile.seekg(0, std::ios::beg); |
| 51 | + if (!inFile.read(mModelBuffer.data(), size)) { |
| 52 | + throw std::runtime_error("Could not read input ONNX file " + modelPath); |
| 53 | + } |
| 54 | + inFile.close(); |
| 55 | + serializer.setOnnxModel(mModelBuffer.data(), static_cast<uint32_t>(size)); |
| 56 | +} |
| 57 | + |
| 58 | +/// Initialize the serialization from a ROOT file |
| 59 | +void readRootModelFromFile(const std::string rootFilePath, std::string key) |
| 60 | +{ |
| 61 | + TFile inRootFile(rootFilePath.c_str()); |
| 62 | + if (inRootFile.IsZombie()) { |
| 63 | + throw std::runtime_error("Could not open input ROOT file " + rootFilePath); |
| 64 | + } |
| 65 | + auto* serPtr = inRootFile.Get<o2::tpc::ORTRootSerializer>(key.c_str()); |
| 66 | + if (!serPtr) { |
| 67 | + throw std::runtime_error("Could not find " + key + " in ROOT file " + rootFilePath); |
| 68 | + } |
| 69 | + serializer = *serPtr; |
| 70 | + inRootFile.Close(); |
| 71 | +} |
| 72 | + |
| 73 | +/// Serialize the ONNX model to a ROOT object and store to file |
| 74 | +void onnxToRoot(std::string infile, std::string outfile, std::string key) |
| 75 | +{ |
| 76 | + readOnnxModelFromFile(infile); |
| 77 | + TFile outRootFile(outfile.c_str(), "RECREATE"); |
| 78 | + if (outRootFile.IsZombie()) { |
| 79 | + throw std::runtime_error("Could not create output ROOT file " + outfile); |
| 80 | + } |
| 81 | + outRootFile.WriteObject(&serializer, key.c_str()); |
| 82 | + outRootFile.Close(); |
| 83 | +} |
| 84 | + |
| 85 | +/// Deserialize the ONNX model from a ROOT object and store to a .onnx file |
| 86 | +void rootToOnnx(std::string infile, std::string outfile, std::string key) |
| 87 | +{ |
| 88 | + TFile inRootFile(infile.c_str()); |
| 89 | + if (inRootFile.IsZombie()) { |
| 90 | + throw std::runtime_error("Could not open input ROOT file " + infile); |
| 91 | + } |
| 92 | + auto* serPtr = inRootFile.Get<o2::tpc::ORTRootSerializer>(key.c_str()); |
| 93 | + if (!serPtr) { |
| 94 | + throw std::runtime_error("Could not find " + key + " in ROOT file " + infile); |
| 95 | + } |
| 96 | + serializer = *serPtr; |
| 97 | + |
| 98 | + std::ofstream outFile(outfile, std::ios::binary | std::ios::trunc); |
| 99 | + if (!outFile.is_open()) { |
| 100 | + throw std::runtime_error("Failed to open output ONNX file: " + outfile); |
| 101 | + } |
| 102 | + outFile.write(serializer.getONNXModel(), static_cast<std::streamsize>(serializer.getONNXModelSize())); |
| 103 | + if (!outFile) { |
| 104 | + throw std::runtime_error("Failed while writing data to: " + outfile); |
| 105 | + } |
| 106 | + outFile.close(); |
| 107 | + |
| 108 | + inRootFile.Close(); |
| 109 | +} |
| 110 | + |
| 111 | +/// Upload the ONNX model to CCDB from an ONNX file |
| 112 | +/// !!! Adjust the metadata, path and validity !!! |
| 113 | +void uploadToCCDBFromONNX(std::string onnxFile, |
| 114 | + const std::map<std::string, std::string>& metadata, |
| 115 | + // { // some example metadata entries |
| 116 | + // "nnCCDBLayerType": "FC", |
| 117 | + // "nnCCDBWithMomentum": "0", |
| 118 | + // "inputDType": "FP16", |
| 119 | + // "nnCCDBInteractionRate": "500", |
| 120 | + // "outputDType": "FP16", |
| 121 | + // "nnCCDBEvalType": "regression_c1", |
| 122 | + // "nnCCDBBeamType": "pp", |
| 123 | + // "partName": "blob", |
| 124 | + // "quality": "3" |
| 125 | + // } |
| 126 | + long tsMin /* = 1 */, |
| 127 | + long tsMax /* = 4108971600000 */, |
| 128 | + std::string ccdbPath /* = "Users/c/csonnabe/TPC/Clusterization" */, |
| 129 | + std::string objname /* = "net_regression_r1.root" */, |
| 130 | + std::string ccdbUrl /* = "http://alice-ccdb.cern.ch" */) |
| 131 | +{ |
| 132 | + readOnnxModelFromFile(onnxFile); |
| 133 | + |
| 134 | + o2::ccdb::CcdbApi api; |
| 135 | + api.init(ccdbUrl); |
| 136 | + |
| 137 | + // build full CCDB path including filename |
| 138 | + const std::string fullPath = ccdbPath;//.back() == '/' ? (ccdbPath + objname) : (ccdbPath + "/" + objname); |
| 139 | + |
| 140 | + api.storeAsTFileAny(&serializer, fullPath, metadata, tsMin, tsMax); |
| 141 | +} |
| 142 | + |
| 143 | +/// Upload the ONNX model to CCDB from a ROOT file |
| 144 | +/// !!! Adjust the metadata, path and validity !!! |
| 145 | +void uploadToCCDBFromROOT(std::string rootFile, |
| 146 | + const std::map<std::string, std::string>& metadata, |
| 147 | + long tsMin /* = 1 */, |
| 148 | + long tsMax /* = 4108971600000 */, |
| 149 | + std::string ccdbPath /* = "Users/c/csonnabe/TPC/Clusterization" */, |
| 150 | + std::string objname /* = "net_regression_r1.root" */, |
| 151 | + std::string ccdbUrl /* = "http://alice-ccdb.cern.ch" */) |
| 152 | +{ |
| 153 | + // read ROOT file, extract ORTRootSerializer object and upload via storeAsTFileAny |
| 154 | + TFile inRootFile(rootFile.c_str()); |
| 155 | + if (inRootFile.IsZombie()) { |
| 156 | + throw std::runtime_error("Could not open input ROOT file " + rootFile); |
| 157 | + } |
| 158 | + |
| 159 | + // if objname is empty, fall back to default CCDB object key |
| 160 | + const std::string key = objname.empty() ? o2::ccdb::CcdbApi::CCDBOBJECT_ENTRY : objname; |
| 161 | + |
| 162 | + auto* serPtr = inRootFile.Get<o2::tpc::ORTRootSerializer>(key.c_str()); |
| 163 | + if (!serPtr) { |
| 164 | + inRootFile.Close(); |
| 165 | + throw std::runtime_error("Could not find " + key + " in ROOT file " + rootFile); |
| 166 | + } |
| 167 | + serializer = *serPtr; |
| 168 | + |
| 169 | + o2::ccdb::CcdbApi api; |
| 170 | + api.init(ccdbUrl); |
| 171 | + |
| 172 | + // build full CCDB path including filename |
| 173 | + const std::string fullPath = ccdbPath;//.back() == '/' ? (ccdbPath + objname) : (ccdbPath + "/" + objname); |
| 174 | + |
| 175 | + api.storeAsTFileAny(&serializer, fullPath, metadata, tsMin, tsMax); |
| 176 | + |
| 177 | + inRootFile.Close(); |
| 178 | +} |
| 179 | + |
| 180 | +void convert_onnx_to_root_serialized(const std::string& onnxFile, |
| 181 | + const std::string& rootFile, |
| 182 | + int mode = 0, |
| 183 | + int ccdbUpload = 0, |
| 184 | + const std::string& metadataStr = "nnCCDBLayerType=FC/nnCCDBWithMomentum=0/inputDType=FP16/nnCCDBInteractionRate=500/outputDType=FP16/nnCCDBEvalType=regression_c1/nnCCDBBeamType=pp/partName=blob/quality=3", |
| 185 | + long tsMin = 1, |
| 186 | + long tsMax = 4108971600000, |
| 187 | + std::string ccdbPath = "Users/c/csonnabe/TPC/Clusterization", |
| 188 | + std::string objname = "net_regression_r1.root", |
| 189 | + std::string ccdbUrl = "http://alice-ccdb.cern.ch") |
| 190 | +{ |
| 191 | + // parse metadataStr of the form key=value/key2=value2/... |
| 192 | + std::map<std::string, std::string> metadata; |
| 193 | + std::size_t start = 0; |
| 194 | + while (start < metadataStr.size()) { |
| 195 | + auto sep = metadataStr.find('/', start); |
| 196 | + auto token = metadataStr.substr(start, sep == std::string::npos ? std::string::npos : sep - start); |
| 197 | + if (!token.empty()) { |
| 198 | + auto eq = token.find('='); |
| 199 | + if (eq != std::string::npos && eq > 0 && eq + 1 < token.size()) { |
| 200 | + metadata.emplace(token.substr(0, eq), token.substr(eq + 1)); |
| 201 | + } |
| 202 | + } |
| 203 | + if (sep == std::string::npos) { |
| 204 | + break; |
| 205 | + } |
| 206 | + start = sep + 1; |
| 207 | + } |
| 208 | + |
| 209 | + if (ccdbUpload == 0){ |
| 210 | + if (mode == 0) |
| 211 | + onnxToRoot(onnxFile, rootFile, o2::ccdb::CcdbApi::CCDBOBJECT_ENTRY); |
| 212 | + else if (mode == 1) |
| 213 | + rootToOnnx(rootFile, onnxFile, o2::ccdb::CcdbApi::CCDBOBJECT_ENTRY); |
| 214 | + } else if (ccdbUpload == 1){ |
| 215 | + if (mode == 0) |
| 216 | + uploadToCCDBFromROOT(rootFile, metadata, tsMin, tsMax, ccdbPath, objname, ccdbUrl); |
| 217 | + else if (mode == 1) |
| 218 | + uploadToCCDBFromONNX(onnxFile, metadata, tsMin, tsMax, ccdbPath, objname, ccdbUrl); |
| 219 | + } |
| 220 | +} |
0 commit comments