Skip to content

Commit 4444597

Browse files
mcoquet642Maurice Coquetalibuild
authored
[PWGDQ] Usage of ML models for MFT-Muon matching (#13048)
Co-authored-by: Maurice Coquet <mcoquet@lxplus957.cern.ch> Co-authored-by: ALICE Action Bot <alibuild@cern.ch>
1 parent 29d5b02 commit 4444597

File tree

2 files changed

+439
-29
lines changed

2 files changed

+439
-29
lines changed
Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file MuonMatchingMlResponse.h
13+
/// \brief Class to compute the ML response for MFT-Muon matching
14+
/// \author Maurice Coquet <maurice.louis.coquet@cern.ch>
15+
16+
#ifndef PWGDQ_CORE_MUONMATCHINGMLRESPONSE_H_
17+
#define PWGDQ_CORE_MUONMATCHINGMLRESPONSE_H_
18+
19+
#include "Tools/ML/MlResponse.h"
20+
21+
#include <map>
22+
#include <string>
23+
#include <vector>
24+
25+
// Fill the map of available input features
26+
// the key is the feature's name (std::string)
27+
// the value is the corresponding value in EnumInputFeatures
28+
#define FILL_MAP_MFTMUON_MATCH(FEATURE) \
29+
{ \
30+
#FEATURE, static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE) \
31+
}
32+
33+
// Check if the index of mCachedIndices (index associated to a FEATURE)
34+
// matches the entry in EnumInputFeatures associated to this FEATURE
35+
// if so, the inputFeatures vector is filled with the FEATURE's value
36+
// by calling the corresponding GETTER=FEATURE from track
37+
#define CHECK_AND_FILL_MUON_TRACK(FEATURE, GETTER) \
38+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
39+
inputFeature = muon.GETTER(); \
40+
break; \
41+
}
42+
43+
// Check if the index of mCachedIndices (index associated to a FEATURE)
44+
// matches the entry in EnumInputFeatures associated to this FEATURE
45+
// if so, the inputFeatures vector is filled with the FEATURE's value
46+
// by calling the corresponding GETTER=FEATURE from track
47+
#define CHECK_AND_FILL_MFT_TRACK(FEATURE, GETTER) \
48+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
49+
inputFeature = mft.GETTER(); \
50+
break; \
51+
}
52+
53+
// Check if the index of mCachedIndices (index associated to a FEATURE)
54+
// matches the entry in EnumInputFeatures associated to this FEATURE
55+
// if so, the inputFeatures vector is filled with the FEATURE's value
56+
// by calling the corresponding GETTER=FEATURE from track
57+
#define CHECK_AND_FILL_MUON_COV(FEATURE, GETTER) \
58+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
59+
inputFeature = muoncov.GETTER(); \
60+
break; \
61+
}
62+
63+
// Check if the index of mCachedIndices (index associated to a FEATURE)
64+
// matches the entry in EnumInputFeatures associated to this FEATURE
65+
// if so, the inputFeatures vector is filled with the FEATURE's value
66+
// by calling the corresponding GETTER=FEATURE from track
67+
#define CHECK_AND_FILL_MFT_COV(FEATURE, GETTER) \
68+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
69+
inputFeature = mftcov.GETTER(); \
70+
break; \
71+
}
72+
73+
// Check if the index of mCachedIndices (index associated to a FEATURE)
74+
// matches the entry in EnumInputFeatures associated to this FEATURE
75+
// if so, the inputFeatures vector is filled with the FEATURE's value
76+
// by calling the corresponding GETTER1 and GETTER2 from track.
77+
#define CHECK_AND_FILL_MFTMUON_DIFF(FEATURE, GETTER1, GETTER2) \
78+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
79+
inputFeature = (mft.GETTER2() - muon.GETTER1()); \
80+
break; \
81+
}
82+
83+
// Check if the index of mCachedIndices (index associated to a FEATURE)
84+
// matches the entry in EnumInputFeatures associated to this FEATURE
85+
// if so, the inputFeatures vector is filled with the FEATURE's value
86+
// by calling the corresponding GETTER=FEATURE from collision
87+
#define CHECK_AND_FILL_MFTMUON_COLLISION(GETTER) \
88+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::GETTER): { \
89+
inputFeature = collision.GETTER(); \
90+
break; \
91+
}
92+
93+
namespace o2::analysis
94+
{
95+
// possible input features for ML
96+
enum class InputFeaturesMFTMuonMatch : uint8_t {
97+
zMatching,
98+
xMFT,
99+
yMFT,
100+
qOverptMFT,
101+
tglMFT,
102+
phiMFT,
103+
dcaXY,
104+
dcaZ,
105+
chi2MFT,
106+
nClustersMFT,
107+
xMCH,
108+
yMCH,
109+
qOverptMCH,
110+
tglMCH,
111+
phiMCH,
112+
nClustersMCH,
113+
chi2MCH,
114+
pdca,
115+
cXXMFT,
116+
cXYMFT,
117+
cYYMFT,
118+
cPhiYMFT,
119+
cPhiXMFT,
120+
cPhiPhiMFT,
121+
cTglYMFT,
122+
cTglXMFT,
123+
cTglPhiMFT,
124+
cTglTglMFT,
125+
c1PtYMFT,
126+
c1PtXMFT,
127+
c1PtPhiMFT,
128+
c1PtTglMFT,
129+
c1Pt21Pt2MFT,
130+
cXXMCH,
131+
cXYMCH,
132+
cYYMCH,
133+
cPhiYMCH,
134+
cPhiXMCH,
135+
cPhiPhiMCH,
136+
cTglYMCH,
137+
cTglXMCH,
138+
cTglPhiMCH,
139+
cTglTglMCH,
140+
c1PtYMCH,
141+
c1PtXMCH,
142+
c1PtPhiMCH,
143+
c1PtTglMCH,
144+
c1Pt21Pt2MCH,
145+
deltaX,
146+
deltaY,
147+
deltaPhi,
148+
deltaEta,
149+
deltaPt,
150+
posX,
151+
posY,
152+
posZ,
153+
numContrib,
154+
trackOccupancyInTimeRange,
155+
ft0cOccupancyInTimeRange,
156+
multFT0A,
157+
multFT0C,
158+
multNTracksPV,
159+
multNTracksPVeta1,
160+
multNTracksPVetaHalf,
161+
isInelGt0,
162+
isInelGt1,
163+
multFT0M,
164+
centFT0M,
165+
centFT0A,
166+
centFT0C,
167+
chi2MCHMFT
168+
};
169+
170+
template <typename TypeOutputScore = float>
171+
class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
172+
{
173+
public:
174+
/// Default constructor
175+
MlResponseMFTMuonMatch() = default;
176+
/// Default destructor
177+
virtual ~MlResponseMFTMuonMatch() = default;
178+
179+
template <typename T1, typename T2, typename C1, typename C2, typename U>
180+
float returnFeature(uint8_t idx, T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision)
181+
{
182+
float inputFeature = 0.;
183+
switch (idx) {
184+
CHECK_AND_FILL_MFT_TRACK(zMatching, z);
185+
CHECK_AND_FILL_MFT_TRACK(xMFT, x);
186+
CHECK_AND_FILL_MFT_TRACK(yMFT, y);
187+
CHECK_AND_FILL_MFT_TRACK(qOverptMFT, signed1Pt);
188+
CHECK_AND_FILL_MFT_TRACK(tglMFT, tgl);
189+
CHECK_AND_FILL_MFT_TRACK(phiMFT, phi);
190+
CHECK_AND_FILL_MFT_TRACK(chi2MFT, chi2);
191+
CHECK_AND_FILL_MFT_TRACK(nClustersMFT, nClusters);
192+
CHECK_AND_FILL_MUON_TRACK(dcaXY, fwddcaXY);
193+
CHECK_AND_FILL_MUON_TRACK(dcaZ, fwddcaz);
194+
CHECK_AND_FILL_MUON_TRACK(xMCH, x);
195+
CHECK_AND_FILL_MUON_TRACK(yMCH, y);
196+
CHECK_AND_FILL_MUON_TRACK(qOverptMCH, signed1Pt);
197+
CHECK_AND_FILL_MUON_TRACK(tglMCH, tgl);
198+
CHECK_AND_FILL_MUON_TRACK(phiMCH, phi);
199+
CHECK_AND_FILL_MUON_TRACK(nClustersMCH, nClusters);
200+
CHECK_AND_FILL_MUON_TRACK(chi2MCH, chi2);
201+
CHECK_AND_FILL_MUON_TRACK(pdca, pDca);
202+
CHECK_AND_FILL_MFT_COV(cXXMFT, cXX);
203+
CHECK_AND_FILL_MFT_COV(cXYMFT, cXY);
204+
CHECK_AND_FILL_MFT_COV(cYYMFT, cYY);
205+
CHECK_AND_FILL_MFT_COV(cPhiYMFT, cPhiY);
206+
CHECK_AND_FILL_MFT_COV(cPhiXMFT, cPhiX);
207+
CHECK_AND_FILL_MFT_COV(cPhiPhiMFT, cPhiPhi);
208+
CHECK_AND_FILL_MFT_COV(cTglYMFT, cTglY);
209+
CHECK_AND_FILL_MFT_COV(cTglXMFT, cTglX);
210+
CHECK_AND_FILL_MFT_COV(cTglPhiMFT, cTglPhi);
211+
CHECK_AND_FILL_MFT_COV(cTglTglMFT, cTglTgl);
212+
CHECK_AND_FILL_MFT_COV(c1PtYMFT, c1PtY);
213+
CHECK_AND_FILL_MFT_COV(c1PtXMFT, c1PtX);
214+
CHECK_AND_FILL_MFT_COV(c1PtPhiMFT, c1PtPhi);
215+
CHECK_AND_FILL_MFT_COV(c1PtTglMFT, c1PtTgl);
216+
CHECK_AND_FILL_MFT_COV(c1Pt21Pt2MFT, c1Pt21Pt2);
217+
CHECK_AND_FILL_MUON_COV(cXXMCH, cXX);
218+
CHECK_AND_FILL_MUON_COV(cXYMCH, cXY);
219+
CHECK_AND_FILL_MUON_COV(cYYMCH, cYY);
220+
CHECK_AND_FILL_MUON_COV(cPhiYMCH, cPhiY);
221+
CHECK_AND_FILL_MUON_COV(cPhiXMCH, cPhiX);
222+
CHECK_AND_FILL_MUON_COV(cPhiPhiMCH, cPhiPhi);
223+
CHECK_AND_FILL_MUON_COV(cTglYMCH, cTglY);
224+
CHECK_AND_FILL_MUON_COV(cTglXMCH, cTglX);
225+
CHECK_AND_FILL_MUON_COV(cTglPhiMCH, cTglPhi);
226+
CHECK_AND_FILL_MUON_COV(cTglTglMCH, cTglTgl);
227+
CHECK_AND_FILL_MUON_COV(c1PtYMCH, c1PtY);
228+
CHECK_AND_FILL_MUON_COV(c1PtXMCH, c1PtX);
229+
CHECK_AND_FILL_MUON_COV(c1PtPhiMCH, c1PtPhi);
230+
CHECK_AND_FILL_MUON_COV(c1PtTglMCH, c1PtTgl);
231+
CHECK_AND_FILL_MUON_COV(c1Pt21Pt2MCH, c1Pt21Pt2);
232+
CHECK_AND_FILL_MFTMUON_COLLISION(posX);
233+
CHECK_AND_FILL_MFTMUON_COLLISION(posY);
234+
CHECK_AND_FILL_MFTMUON_COLLISION(posZ);
235+
CHECK_AND_FILL_MFTMUON_COLLISION(numContrib);
236+
CHECK_AND_FILL_MFTMUON_COLLISION(trackOccupancyInTimeRange);
237+
CHECK_AND_FILL_MFTMUON_COLLISION(ft0cOccupancyInTimeRange);
238+
CHECK_AND_FILL_MFTMUON_COLLISION(multFT0A);
239+
CHECK_AND_FILL_MFTMUON_COLLISION(multFT0C);
240+
CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPV);
241+
CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPVeta1);
242+
CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPVetaHalf);
243+
CHECK_AND_FILL_MFTMUON_COLLISION(isInelGt0);
244+
CHECK_AND_FILL_MFTMUON_COLLISION(isInelGt1);
245+
CHECK_AND_FILL_MFTMUON_COLLISION(multFT0M);
246+
CHECK_AND_FILL_MFTMUON_COLLISION(centFT0M);
247+
CHECK_AND_FILL_MFTMUON_COLLISION(centFT0A);
248+
CHECK_AND_FILL_MFTMUON_COLLISION(centFT0C);
249+
CHECK_AND_FILL_MUON_TRACK(chi2MCHMFT, chi2MatchMCHMFT);
250+
}
251+
return inputFeature;
252+
}
253+
254+
template <typename T1>
255+
float returnFeatureTest(uint8_t idx, T1 const& muon)
256+
{
257+
float inputFeature = 0.;
258+
switch (idx) {
259+
CHECK_AND_FILL_MUON_TRACK(chi2MCHMFT, chi2MatchMCHMFT);
260+
}
261+
return inputFeature;
262+
}
263+
264+
/// Method to get the input features vector needed for ML inference
265+
/// \param track is the single track, \param collision is the collision
266+
/// \return inputFeatures vector
267+
template <typename T1, typename T2, typename C1, typename C2, typename U>
268+
std::vector<float> getInputFeatures(T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision)
269+
{
270+
std::vector<float> inputFeatures;
271+
for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) {
272+
float inputFeature = returnFeature(idx, muon, mft, muoncov, mftcov, collision);
273+
inputFeatures.emplace_back(inputFeature);
274+
}
275+
return inputFeatures;
276+
}
277+
278+
template <typename T1>
279+
std::vector<float> getInputFeaturesTest(T1 const& muon)
280+
{
281+
std::vector<float> inputFeatures;
282+
for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) {
283+
float inputFeature = returnFeatureTest(idx, muon);
284+
inputFeatures.emplace_back(inputFeature);
285+
}
286+
return inputFeatures;
287+
}
288+
289+
/// Method to get the value of variable chosen for binning
290+
/// \param track is the single track, \param collision is the collision
291+
/// \return binning variable
292+
template <typename T1, typename T2, typename C1, typename C2, typename U>
293+
float getBinningFeature(T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision)
294+
{
295+
return returnFeature(mCachedIndexBinning, muon, mft, muoncov, mftcov, collision);
296+
}
297+
298+
void cacheBinningIndex(std::string const& cfgBinningFeature)
299+
{
300+
setAvailableInputFeatures();
301+
if (MlResponse<TypeOutputScore>::mAvailableInputFeatures.count(cfgBinningFeature)) {
302+
mCachedIndexBinning = MlResponse<TypeOutputScore>::mAvailableInputFeatures[cfgBinningFeature];
303+
} else {
304+
LOG(fatal) << "Binning feature " << cfgBinningFeature << " not available! Please check your configurables.";
305+
}
306+
}
307+
308+
protected:
309+
/// Method to fill the map of available input features
310+
void setAvailableInputFeatures()
311+
{
312+
MlResponse<TypeOutputScore>::mAvailableInputFeatures = {
313+
FILL_MAP_MFTMUON_MATCH(zMatching),
314+
FILL_MAP_MFTMUON_MATCH(xMFT),
315+
FILL_MAP_MFTMUON_MATCH(yMFT),
316+
FILL_MAP_MFTMUON_MATCH(qOverptMFT),
317+
FILL_MAP_MFTMUON_MATCH(tglMFT),
318+
FILL_MAP_MFTMUON_MATCH(phiMFT),
319+
FILL_MAP_MFTMUON_MATCH(dcaXY),
320+
FILL_MAP_MFTMUON_MATCH(dcaZ),
321+
FILL_MAP_MFTMUON_MATCH(chi2MFT),
322+
FILL_MAP_MFTMUON_MATCH(nClustersMFT),
323+
FILL_MAP_MFTMUON_MATCH(xMCH),
324+
FILL_MAP_MFTMUON_MATCH(yMCH),
325+
FILL_MAP_MFTMUON_MATCH(qOverptMCH),
326+
FILL_MAP_MFTMUON_MATCH(tglMCH),
327+
FILL_MAP_MFTMUON_MATCH(phiMCH),
328+
FILL_MAP_MFTMUON_MATCH(nClustersMCH),
329+
FILL_MAP_MFTMUON_MATCH(chi2MCH),
330+
FILL_MAP_MFTMUON_MATCH(pdca),
331+
FILL_MAP_MFTMUON_MATCH(cXXMFT),
332+
FILL_MAP_MFTMUON_MATCH(cXYMFT),
333+
FILL_MAP_MFTMUON_MATCH(cYYMFT),
334+
FILL_MAP_MFTMUON_MATCH(cPhiYMFT),
335+
FILL_MAP_MFTMUON_MATCH(cPhiXMFT),
336+
FILL_MAP_MFTMUON_MATCH(cPhiPhiMFT),
337+
FILL_MAP_MFTMUON_MATCH(cTglYMFT),
338+
FILL_MAP_MFTMUON_MATCH(cTglXMFT),
339+
FILL_MAP_MFTMUON_MATCH(cTglPhiMFT),
340+
FILL_MAP_MFTMUON_MATCH(cTglTglMFT),
341+
FILL_MAP_MFTMUON_MATCH(c1PtYMFT),
342+
FILL_MAP_MFTMUON_MATCH(c1PtXMFT),
343+
FILL_MAP_MFTMUON_MATCH(c1PtPhiMFT),
344+
FILL_MAP_MFTMUON_MATCH(c1PtTglMFT),
345+
FILL_MAP_MFTMUON_MATCH(c1Pt21Pt2MFT),
346+
FILL_MAP_MFTMUON_MATCH(cXXMCH),
347+
FILL_MAP_MFTMUON_MATCH(cXYMCH),
348+
FILL_MAP_MFTMUON_MATCH(cYYMCH),
349+
FILL_MAP_MFTMUON_MATCH(cPhiYMCH),
350+
FILL_MAP_MFTMUON_MATCH(cPhiXMCH),
351+
FILL_MAP_MFTMUON_MATCH(cPhiPhiMCH),
352+
FILL_MAP_MFTMUON_MATCH(cTglYMCH),
353+
FILL_MAP_MFTMUON_MATCH(cTglXMCH),
354+
FILL_MAP_MFTMUON_MATCH(cTglPhiMCH),
355+
FILL_MAP_MFTMUON_MATCH(cTglTglMCH),
356+
FILL_MAP_MFTMUON_MATCH(c1PtYMCH),
357+
FILL_MAP_MFTMUON_MATCH(c1PtXMCH),
358+
FILL_MAP_MFTMUON_MATCH(c1PtPhiMCH),
359+
FILL_MAP_MFTMUON_MATCH(c1PtTglMCH),
360+
FILL_MAP_MFTMUON_MATCH(c1Pt21Pt2MCH),
361+
FILL_MAP_MFTMUON_MATCH(chi2MCHMFT)};
362+
}
363+
364+
uint8_t mCachedIndexBinning; // index correspondance between configurable and available input features
365+
};
366+
367+
} // namespace o2::analysis
368+
369+
#undef FILL_MAP_MFTMUON_MAP
370+
#undef CHECK_AND_FILL_MUON_TRACK
371+
#undef CHECK_AND_FILL_MFT_TRACK
372+
#undef CHECK_AND_FILL_MUON_COV
373+
#undef CHECK_AND_FILL_MFT_COV
374+
#undef CHECK_AND_FILL_MFTMUON_DIFF
375+
#undef CHECK_AND_FILL_MFTMUON_COLLISION
376+
377+
#endif // PWGDQ_CORE_MUONMATCHINGMLRESPONSE_H_

0 commit comments

Comments
 (0)