|
| 1 | +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. |
| 2 | +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. |
| 3 | +// All rights not expressly granted are reserved. |
| 4 | +// |
| 5 | +// This software is distributed under the terms of the GNU General Public |
| 6 | +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". |
| 7 | +// |
| 8 | +// In applying this license CERN does not waive the privileges and immunities |
| 9 | +// granted to it by virtue of its status as an Intergovernmental Organization |
| 10 | +// or submit itself to any jurisdiction. |
| 11 | +// |
| 12 | +// *+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+* |
| 13 | +// Lambdakzero ML selection task |
| 14 | +// *+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+* |
| 15 | +// |
| 16 | +// Comments, questions, complaints, suggestions? |
| 17 | +// Please write to: |
| 18 | +// gianni.shigeru.setoue.liveraro@cern.ch |
| 19 | +// romain.schotter@cern.ch |
| 20 | +// david.dobrigkeit.chinellato@cern.ch |
| 21 | +// |
| 22 | + |
| 23 | +#include <Math/Vector4D.h> |
| 24 | +#include <cmath> |
| 25 | +#include <array> |
| 26 | +#include <cstdlib> |
| 27 | + |
| 28 | +#include "Framework/runDataProcessing.h" |
| 29 | +#include "Framework/AnalysisTask.h" |
| 30 | +#include "Framework/HistogramRegistry.h" |
| 31 | +#include "Framework/AnalysisDataModel.h" |
| 32 | +#include "Framework/ASoAHelpers.h" |
| 33 | +#include "Framework/ASoA.h" |
| 34 | +#include "ReconstructionDataFormats/Track.h" |
| 35 | +#include "Common/Core/RecoDecay.h" |
| 36 | +#include "Common/Core/trackUtilities.h" |
| 37 | +#include "PWGLF/DataModel/LFStrangenessTables.h" |
| 38 | +#include "PWGLF/DataModel/LFStrangenessPIDTables.h" |
| 39 | +#include "PWGLF/DataModel/LFStrangenessMLTables.h" |
| 40 | +#include "Common/Core/TrackSelection.h" |
| 41 | +#include "Common/DataModel/TrackSelectionTables.h" |
| 42 | +#include "Common/DataModel/EventSelection.h" |
| 43 | +#include "Common/DataModel/Centrality.h" |
| 44 | +#include "Common/DataModel/PIDResponse.h" |
| 45 | +#include "CCDB/BasicCCDBManager.h" |
| 46 | +#include <TFile.h> |
| 47 | +#include <TH2F.h> |
| 48 | +#include <TProfile.h> |
| 49 | +#include <TLorentzVector.h> |
| 50 | +#include <TPDGCode.h> |
| 51 | +#include <TDatabasePDG.h> |
| 52 | +#include "Tools/ML/MlResponse.h" |
| 53 | +#include "Tools/ML/model.h" |
| 54 | + |
| 55 | +using namespace o2; |
| 56 | +using namespace o2::analysis; |
| 57 | +using namespace o2::framework; |
| 58 | +using namespace o2::framework::expressions; |
| 59 | +using namespace o2::ml; |
| 60 | +using std::array; |
| 61 | +using std::cout; |
| 62 | +using std::endl; |
| 63 | + |
| 64 | +// For original data loops |
| 65 | +using CascOriginalDatas = soa::Join<aod::CascIndices, aod::CascCores>; |
| 66 | + |
| 67 | +// For derived data analysis |
| 68 | +using CascDerivedDatas = soa::Join<aod::CascCores, aod::CascExtras, aod::CascCollRefs>; |
| 69 | + |
| 70 | +struct cascademlselection { |
| 71 | + o2::ml::OnnxModel mlModelXiMinus; |
| 72 | + o2::ml::OnnxModel mlModelXiPlus; |
| 73 | + o2::ml::OnnxModel mlModelOmegaMinus; |
| 74 | + o2::ml::OnnxModel mlModelOmegaPlus; |
| 75 | + |
| 76 | + std::map<std::string, std::string> metadata; |
| 77 | + |
| 78 | + Produces<aod::CascXiMLScores> xiMLSelections; // optionally aggregate information from ML output for posterior analysis (derived data) |
| 79 | + Produces<aod::CascOmMLScores> omegaMLSelections; // optionally aggregate information from ML output for posterior analysis (derived data) |
| 80 | + |
| 81 | + HistogramRegistry histos{"Histos", {}, OutputObjHandlingPolicy::AnalysisObject}; |
| 82 | + |
| 83 | + // CCDB configuration |
| 84 | + o2::ccdb::CcdbApi ccdbApi; |
| 85 | + Service<o2::ccdb::BasicCCDBManager> ccdb; |
| 86 | + int mRunNumber; |
| 87 | + |
| 88 | + // CCDB options |
| 89 | + struct : ConfigurableGroup { |
| 90 | + Configurable<std::string> ccdburl{"ccdb-url", "http://alice-ccdb.cern.ch", "url of the ccdb repository"}; |
| 91 | + Configurable<std::string> grpPath{"grpPath", "GLO/GRP/GRP", "Path of the grp file"}; |
| 92 | + Configurable<std::string> grpmagPath{"grpmagPath", "GLO/Config/GRPMagField", "CCDB path of the GRPMagField object"}; |
| 93 | + Configurable<std::string> lutPath{"lutPath", "GLO/Param/MatLUT", "Path of the Lut parametrization"}; |
| 94 | + Configurable<std::string> geoPath{"geoPath", "GLO/Config/GeometryAligned", "Path of the geometry file"}; |
| 95 | + } ccdbConfigurations; |
| 96 | + |
| 97 | + // Machine learning evaluation for pre-selection and corresponding information generation |
| 98 | + struct : ConfigurableGroup { |
| 99 | + // ML classifiers: master flags to populate ML Selection tables |
| 100 | + Configurable<bool> calculateXiMinusScores{"mlConfigurations.calculateXiMinusScores", true, "calculate XiMinus ML scores"}; |
| 101 | + Configurable<bool> calculateXiPlusScores{"mlConfigurations.calculateXiPlusScores", true, "calculate XiPlus ML scores"}; |
| 102 | + Configurable<bool> calculateOmegaMinusScores{"mlConfigurations.calculateOmegaMinusScores", true, "calculate OmegaMinus ML scores"}; |
| 103 | + Configurable<bool> calculateOmegaPlusScores{"mlConfigurations.calculateOmegaPlusScores", true, "calculate OmegaPlus ML scores"}; |
| 104 | + |
| 105 | + // ML input for ML calculation |
| 106 | + Configurable<std::string> modelPathCCDB{"mlConfigurations.modelPathCCDB", "", "ML Model path in CCDB"}; |
| 107 | + Configurable<int64_t> timestampCCDB{"mlConfigurations.timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"}; |
| 108 | + Configurable<bool> loadModelsFromCCDB{"mlConfigurations.loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; |
| 109 | + Configurable<bool> enableOptimizations{"mlConfigurations.enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"}; |
| 110 | + |
| 111 | + // Local paths for test purposes |
| 112 | + Configurable<std::string> localModelPathXiMinus{"mlConfigurations.localModelPathXiMinus", "XiMinus_BDTModel.onnx", "(std::string) Path to the local .onnx file."}; |
| 113 | + Configurable<std::string> localModelPathXiPlus{"mlConfigurations.localModelPathXiPlus", "XiPlus_BDTModel.onnx", "(std::string) Path to the local .onnx file."}; |
| 114 | + Configurable<std::string> localModelPathOmegaMinus{"mlConfigurations.localModelPathOmegaMinus", "OmegaMinus_BDTModel.onnx", "(std::string) Path to the local .onnx file."}; |
| 115 | + Configurable<std::string> localModelPathOmegaPlus{"mlConfigurations.localModelPathOmegaPlus", "OmegaPlus_BDTModel.onnx", "(std::string) Path to the local .onnx file."}; |
| 116 | + |
| 117 | + // Thresholds for choosing to populate V0Cores tables with pre-selections |
| 118 | + Configurable<float> thresholdXiMinus{"mlConfigurations.thresholdXiMinus", -1.0f, "Threshold to keep XiMinus candidates"}; |
| 119 | + Configurable<float> thresholdXiPlus{"mlConfigurations.thresholdXiPlus", -1.0f, "Threshold to keep XiPlus candidates"}; |
| 120 | + Configurable<float> thresholdOmegaMinus{"mlConfigurations.thresholdOmegaMinus", -1.0f, "Threshold to keep OmegaMinus candidates"}; |
| 121 | + Configurable<float> thresholdOmegaPlus{"mlConfigurations.thresholdOmegaPlus", -1.0f, "Threshold to keep OmegaPlus candidates"}; |
| 122 | + } mlConfigurations; |
| 123 | + |
| 124 | + // Axis |
| 125 | + // base properties |
| 126 | + ConfigurableAxis vertexZ{"vertexZ", {30, -15.0f, 15.0f}, ""}; |
| 127 | + |
| 128 | + int nCandidates = 0; |
| 129 | + |
| 130 | + template <typename TCollision> |
| 131 | + void initCCDB(TCollision const& collision) |
| 132 | + { |
| 133 | + int64_t timeStampML = 0; |
| 134 | + if constexpr (requires { collision.timestamp(); }) { // we are in derived data |
| 135 | + if (mRunNumber == collision.runNumber()) { |
| 136 | + return; |
| 137 | + } |
| 138 | + mRunNumber = collision.runNumber(); |
| 139 | + timeStampML = collision.timestamp(); |
| 140 | + } |
| 141 | + if constexpr (requires { collision.template bc_as<aod::BCsWithTimestamps>(); }) { // we are in original data |
| 142 | + auto bc = collision.template bc_as<aod::BCsWithTimestamps>(); |
| 143 | + if (mRunNumber == bc.runNumber()) { |
| 144 | + return; |
| 145 | + } |
| 146 | + mRunNumber = bc.runNumber(); |
| 147 | + timeStampML = bc.timestamp(); |
| 148 | + } |
| 149 | + |
| 150 | + // machine learning initialization if requested |
| 151 | + if (mlConfigurations.calculateXiMinusScores || |
| 152 | + mlConfigurations.calculateXiPlusScores || |
| 153 | + mlConfigurations.calculateOmegaMinusScores || |
| 154 | + mlConfigurations.calculateOmegaPlusScores) { |
| 155 | + if (mlConfigurations.timestampCCDB.value != -1) |
| 156 | + timeStampML = mlConfigurations.timestampCCDB.value; |
| 157 | + LoadMachines(timeStampML); |
| 158 | + } |
| 159 | + } |
| 160 | + |
| 161 | + // function to load models for ML-based classifiers |
| 162 | + void LoadMachines(int64_t timeStampML) |
| 163 | + { |
| 164 | + if (mlConfigurations.loadModelsFromCCDB) { |
| 165 | + ccdbApi.init(ccdbConfigurations.ccdburl); |
| 166 | + LOG(info) << "Fetching cascade models for timestamp: " << timeStampML; |
| 167 | + |
| 168 | + if (mlConfigurations.calculateXiMinusScores) { |
| 169 | + bool retrieveSuccess = ccdbApi.retrieveBlob(mlConfigurations.modelPathCCDB, ".", metadata, timeStampML, false, mlConfigurations.localModelPathXiMinus.value); |
| 170 | + if (retrieveSuccess) { |
| 171 | + mlModelXiMinus.initModel(mlConfigurations.localModelPathXiMinus.value, mlConfigurations.enableOptimizations.value); |
| 172 | + } else { |
| 173 | + LOG(fatal) << "Error encountered while fetching/loading the XiMinus model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?"; |
| 174 | + } |
| 175 | + } |
| 176 | + |
| 177 | + if (mlConfigurations.calculateXiPlusScores) { |
| 178 | + bool retrieveSuccess = ccdbApi.retrieveBlob(mlConfigurations.modelPathCCDB, ".", metadata, timeStampML, false, mlConfigurations.localModelPathXiPlus.value); |
| 179 | + if (retrieveSuccess) { |
| 180 | + mlModelXiPlus.initModel(mlConfigurations.localModelPathXiPlus.value, mlConfigurations.enableOptimizations.value); |
| 181 | + } else { |
| 182 | + LOG(fatal) << "Error encountered while fetching/loading the XiPlus model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?"; |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + if (mlConfigurations.calculateOmegaMinusScores) { |
| 187 | + bool retrieveSuccess = ccdbApi.retrieveBlob(mlConfigurations.modelPathCCDB, ".", metadata, timeStampML, false, mlConfigurations.localModelPathOmegaMinus.value); |
| 188 | + if (retrieveSuccess) { |
| 189 | + mlModelOmegaMinus.initModel(mlConfigurations.localModelPathOmegaMinus.value, mlConfigurations.enableOptimizations.value); |
| 190 | + } else { |
| 191 | + LOG(fatal) << "Error encountered while fetching/loading the OmegaMinus model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?"; |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + if (mlConfigurations.calculateOmegaPlusScores) { |
| 196 | + bool retrieveSuccess = ccdbApi.retrieveBlob(mlConfigurations.modelPathCCDB, ".", metadata, timeStampML, false, mlConfigurations.localModelPathOmegaPlus.value); |
| 197 | + if (retrieveSuccess) { |
| 198 | + mlModelOmegaPlus.initModel(mlConfigurations.localModelPathOmegaPlus.value, mlConfigurations.enableOptimizations.value); |
| 199 | + } else { |
| 200 | + LOG(fatal) << "Error encountered while fetching/loading the OmegaPlus model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?"; |
| 201 | + } |
| 202 | + } |
| 203 | + } else { |
| 204 | + if (mlConfigurations.calculateXiMinusScores) |
| 205 | + mlModelXiMinus.initModel(mlConfigurations.localModelPathXiMinus.value, mlConfigurations.enableOptimizations.value); |
| 206 | + if (mlConfigurations.calculateXiPlusScores) |
| 207 | + mlModelXiPlus.initModel(mlConfigurations.localModelPathXiPlus.value, mlConfigurations.enableOptimizations.value); |
| 208 | + if (mlConfigurations.calculateOmegaMinusScores) |
| 209 | + mlModelOmegaMinus.initModel(mlConfigurations.localModelPathOmegaMinus.value, mlConfigurations.enableOptimizations.value); |
| 210 | + if (mlConfigurations.calculateOmegaPlusScores) |
| 211 | + mlModelOmegaPlus.initModel(mlConfigurations.localModelPathOmegaPlus.value, mlConfigurations.enableOptimizations.value); |
| 212 | + } |
| 213 | + LOG(info) << "Cascade ML Models loaded."; |
| 214 | + } |
| 215 | + |
| 216 | + void init(InitContext const&) |
| 217 | + { |
| 218 | + // Histograms |
| 219 | + histos.add("hEventVertexZ", "hEventVertexZ", kTH1F, {vertexZ}); |
| 220 | + |
| 221 | + ccdb->setURL(ccdbConfigurations.ccdburl); |
| 222 | + } |
| 223 | + |
| 224 | + // Process candidate and store properties in object |
| 225 | + template <typename TCascObject> |
| 226 | + void processCandidate(TCascObject const& cand) |
| 227 | + { |
| 228 | + // Select features |
| 229 | + // FIXME THIS NEEDS ADJUSTING |
| 230 | + std::vector<float> inputFeatures{0.0f, 0.0f, |
| 231 | + 0.0f, 0.0f}; |
| 232 | + |
| 233 | + // calculate scores |
| 234 | + if (cand.sign() < 0) { |
| 235 | + if (mlConfigurations.calculateXiMinusScores) { |
| 236 | + float* xiMinusProbability = mlModelXiMinus.evalModel(inputFeatures); |
| 237 | + xiMLSelections(xiMinusProbability[1]); |
| 238 | + } else { |
| 239 | + xiMLSelections(-1); |
| 240 | + } |
| 241 | + if (mlConfigurations.calculateOmegaMinusScores) { |
| 242 | + float* omegaMinusProbability = mlModelOmegaMinus.evalModel(inputFeatures); |
| 243 | + omegaMLSelections(omegaMinusProbability[1]); |
| 244 | + } else { |
| 245 | + omegaMLSelections(-1); |
| 246 | + } |
| 247 | + } |
| 248 | + if (cand.sign() > 0) { |
| 249 | + if (mlConfigurations.calculateXiPlusScores) { |
| 250 | + float* xiPlusProbability = mlModelXiPlus.evalModel(inputFeatures); |
| 251 | + xiMLSelections(xiPlusProbability[1]); |
| 252 | + } else { |
| 253 | + xiMLSelections(-1); |
| 254 | + } |
| 255 | + if (mlConfigurations.calculateOmegaPlusScores) { |
| 256 | + float* omegaPlusProbability = mlModelOmegaPlus.evalModel(inputFeatures); |
| 257 | + omegaMLSelections(omegaPlusProbability[1]); |
| 258 | + } else { |
| 259 | + omegaMLSelections(-1); |
| 260 | + } |
| 261 | + } |
| 262 | + } |
| 263 | + |
| 264 | + void processDerivedData(soa::Join<aod::StraCollisions, aod::StraStamps>::iterator const& collision, CascDerivedDatas const& cascades) |
| 265 | + { |
| 266 | + initCCDB(collision); |
| 267 | + |
| 268 | + histos.fill(HIST("hEventVertexZ"), collision.posZ()); |
| 269 | + for (auto& casc : cascades) { |
| 270 | + nCandidates++; |
| 271 | + if (nCandidates % 50000 == 0) { |
| 272 | + LOG(info) << "Candidates processed: " << nCandidates; |
| 273 | + } |
| 274 | + processCandidate(casc); |
| 275 | + } |
| 276 | + } |
| 277 | + void processStandardData(aod::Collision const& collision, CascOriginalDatas const& cascades) |
| 278 | + { |
| 279 | + initCCDB(collision); |
| 280 | + |
| 281 | + histos.fill(HIST("hEventVertexZ"), collision.posZ()); |
| 282 | + for (auto& casc : cascades) { |
| 283 | + nCandidates++; |
| 284 | + if (nCandidates % 50000 == 0) { |
| 285 | + LOG(info) << "Candidates processed: " << nCandidates; |
| 286 | + } |
| 287 | + processCandidate(casc); |
| 288 | + } |
| 289 | + } |
| 290 | + |
| 291 | + PROCESS_SWITCH(cascademlselection, processStandardData, "Process standard data", false); |
| 292 | + PROCESS_SWITCH(cascademlselection, processDerivedData, "Process derived data", true); |
| 293 | +}; |
| 294 | + |
| 295 | +WorkflowSpec defineDataProcessing(ConfigContext const& cfgc) |
| 296 | +{ |
| 297 | + return WorkflowSpec{adaptAnalysisTask<cascademlselection>(cfgc)}; |
| 298 | +} |
0 commit comments