Skip to content

Commit e54cf0d

Browse files
author
Maurice Coquet
committed
Usage of ML models for MFT-Muon matching
1 parent 1016b75 commit e54cf0d

File tree

2 files changed

+409
-0
lines changed

2 files changed

+409
-0
lines changed
Lines changed: 378 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file MuonMatchingMlResponse.h
13+
/// \brief Class to compute the ML response for MFT-Muon matching
14+
15+
#ifndef PWGDQ_DILEPTON_UTILS_MLRESPONSEMFTMUONMATCHING_H_
16+
#define PWGDQ_DILEPTON_UTILS_MLRESPONSEMFTMUONMATCHING_H_
17+
18+
#include "Tools/ML/MlResponse.h"
19+
20+
#include <map>
21+
#include <string>
22+
#include <vector>
23+
24+
// Fill the map of available input features
25+
// the key is the feature's name (std::string)
26+
// the value is the corresponding value in EnumInputFeatures
27+
#define FILL_MAP_MFTMUON_MATCH(FEATURE) \
28+
{ \
29+
#FEATURE, static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE) \
30+
}
31+
32+
// Check if the index of mCachedIndices (index associated to a FEATURE)
33+
// matches the entry in EnumInputFeatures associated to this FEATURE
34+
// if so, the inputFeatures vector is filled with the FEATURE's value
35+
// by calling the corresponding GETTER=FEATURE from track
36+
#define CHECK_AND_FILL_MUON_TRACK(FEATURE, GETTER) \
37+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
38+
inputFeature = muon.GETTER(); \
39+
break; \
40+
}
41+
42+
// Check if the index of mCachedIndices (index associated to a FEATURE)
43+
// matches the entry in EnumInputFeatures associated to this FEATURE
44+
// if so, the inputFeatures vector is filled with the FEATURE's value
45+
// by calling the corresponding GETTER=FEATURE from track
46+
#define CHECK_AND_FILL_MFT_TRACK(FEATURE, GETTER) \
47+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
48+
inputFeature = mft.GETTER(); \
49+
break; \
50+
}
51+
52+
// Check if the index of mCachedIndices (index associated to a FEATURE)
53+
// matches the entry in EnumInputFeatures associated to this FEATURE
54+
// if so, the inputFeatures vector is filled with the FEATURE's value
55+
// by calling the corresponding GETTER=FEATURE from track
56+
#define CHECK_AND_FILL_MUON_COV(FEATURE, GETTER) \
57+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
58+
inputFeature = muoncov.GETTER(); \
59+
break; \
60+
}
61+
62+
// Check if the index of mCachedIndices (index associated to a FEATURE)
63+
// matches the entry in EnumInputFeatures associated to this FEATURE
64+
// if so, the inputFeatures vector is filled with the FEATURE's value
65+
// by calling the corresponding GETTER=FEATURE from track
66+
#define CHECK_AND_FILL_MFT_COV(FEATURE, GETTER) \
67+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
68+
inputFeature = mftcov.GETTER(); \
69+
break; \
70+
}
71+
72+
// Check if the index of mCachedIndices (index associated to a FEATURE)
73+
// matches the entry in EnumInputFeatures associated to this FEATURE
74+
// if so, the inputFeatures vector is filled with the FEATURE's value
75+
// by calling the corresponding GETTER1 and GETTER2 from track.
76+
#define CHECK_AND_FILL_MFTMUON_DIFF(FEATURE, GETTER1, GETTER2) \
77+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
78+
inputFeature = (mft.GETTER2() - muon.GETTER1()); \
79+
break; \
80+
}
81+
82+
// Check if the index of mCachedIndices (index associated to a FEATURE)
83+
// matches the entry in EnumInputFeatures associated to this FEATURE
84+
// if so, the inputFeatures vector is filled with the FEATURE's value
85+
// by calling the corresponding GETTER=FEATURE from collision
86+
#define CHECK_AND_FILL_MFTMUON_COLLISION(GETTER) \
87+
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::GETTER): { \
88+
inputFeature = collision.GETTER(); \
89+
break; \
90+
}
91+
92+
namespace o2::analysis
93+
{
94+
// possible input features for ML
95+
enum class InputFeaturesMFTMuonMatch : uint8_t {
96+
zMatching,
97+
xMFT,
98+
yMFT,
99+
qOverptMFT,
100+
tglMFT,
101+
phiMFT,
102+
dcaXY,
103+
dcaZ,
104+
chi2MFT,
105+
nClustersMFT,
106+
xMCH,
107+
yMCH,
108+
qOverptMCH,
109+
tglMCH,
110+
phiMCH,
111+
nClustersMCH,
112+
chi2MCH,
113+
pdca,
114+
cXXMFT,
115+
cXYMFT,
116+
cYYMFT,
117+
cPhiYMFT,
118+
cPhiXMFT,
119+
cPhiPhiMFT,
120+
cTglYMFT,
121+
cTglXMFT,
122+
cTglPhiMFT,
123+
cTglTglMFT,
124+
c1PtYMFT,
125+
c1PtXMFT,
126+
c1PtPhiMFT,
127+
c1PtTglMFT,
128+
c1Pt21Pt2MFT,
129+
cXXMCH,
130+
cXYMCH,
131+
cYYMCH,
132+
cPhiYMCH,
133+
cPhiXMCH,
134+
cPhiPhiMCH,
135+
cTglYMCH,
136+
cTglXMCH,
137+
cTglPhiMCH,
138+
cTglTglMCH,
139+
c1PtYMCH,
140+
c1PtXMCH,
141+
c1PtPhiMCH,
142+
c1PtTglMCH,
143+
c1Pt21Pt2MCH,
144+
deltaX,
145+
deltaY,
146+
deltaPhi,
147+
deltaEta,
148+
deltaPt,
149+
posX,
150+
posY,
151+
posZ,
152+
numContrib,
153+
trackOccupancyInTimeRange,
154+
ft0cOccupancyInTimeRange,
155+
multFT0A,
156+
multFT0C,
157+
multNTracksPV,
158+
multNTracksPVeta1,
159+
multNTracksPVetaHalf,
160+
isInelGt0,
161+
isInelGt1,
162+
multFT0M,
163+
centFT0M,
164+
centFT0A,
165+
centFT0C,
166+
chi2MCHMFT
167+
};
168+
169+
template <typename TypeOutputScore = float>
170+
class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
171+
{
172+
public:
173+
/// Default constructor
174+
MlResponseMFTMuonMatch() = default;
175+
/// Default destructor
176+
virtual ~MlResponseMFTMuonMatch() = default;
177+
178+
template <typename T1, typename T2, typename C1, typename C2, typename U>
179+
float return_feature(uint8_t idx, T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision)
180+
{
181+
float inputFeature = 0.;
182+
switch (idx) {
183+
CHECK_AND_FILL_MFT_TRACK(zMatching,z);
184+
CHECK_AND_FILL_MFT_TRACK(xMFT,x);
185+
CHECK_AND_FILL_MFT_TRACK(yMFT,y);
186+
CHECK_AND_FILL_MFT_TRACK(qOverptMFT,signed1Pt);
187+
CHECK_AND_FILL_MFT_TRACK(tglMFT,tgl);
188+
CHECK_AND_FILL_MFT_TRACK(phiMFT,phi);
189+
CHECK_AND_FILL_MFT_TRACK(chi2MFT,chi2);
190+
CHECK_AND_FILL_MFT_TRACK(nClustersMFT,nClusters);
191+
CHECK_AND_FILL_MUON_TRACK(dcaXY, fwddcaXY);
192+
CHECK_AND_FILL_MUON_TRACK(dcaZ, fwddcaz);
193+
CHECK_AND_FILL_MUON_TRACK(xMCH,x);
194+
CHECK_AND_FILL_MUON_TRACK(yMCH,y);
195+
CHECK_AND_FILL_MUON_TRACK(qOverptMCH,signed1Pt);
196+
CHECK_AND_FILL_MUON_TRACK(tglMCH,tgl);
197+
CHECK_AND_FILL_MUON_TRACK(phiMCH,phi);
198+
CHECK_AND_FILL_MUON_TRACK(nClustersMCH,nClusters);
199+
CHECK_AND_FILL_MUON_TRACK(chi2MCH,chi2);
200+
CHECK_AND_FILL_MUON_TRACK(pdca,pDca);
201+
CHECK_AND_FILL_MFT_COV(cXXMFT,cXX);
202+
CHECK_AND_FILL_MFT_COV(cXYMFT,cXY);
203+
CHECK_AND_FILL_MFT_COV(cYYMFT,cYY);
204+
CHECK_AND_FILL_MFT_COV(cPhiYMFT,cPhiY);
205+
CHECK_AND_FILL_MFT_COV(cPhiXMFT,cPhiX);
206+
CHECK_AND_FILL_MFT_COV(cPhiPhiMFT,cPhiPhi);
207+
CHECK_AND_FILL_MFT_COV(cTglYMFT,cTglY);
208+
CHECK_AND_FILL_MFT_COV(cTglXMFT,cTglX);
209+
CHECK_AND_FILL_MFT_COV(cTglPhiMFT,cTglPhi);
210+
CHECK_AND_FILL_MFT_COV(cTglTglMFT,cTglTgl);
211+
CHECK_AND_FILL_MFT_COV(c1PtYMFT,c1PtY);
212+
CHECK_AND_FILL_MFT_COV(c1PtXMFT,c1PtX);
213+
CHECK_AND_FILL_MFT_COV(c1PtPhiMFT,c1PtPhi);
214+
CHECK_AND_FILL_MFT_COV(c1PtTglMFT,c1PtTgl);
215+
CHECK_AND_FILL_MFT_COV(c1Pt21Pt2MFT,c1Pt21Pt2);
216+
CHECK_AND_FILL_MUON_COV(cXXMCH,cXX);
217+
CHECK_AND_FILL_MUON_COV(cXYMCH,cXY);
218+
CHECK_AND_FILL_MUON_COV(cYYMCH,cYY);
219+
CHECK_AND_FILL_MUON_COV(cPhiYMCH,cPhiY);
220+
CHECK_AND_FILL_MUON_COV(cPhiXMCH,cPhiX);
221+
CHECK_AND_FILL_MUON_COV(cPhiPhiMCH,cPhiPhi);
222+
CHECK_AND_FILL_MUON_COV(cTglYMCH,cTglY);
223+
CHECK_AND_FILL_MUON_COV(cTglXMCH,cTglX);
224+
CHECK_AND_FILL_MUON_COV(cTglPhiMCH,cTglPhi);
225+
CHECK_AND_FILL_MUON_COV(cTglTglMCH,cTglTgl);
226+
CHECK_AND_FILL_MUON_COV(c1PtYMCH,c1PtY);
227+
CHECK_AND_FILL_MUON_COV(c1PtXMCH,c1PtX);
228+
CHECK_AND_FILL_MUON_COV(c1PtPhiMCH,c1PtPhi);
229+
CHECK_AND_FILL_MUON_COV(c1PtTglMCH,c1PtTgl);
230+
CHECK_AND_FILL_MUON_COV(c1Pt21Pt2MCH,c1Pt21Pt2);
231+
CHECK_AND_FILL_MFTMUON_COLLISION(posX);
232+
CHECK_AND_FILL_MFTMUON_COLLISION(posY);
233+
CHECK_AND_FILL_MFTMUON_COLLISION(posZ);
234+
CHECK_AND_FILL_MFTMUON_COLLISION(numContrib);
235+
CHECK_AND_FILL_MFTMUON_COLLISION(trackOccupancyInTimeRange);
236+
CHECK_AND_FILL_MFTMUON_COLLISION(ft0cOccupancyInTimeRange);
237+
CHECK_AND_FILL_MFTMUON_COLLISION(multFT0A);
238+
CHECK_AND_FILL_MFTMUON_COLLISION(multFT0C);
239+
CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPV);
240+
CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPVeta1);
241+
CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPVetaHalf);
242+
CHECK_AND_FILL_MFTMUON_COLLISION(isInelGt0);
243+
CHECK_AND_FILL_MFTMUON_COLLISION(isInelGt1);
244+
CHECK_AND_FILL_MFTMUON_COLLISION(multFT0M);
245+
CHECK_AND_FILL_MFTMUON_COLLISION(centFT0M);
246+
CHECK_AND_FILL_MFTMUON_COLLISION(centFT0A);
247+
CHECK_AND_FILL_MFTMUON_COLLISION(centFT0C);
248+
CHECK_AND_FILL_MUON_TRACK(chi2MCHMFT,chi2MatchMCHMFT);
249+
}
250+
return inputFeature;
251+
}
252+
253+
template <typename T1>
254+
float return_featureTest(uint8_t idx, T1 const& muon)
255+
{
256+
float inputFeature = 0.;
257+
switch (idx) {
258+
CHECK_AND_FILL_MUON_TRACK(chi2MCHMFT,chi2MatchMCHMFT);
259+
}
260+
return inputFeature;
261+
}
262+
263+
264+
/// Method to get the input features vector needed for ML inference
265+
/// \param track is the single track, \param collision is the collision
266+
/// \return inputFeatures vector
267+
template <typename T1, typename T2, typename C1, typename C2, typename U>
268+
std::vector<float> getInputFeatures(T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision)
269+
{
270+
std::vector<float> inputFeatures;
271+
for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) {
272+
float inputFeature = return_feature(idx, muon, mft, muoncov, mftcov, collision);
273+
inputFeatures.emplace_back(inputFeature);
274+
}
275+
return inputFeatures;
276+
}
277+
278+
template <typename T1>
279+
std::vector<float> getInputFeaturesTest(T1 const& muon)
280+
{
281+
std::vector<float> inputFeatures;
282+
for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) {
283+
float inputFeature = return_featureTest(idx, muon);
284+
inputFeatures.emplace_back(inputFeature);
285+
}
286+
return inputFeatures;
287+
}
288+
289+
/// Method to get the value of variable chosen for binning
290+
/// \param track is the single track, \param collision is the collision
291+
/// \return binning variable
292+
template <typename T1, typename T2, typename C1, typename C2, typename U>
293+
float getBinningFeature(T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision)
294+
{
295+
return return_feature(mCachedIndexBinning, muon, mft, muoncov, mftcov, collision);
296+
}
297+
298+
void cacheBinningIndex(std::string const& cfgBinningFeature)
299+
{
300+
setAvailableInputFeatures();
301+
if (MlResponse<TypeOutputScore>::mAvailableInputFeatures.count(cfgBinningFeature)) {
302+
mCachedIndexBinning = MlResponse<TypeOutputScore>::mAvailableInputFeatures[cfgBinningFeature];
303+
} else {
304+
LOG(fatal) << "Binning feature " << cfgBinningFeature << " not available! Please check your configurables.";
305+
}
306+
}
307+
308+
protected:
309+
/// Method to fill the map of available input features
310+
void setAvailableInputFeatures()
311+
{
312+
MlResponse<TypeOutputScore>::mAvailableInputFeatures = {
313+
FILL_MAP_MFTMUON_MATCH(zMatching),
314+
FILL_MAP_MFTMUON_MATCH(xMFT),
315+
FILL_MAP_MFTMUON_MATCH(yMFT),
316+
FILL_MAP_MFTMUON_MATCH(qOverptMFT),
317+
FILL_MAP_MFTMUON_MATCH(tglMFT),
318+
FILL_MAP_MFTMUON_MATCH(phiMFT),
319+
FILL_MAP_MFTMUON_MATCH(dcaXY),
320+
FILL_MAP_MFTMUON_MATCH(dcaZ),
321+
FILL_MAP_MFTMUON_MATCH(chi2MFT),
322+
FILL_MAP_MFTMUON_MATCH(nClustersMFT),
323+
FILL_MAP_MFTMUON_MATCH(xMCH),
324+
FILL_MAP_MFTMUON_MATCH(yMCH),
325+
FILL_MAP_MFTMUON_MATCH(qOverptMCH),
326+
FILL_MAP_MFTMUON_MATCH(tglMCH),
327+
FILL_MAP_MFTMUON_MATCH(phiMCH),
328+
FILL_MAP_MFTMUON_MATCH(nClustersMCH),
329+
FILL_MAP_MFTMUON_MATCH(chi2MCH),
330+
FILL_MAP_MFTMUON_MATCH(pdca),
331+
FILL_MAP_MFTMUON_MATCH(cXXMFT),
332+
FILL_MAP_MFTMUON_MATCH(cXYMFT),
333+
FILL_MAP_MFTMUON_MATCH(cYYMFT),
334+
FILL_MAP_MFTMUON_MATCH(cPhiYMFT),
335+
FILL_MAP_MFTMUON_MATCH(cPhiXMFT),
336+
FILL_MAP_MFTMUON_MATCH(cPhiPhiMFT),
337+
FILL_MAP_MFTMUON_MATCH(cTglYMFT),
338+
FILL_MAP_MFTMUON_MATCH(cTglXMFT),
339+
FILL_MAP_MFTMUON_MATCH(cTglPhiMFT),
340+
FILL_MAP_MFTMUON_MATCH(cTglTglMFT),
341+
FILL_MAP_MFTMUON_MATCH(c1PtYMFT),
342+
FILL_MAP_MFTMUON_MATCH(c1PtXMFT),
343+
FILL_MAP_MFTMUON_MATCH(c1PtPhiMFT),
344+
FILL_MAP_MFTMUON_MATCH(c1PtTglMFT),
345+
FILL_MAP_MFTMUON_MATCH(c1Pt21Pt2MFT),
346+
FILL_MAP_MFTMUON_MATCH(cXXMCH),
347+
FILL_MAP_MFTMUON_MATCH(cXYMCH),
348+
FILL_MAP_MFTMUON_MATCH(cYYMCH),
349+
FILL_MAP_MFTMUON_MATCH(cPhiYMCH),
350+
FILL_MAP_MFTMUON_MATCH(cPhiXMCH),
351+
FILL_MAP_MFTMUON_MATCH(cPhiPhiMCH),
352+
FILL_MAP_MFTMUON_MATCH(cTglYMCH),
353+
FILL_MAP_MFTMUON_MATCH(cTglXMCH),
354+
FILL_MAP_MFTMUON_MATCH(cTglPhiMCH),
355+
FILL_MAP_MFTMUON_MATCH(cTglTglMCH),
356+
FILL_MAP_MFTMUON_MATCH(c1PtYMCH),
357+
FILL_MAP_MFTMUON_MATCH(c1PtXMCH),
358+
FILL_MAP_MFTMUON_MATCH(c1PtPhiMCH),
359+
FILL_MAP_MFTMUON_MATCH(c1PtTglMCH),
360+
FILL_MAP_MFTMUON_MATCH(c1Pt21Pt2MCH),
361+
FILL_MAP_MFTMUON_MATCH(chi2MCHMFT)
362+
};
363+
}
364+
365+
uint8_t mCachedIndexBinning; // index correspondance between configurable and available input features
366+
};
367+
368+
} // namespace o2::analysis
369+
370+
#undef FILL_MAP_MFTMUON_MAP
371+
#undef CHECK_AND_FILL_MUON_TRACK
372+
#undef CHECK_AND_FILL_MFT_TRACK
373+
#undef CHECK_AND_FILL_MUON_COV
374+
#undef CHECK_AND_FILL_MFT_COV
375+
#undef CHECK_AND_FILL_MFTMUON_DIFF
376+
#undef CHECK_AND_FILL_MFTMUON_COLLISION
377+
378+
#endif // PWGDQ_DILEPTON_UTILS_MLRESPONSEMFTMUONMATCHING_H_

0 commit comments

Comments
 (0)