Skip to content

Commit 42adf65

Browse files
[PWGLF] Add task for a posteriori cascade ML score calculation (#9398)
1 parent 614bf09 commit 42adf65

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed

PWGLF/TableProducer/Strangeness/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ o2physics_add_dpl_workflow(lambdakzeromlselection
138138
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore O2Physics::MLCore
139139
COMPONENT_NAME Analysis)
140140

141+
o2physics_add_dpl_workflow(cascademlselection
142+
SOURCES cascademlselection.cxx
143+
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore O2Physics::MLCore
144+
COMPONENT_NAME Analysis)
145+
141146
o2physics_add_dpl_workflow(sigma0builder
142147
SOURCES sigma0builder.cxx
143148
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore O2Physics::MLCore
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
//
12+
// *+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*
13+
// Lambdakzero ML selection task
14+
// *+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*
15+
//
16+
// Comments, questions, complaints, suggestions?
17+
// Please write to:
18+
// gianni.shigeru.setoue.liveraro@cern.ch
19+
// romain.schotter@cern.ch
20+
// david.dobrigkeit.chinellato@cern.ch
21+
//
22+
23+
#include <Math/Vector4D.h>
24+
#include <cmath>
25+
#include <array>
26+
#include <cstdlib>
27+
28+
#include "Framework/runDataProcessing.h"
29+
#include "Framework/AnalysisTask.h"
30+
#include "Framework/HistogramRegistry.h"
31+
#include "Framework/AnalysisDataModel.h"
32+
#include "Framework/ASoAHelpers.h"
33+
#include "Framework/ASoA.h"
34+
#include "ReconstructionDataFormats/Track.h"
35+
#include "Common/Core/RecoDecay.h"
36+
#include "Common/Core/trackUtilities.h"
37+
#include "PWGLF/DataModel/LFStrangenessTables.h"
38+
#include "PWGLF/DataModel/LFStrangenessPIDTables.h"
39+
#include "PWGLF/DataModel/LFStrangenessMLTables.h"
40+
#include "Common/Core/TrackSelection.h"
41+
#include "Common/DataModel/TrackSelectionTables.h"
42+
#include "Common/DataModel/EventSelection.h"
43+
#include "Common/DataModel/Centrality.h"
44+
#include "Common/DataModel/PIDResponse.h"
45+
#include "CCDB/BasicCCDBManager.h"
46+
#include <TFile.h>
47+
#include <TH2F.h>
48+
#include <TProfile.h>
49+
#include <TLorentzVector.h>
50+
#include <TPDGCode.h>
51+
#include <TDatabasePDG.h>
52+
#include "Tools/ML/MlResponse.h"
53+
#include "Tools/ML/model.h"
54+
55+
using namespace o2;
56+
using namespace o2::analysis;
57+
using namespace o2::framework;
58+
using namespace o2::framework::expressions;
59+
using namespace o2::ml;
60+
using std::array;
61+
using std::cout;
62+
using std::endl;
63+
64+
// For original data loops
65+
using CascOriginalDatas = soa::Join<aod::CascIndices, aod::CascCores>;
66+
67+
// For derived data analysis
68+
using CascDerivedDatas = soa::Join<aod::CascCores, aod::CascExtras, aod::CascCollRefs>;
69+
70+
struct cascademlselection {
71+
o2::ml::OnnxModel mlModelXiMinus;
72+
o2::ml::OnnxModel mlModelXiPlus;
73+
o2::ml::OnnxModel mlModelOmegaMinus;
74+
o2::ml::OnnxModel mlModelOmegaPlus;
75+
76+
std::map<std::string, std::string> metadata;
77+
78+
Produces<aod::CascXiMLScores> xiMLSelections; // optionally aggregate information from ML output for posterior analysis (derived data)
79+
Produces<aod::CascOmMLScores> omegaMLSelections; // optionally aggregate information from ML output for posterior analysis (derived data)
80+
81+
HistogramRegistry histos{"Histos", {}, OutputObjHandlingPolicy::AnalysisObject};
82+
83+
// CCDB configuration
84+
o2::ccdb::CcdbApi ccdbApi;
85+
Service<o2::ccdb::BasicCCDBManager> ccdb;
86+
int mRunNumber;
87+
88+
// CCDB options
89+
struct : ConfigurableGroup {
90+
Configurable<std::string> ccdburl{"ccdb-url", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
91+
Configurable<std::string> grpPath{"grpPath", "GLO/GRP/GRP", "Path of the grp file"};
92+
Configurable<std::string> grpmagPath{"grpmagPath", "GLO/Config/GRPMagField", "CCDB path of the GRPMagField object"};
93+
Configurable<std::string> lutPath{"lutPath", "GLO/Param/MatLUT", "Path of the Lut parametrization"};
94+
Configurable<std::string> geoPath{"geoPath", "GLO/Config/GeometryAligned", "Path of the geometry file"};
95+
} ccdbConfigurations;
96+
97+
// Machine learning evaluation for pre-selection and corresponding information generation
98+
struct : ConfigurableGroup {
99+
// ML classifiers: master flags to populate ML Selection tables
100+
Configurable<bool> calculateXiMinusScores{"mlConfigurations.calculateXiMinusScores", true, "calculate XiMinus ML scores"};
101+
Configurable<bool> calculateXiPlusScores{"mlConfigurations.calculateXiPlusScores", true, "calculate XiPlus ML scores"};
102+
Configurable<bool> calculateOmegaMinusScores{"mlConfigurations.calculateOmegaMinusScores", true, "calculate OmegaMinus ML scores"};
103+
Configurable<bool> calculateOmegaPlusScores{"mlConfigurations.calculateOmegaPlusScores", true, "calculate OmegaPlus ML scores"};
104+
105+
// ML input for ML calculation
106+
Configurable<std::string> modelPathCCDB{"mlConfigurations.modelPathCCDB", "", "ML Model path in CCDB"};
107+
Configurable<int64_t> timestampCCDB{"mlConfigurations.timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"};
108+
Configurable<bool> loadModelsFromCCDB{"mlConfigurations.loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
109+
Configurable<bool> enableOptimizations{"mlConfigurations.enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};
110+
111+
// Local paths for test purposes
112+
Configurable<std::string> localModelPathXiMinus{"mlConfigurations.localModelPathXiMinus", "XiMinus_BDTModel.onnx", "(std::string) Path to the local .onnx file."};
113+
Configurable<std::string> localModelPathXiPlus{"mlConfigurations.localModelPathXiPlus", "XiPlus_BDTModel.onnx", "(std::string) Path to the local .onnx file."};
114+
Configurable<std::string> localModelPathOmegaMinus{"mlConfigurations.localModelPathOmegaMinus", "OmegaMinus_BDTModel.onnx", "(std::string) Path to the local .onnx file."};
115+
Configurable<std::string> localModelPathOmegaPlus{"mlConfigurations.localModelPathOmegaPlus", "OmegaPlus_BDTModel.onnx", "(std::string) Path to the local .onnx file."};
116+
117+
// Thresholds for choosing to populate V0Cores tables with pre-selections
118+
Configurable<float> thresholdXiMinus{"mlConfigurations.thresholdXiMinus", -1.0f, "Threshold to keep XiMinus candidates"};
119+
Configurable<float> thresholdXiPlus{"mlConfigurations.thresholdXiPlus", -1.0f, "Threshold to keep XiPlus candidates"};
120+
Configurable<float> thresholdOmegaMinus{"mlConfigurations.thresholdOmegaMinus", -1.0f, "Threshold to keep OmegaMinus candidates"};
121+
Configurable<float> thresholdOmegaPlus{"mlConfigurations.thresholdOmegaPlus", -1.0f, "Threshold to keep OmegaPlus candidates"};
122+
} mlConfigurations;
123+
124+
// Axis
125+
// base properties
126+
ConfigurableAxis vertexZ{"vertexZ", {30, -15.0f, 15.0f}, ""};
127+
128+
int nCandidates = 0;
129+
130+
template <typename TCollision>
131+
void initCCDB(TCollision const& collision)
132+
{
133+
int64_t timeStampML = 0;
134+
if constexpr (requires { collision.timestamp(); }) { // we are in derived data
135+
if (mRunNumber == collision.runNumber()) {
136+
return;
137+
}
138+
mRunNumber = collision.runNumber();
139+
timeStampML = collision.timestamp();
140+
}
141+
if constexpr (requires { collision.template bc_as<aod::BCsWithTimestamps>(); }) { // we are in original data
142+
auto bc = collision.template bc_as<aod::BCsWithTimestamps>();
143+
if (mRunNumber == bc.runNumber()) {
144+
return;
145+
}
146+
mRunNumber = bc.runNumber();
147+
timeStampML = bc.timestamp();
148+
}
149+
150+
// machine learning initialization if requested
151+
if (mlConfigurations.calculateXiMinusScores ||
152+
mlConfigurations.calculateXiPlusScores ||
153+
mlConfigurations.calculateOmegaMinusScores ||
154+
mlConfigurations.calculateOmegaPlusScores) {
155+
if (mlConfigurations.timestampCCDB.value != -1)
156+
timeStampML = mlConfigurations.timestampCCDB.value;
157+
LoadMachines(timeStampML);
158+
}
159+
}
160+
161+
// function to load models for ML-based classifiers
162+
void LoadMachines(int64_t timeStampML)
163+
{
164+
if (mlConfigurations.loadModelsFromCCDB) {
165+
ccdbApi.init(ccdbConfigurations.ccdburl);
166+
LOG(info) << "Fetching cascade models for timestamp: " << timeStampML;
167+
168+
if (mlConfigurations.calculateXiMinusScores) {
169+
bool retrieveSuccess = ccdbApi.retrieveBlob(mlConfigurations.modelPathCCDB, ".", metadata, timeStampML, false, mlConfigurations.localModelPathXiMinus.value);
170+
if (retrieveSuccess) {
171+
mlModelXiMinus.initModel(mlConfigurations.localModelPathXiMinus.value, mlConfigurations.enableOptimizations.value);
172+
} else {
173+
LOG(fatal) << "Error encountered while fetching/loading the XiMinus model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
174+
}
175+
}
176+
177+
if (mlConfigurations.calculateXiPlusScores) {
178+
bool retrieveSuccess = ccdbApi.retrieveBlob(mlConfigurations.modelPathCCDB, ".", metadata, timeStampML, false, mlConfigurations.localModelPathXiPlus.value);
179+
if (retrieveSuccess) {
180+
mlModelXiPlus.initModel(mlConfigurations.localModelPathXiPlus.value, mlConfigurations.enableOptimizations.value);
181+
} else {
182+
LOG(fatal) << "Error encountered while fetching/loading the XiPlus model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
183+
}
184+
}
185+
186+
if (mlConfigurations.calculateOmegaMinusScores) {
187+
bool retrieveSuccess = ccdbApi.retrieveBlob(mlConfigurations.modelPathCCDB, ".", metadata, timeStampML, false, mlConfigurations.localModelPathOmegaMinus.value);
188+
if (retrieveSuccess) {
189+
mlModelOmegaMinus.initModel(mlConfigurations.localModelPathOmegaMinus.value, mlConfigurations.enableOptimizations.value);
190+
} else {
191+
LOG(fatal) << "Error encountered while fetching/loading the OmegaMinus model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
192+
}
193+
}
194+
195+
if (mlConfigurations.calculateOmegaPlusScores) {
196+
bool retrieveSuccess = ccdbApi.retrieveBlob(mlConfigurations.modelPathCCDB, ".", metadata, timeStampML, false, mlConfigurations.localModelPathOmegaPlus.value);
197+
if (retrieveSuccess) {
198+
mlModelOmegaPlus.initModel(mlConfigurations.localModelPathOmegaPlus.value, mlConfigurations.enableOptimizations.value);
199+
} else {
200+
LOG(fatal) << "Error encountered while fetching/loading the OmegaPlus model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
201+
}
202+
}
203+
} else {
204+
if (mlConfigurations.calculateXiMinusScores)
205+
mlModelXiMinus.initModel(mlConfigurations.localModelPathXiMinus.value, mlConfigurations.enableOptimizations.value);
206+
if (mlConfigurations.calculateXiPlusScores)
207+
mlModelXiPlus.initModel(mlConfigurations.localModelPathXiPlus.value, mlConfigurations.enableOptimizations.value);
208+
if (mlConfigurations.calculateOmegaMinusScores)
209+
mlModelOmegaMinus.initModel(mlConfigurations.localModelPathOmegaMinus.value, mlConfigurations.enableOptimizations.value);
210+
if (mlConfigurations.calculateOmegaPlusScores)
211+
mlModelOmegaPlus.initModel(mlConfigurations.localModelPathOmegaPlus.value, mlConfigurations.enableOptimizations.value);
212+
}
213+
LOG(info) << "Cascade ML Models loaded.";
214+
}
215+
216+
void init(InitContext const&)
217+
{
218+
// Histograms
219+
histos.add("hEventVertexZ", "hEventVertexZ", kTH1F, {vertexZ});
220+
221+
ccdb->setURL(ccdbConfigurations.ccdburl);
222+
}
223+
224+
// Process candidate and store properties in object
225+
template <typename TCascObject>
226+
void processCandidate(TCascObject const& cand)
227+
{
228+
// Select features
229+
// FIXME THIS NEEDS ADJUSTING
230+
std::vector<float> inputFeatures{0.0f, 0.0f,
231+
0.0f, 0.0f};
232+
233+
// calculate scores
234+
if (cand.sign() < 0) {
235+
if (mlConfigurations.calculateXiMinusScores) {
236+
float* xiMinusProbability = mlModelXiMinus.evalModel(inputFeatures);
237+
xiMLSelections(xiMinusProbability[1]);
238+
} else {
239+
xiMLSelections(-1);
240+
}
241+
if (mlConfigurations.calculateOmegaMinusScores) {
242+
float* omegaMinusProbability = mlModelOmegaMinus.evalModel(inputFeatures);
243+
omegaMLSelections(omegaMinusProbability[1]);
244+
} else {
245+
omegaMLSelections(-1);
246+
}
247+
}
248+
if (cand.sign() > 0) {
249+
if (mlConfigurations.calculateXiPlusScores) {
250+
float* xiPlusProbability = mlModelXiPlus.evalModel(inputFeatures);
251+
xiMLSelections(xiPlusProbability[1]);
252+
} else {
253+
xiMLSelections(-1);
254+
}
255+
if (mlConfigurations.calculateOmegaPlusScores) {
256+
float* omegaPlusProbability = mlModelOmegaPlus.evalModel(inputFeatures);
257+
omegaMLSelections(omegaPlusProbability[1]);
258+
} else {
259+
omegaMLSelections(-1);
260+
}
261+
}
262+
}
263+
264+
void processDerivedData(soa::Join<aod::StraCollisions, aod::StraStamps>::iterator const& collision, CascDerivedDatas const& cascades)
265+
{
266+
initCCDB(collision);
267+
268+
histos.fill(HIST("hEventVertexZ"), collision.posZ());
269+
for (auto& casc : cascades) {
270+
nCandidates++;
271+
if (nCandidates % 50000 == 0) {
272+
LOG(info) << "Candidates processed: " << nCandidates;
273+
}
274+
processCandidate(casc);
275+
}
276+
}
277+
void processStandardData(aod::Collision const& collision, CascOriginalDatas const& cascades)
278+
{
279+
initCCDB(collision);
280+
281+
histos.fill(HIST("hEventVertexZ"), collision.posZ());
282+
for (auto& casc : cascades) {
283+
nCandidates++;
284+
if (nCandidates % 50000 == 0) {
285+
LOG(info) << "Candidates processed: " << nCandidates;
286+
}
287+
processCandidate(casc);
288+
}
289+
}
290+
291+
PROCESS_SWITCH(cascademlselection, processStandardData, "Process standard data", false);
292+
PROCESS_SWITCH(cascademlselection, processDerivedData, "Process derived data", true);
293+
};
294+
295+
WorkflowSpec defineDataProcessing(ConfigContext const& cfgc)
296+
{
297+
return WorkflowSpec{adaptAnalysisTask<cascademlselection>(cfgc)};
298+
}

0 commit comments

Comments
 (0)