1717#ifndef PWGDQ_CORE_DQMLRESPONSE_H_
1818#define PWGDQ_CORE_DQMLRESPONSE_H_
1919
20+ #include " Tools/ML/MlResponse.h"
21+
2022#include < map>
2123#include < string>
2224#include < vector>
2325
24- #include " Tools/ML/MlResponse.h"
25-
2626// Fill the map of available input features
2727// the key is the feature's name (std::string)
2828// the value is the corresponding value in EnumInputFeatures
29- #define FILL_MAP (FEATURE ) \
30- { \
29+ #define FILL_MAP (FEATURE ) \
30+ { \
3131 #FEATURE, static_cast <uint8_t >(InputFeatures::FEATURE) \
3232 }
3333
34-
3534namespace o2 ::analysis
3635{
3736
38- enum class InputFeatures : uint8_t { // refer to DielectronsAll
37+ enum class InputFeatures : uint8_t { // refer to DielectronsAll
3938 fMass = 0 ,
4039 fPt ,
4140 fEta ,
@@ -69,37 +68,36 @@ enum class InputFeatures : uint8_t { //refer to DielectronsAll
6968};
7069
7170static const std::map<InputFeatures, std::string> gFeatureNameMap = {
72- {InputFeatures::fMass , " fMass" },
73- {InputFeatures::fPt , " fPt" },
74- {InputFeatures::fEta , " fEta" },
75- {InputFeatures::fPhi , " fPhi" },
76- {InputFeatures::fPt1 , " fPt1" },
77- {InputFeatures::fITSChi2NCl1 , " fITSChi2NCl1" },
78- {InputFeatures::fTPCNClsCR1 , " fTPCNClsCR1" },
79- {InputFeatures::fTPCNClsFound1 , " fTPCNClsFound1" },
80- {InputFeatures::fTPCChi2NCl1 , " fTPCChi2NCl1" },
81- {InputFeatures::fDcaXY1 , " fDcaXY1" },
82- {InputFeatures::fDcaZ1 , " fDcaZ1" },
83- {InputFeatures::fTPCNSigmaEl1 , " fTPCNSigmaEl1" },
84- {InputFeatures::fTPCNSigmaPi1 , " fTPCNSigmaPi1" },
85- {InputFeatures::fTPCNSigmaPr1 , " fTPCNSigmaPr1" },
86- {InputFeatures::fTOFNSigmaEl1 , " fTOFNSigmaEl1" },
87- {InputFeatures::fTOFNSigmaPi1 , " fTOFNSigmaPi1" },
88- {InputFeatures::fTOFNSigmaPr1 , " fTOFNSigmaPr1" },
89- {InputFeatures::fPt2 , " fPt2" },
90- {InputFeatures::fITSChi2NCl2 , " fITSChi2NCl2" },
91- {InputFeatures::fTPCNClsCR2 , " fTPCNClsCR2" },
92- {InputFeatures::fTPCNClsFound2 , " fTPCNClsFound2" },
93- {InputFeatures::fTPCChi2NCl2 , " fTPCChi2NCl2" },
94- {InputFeatures::fDcaXY2 , " fDcaXY2" },
95- {InputFeatures::fDcaZ2 , " fDcaZ2" },
96- {InputFeatures::fTPCNSigmaEl2 , " fTPCNSigmaEl2" },
97- {InputFeatures::fTPCNSigmaPi2 , " fTPCNSigmaPi2" },
98- {InputFeatures::fTPCNSigmaPr2 , " fTPCNSigmaPr2" },
99- {InputFeatures::fTOFNSigmaEl2 , " fTOFNSigmaEl2" },
100- {InputFeatures::fTOFNSigmaPi2 , " fTOFNSigmaPi2" },
101- {InputFeatures::fTOFNSigmaPr2 , " fTOFNSigmaPr2" }
102- };
71+ {InputFeatures::fMass , " fMass" },
72+ {InputFeatures::fPt , " fPt" },
73+ {InputFeatures::fEta , " fEta" },
74+ {InputFeatures::fPhi , " fPhi" },
75+ {InputFeatures::fPt1 , " fPt1" },
76+ {InputFeatures::fITSChi2NCl1 , " fITSChi2NCl1" },
77+ {InputFeatures::fTPCNClsCR1 , " fTPCNClsCR1" },
78+ {InputFeatures::fTPCNClsFound1 , " fTPCNClsFound1" },
79+ {InputFeatures::fTPCChi2NCl1 , " fTPCChi2NCl1" },
80+ {InputFeatures::fDcaXY1 , " fDcaXY1" },
81+ {InputFeatures::fDcaZ1 , " fDcaZ1" },
82+ {InputFeatures::fTPCNSigmaEl1 , " fTPCNSigmaEl1" },
83+ {InputFeatures::fTPCNSigmaPi1 , " fTPCNSigmaPi1" },
84+ {InputFeatures::fTPCNSigmaPr1 , " fTPCNSigmaPr1" },
85+ {InputFeatures::fTOFNSigmaEl1 , " fTOFNSigmaEl1" },
86+ {InputFeatures::fTOFNSigmaPi1 , " fTOFNSigmaPi1" },
87+ {InputFeatures::fTOFNSigmaPr1 , " fTOFNSigmaPr1" },
88+ {InputFeatures::fPt2 , " fPt2" },
89+ {InputFeatures::fITSChi2NCl2 , " fITSChi2NCl2" },
90+ {InputFeatures::fTPCNClsCR2 , " fTPCNClsCR2" },
91+ {InputFeatures::fTPCNClsFound2 , " fTPCNClsFound2" },
92+ {InputFeatures::fTPCChi2NCl2 , " fTPCChi2NCl2" },
93+ {InputFeatures::fDcaXY2 , " fDcaXY2" },
94+ {InputFeatures::fDcaZ2 , " fDcaZ2" },
95+ {InputFeatures::fTPCNSigmaEl2 , " fTPCNSigmaEl2" },
96+ {InputFeatures::fTPCNSigmaPi2 , " fTPCNSigmaPi2" },
97+ {InputFeatures::fTPCNSigmaPr2 , " fTPCNSigmaPr2" },
98+ {InputFeatures::fTOFNSigmaEl2 , " fTOFNSigmaEl2" },
99+ {InputFeatures::fTOFNSigmaPi2 , " fTOFNSigmaPi2" },
100+ {InputFeatures::fTOFNSigmaPr2 , " fTOFNSigmaPr2" }};
103101
104102template <typename TypeOutputScore = float >
105103class DQMlResponse : public MlResponse <TypeOutputScore>
@@ -111,67 +109,66 @@ class DQMlResponse : public MlResponse<TypeOutputScore>
111109 virtual ~DQMlResponse () = default ;
112110
113111 // / Method to get the input features vector needed for ML inference
114- // / \return inputFeatures vector
115- template <typename T1, typename T2, typename TValues>
116- std::vector<float > getInputFeatures (const T1& t1,
117- const T2& t2,
118- const TValues& fg) const
119- {
120- using Accessor = std::function<float (const T1&, const T2&, const TValues&)>;
121- static const std::unordered_map<std::string, Accessor> featureMap{
122- {" fMass" , [](auto const &, auto const &, auto const & v){ return v[VarManager::kMass ]; }},
123- {" fPt" , [](auto const &, auto const &, auto const & v){ return v[VarManager::kPt ]; }},
124- {" fEta" , [](auto const &, auto const &, auto const & v){ return v[VarManager::kEta ]; }},
125- {" fPhi" , [](auto const &, auto const &, auto const & v){ return v[VarManager::kPhi ]; }},
126-
127- {" fPt1" , [](auto const & t1, auto const &, auto const &) { return t1.pt (); }},
128- {" fITSChi2NCl1" , [](auto const & t1, auto const &, auto const &) { return t1.itsChi2NCl (); }},
129- {" fTPCNClsCR1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcNClsCrossedRows (); }},
112+ // / \return inputFeatures vector
113+ template <typename T1, typename T2, typename TValues>
114+ std::vector<float > getInputFeatures (const T1& t1,
115+ const T2& t2,
116+ const TValues& fg) const
117+ {
118+ using Accessor = std::function<float (const T1&, const T2&, const TValues&)>;
119+ static const std::unordered_map<std::string, Accessor> featureMap{
120+ {" fMass" , [](auto const &, auto const &, auto const & v) { return v[VarManager::kMass ]; }},
121+ {" fPt" , [](auto const &, auto const &, auto const & v) { return v[VarManager::kPt ]; }},
122+ {" fEta" , [](auto const &, auto const &, auto const & v) { return v[VarManager::kEta ]; }},
123+ {" fPhi" , [](auto const &, auto const &, auto const & v) { return v[VarManager::kPhi ]; }},
124+
125+ {" fPt1" , [](auto const & t1, auto const &, auto const &) { return t1.pt (); }},
126+ {" fITSChi2NCl1" , [](auto const & t1, auto const &, auto const &) { return t1.itsChi2NCl (); }},
127+ {" fTPCNClsCR1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcNClsCrossedRows (); }},
130128 {" fTPCNClsFound1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcNClsFound (); }},
131- {" fTPCChi2NCl1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcChi2NCl (); }},
132- {" fDcaXY1" , [](auto const & t1, auto const &, auto const &) { return t1.dcaXY (); }},
133- {" fDcaZ1" , [](auto const & t1, auto const &, auto const &) { return t1.dcaZ (); }},
134- {" fTPCNSigmaEl1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcNSigmaEl (); }},
135- {" fTPCNSigmaPi1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcNSigmaPi (); }},
136- {" fTPCNSigmaPr1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcNSigmaPr (); }},
137- {" fTOFNSigmaEl1" , [](auto const & t1, auto const &, auto const &) { return t1.tofNSigmaEl (); }},
138- {" fTOFNSigmaPi1" , [](auto const & t1, auto const &, auto const &) { return t1.tofNSigmaPi (); }},
139- {" fTOFNSigmaPr1" , [](auto const & t1, auto const &, auto const &) { return t1.tofNSigmaPr (); }},
140-
141- {" fPt2" , [](auto const &, auto const & t2, auto const &) { return t2.pt (); }},
142- {" fITSChi2NCl2" , [](auto const &, auto const & t2, auto const &) { return t2.itsChi2NCl (); }},
143- {" fTPCNClsCR2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcNClsCrossedRows (); }},
129+ {" fTPCChi2NCl1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcChi2NCl (); }},
130+ {" fDcaXY1" , [](auto const & t1, auto const &, auto const &) { return t1.dcaXY (); }},
131+ {" fDcaZ1" , [](auto const & t1, auto const &, auto const &) { return t1.dcaZ (); }},
132+ {" fTPCNSigmaEl1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcNSigmaEl (); }},
133+ {" fTPCNSigmaPi1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcNSigmaPi (); }},
134+ {" fTPCNSigmaPr1" , [](auto const & t1, auto const &, auto const &) { return t1.tpcNSigmaPr (); }},
135+ {" fTOFNSigmaEl1" , [](auto const & t1, auto const &, auto const &) { return t1.tofNSigmaEl (); }},
136+ {" fTOFNSigmaPi1" , [](auto const & t1, auto const &, auto const &) { return t1.tofNSigmaPi (); }},
137+ {" fTOFNSigmaPr1" , [](auto const & t1, auto const &, auto const &) { return t1.tofNSigmaPr (); }},
138+
139+ {" fPt2" , [](auto const &, auto const & t2, auto const &) { return t2.pt (); }},
140+ {" fITSChi2NCl2" , [](auto const &, auto const & t2, auto const &) { return t2.itsChi2NCl (); }},
141+ {" fTPCNClsCR2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcNClsCrossedRows (); }},
144142 {" fTPCNClsFound2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcNClsFound (); }},
145- {" fTPCChi2NCl2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcChi2NCl (); }},
146- {" fDcaXY2" , [](auto const &, auto const & t2, auto const &) { return t2.dcaXY (); }},
147- {" fDcaZ2" , [](auto const &, auto const & t2, auto const &) { return t2.dcaZ (); }},
148- {" fTPCNSigmaEl2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcNSigmaEl (); }},
149- {" fTPCNSigmaPi2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcNSigmaPi (); }},
150- {" fTPCNSigmaPr2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcNSigmaPr (); }},
151- {" fTOFNSigmaEl2" , [](auto const &, auto const & t2, auto const &) { return t2.tofNSigmaEl (); }},
152- {" fTOFNSigmaPi2" , [](auto const &, auto const & t2, auto const &) { return t2.tofNSigmaPi (); }},
153- {" fTOFNSigmaPr2" , [](auto const &, auto const & t2, auto const &) { return t2.tofNSigmaPr (); }}
154- };
155-
156- std::vector< float > dqInputFeatures;
157- dqInputFeatures. reserve (MlResponse<TypeOutputScore>:: mCachedIndices . size ());
158-
159- for ( auto idx : MlResponse<TypeOutputScore>:: mCachedIndices ) {
160- auto enumIdx = static_cast <InputFeatures>(idx );
161- const auto & name = gFeatureNameMap . at (enumIdx);
162-
163- auto acc = featureMap.find (name);
164- if (acc == featureMap. end ()) {
165- LOG (error) << " Missing accessor for " << name ;
166- continue ;
167- } else {
168- dqInputFeatures. push_back (acc-> second (t1, t2, fg));
143+ {" fTPCChi2NCl2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcChi2NCl (); }},
144+ {" fDcaXY2" , [](auto const &, auto const & t2, auto const &) { return t2.dcaXY (); }},
145+ {" fDcaZ2" , [](auto const &, auto const & t2, auto const &) { return t2.dcaZ (); }},
146+ {" fTPCNSigmaEl2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcNSigmaEl (); }},
147+ {" fTPCNSigmaPi2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcNSigmaPi (); }},
148+ {" fTPCNSigmaPr2" , [](auto const &, auto const & t2, auto const &) { return t2.tpcNSigmaPr (); }},
149+ {" fTOFNSigmaEl2" , [](auto const &, auto const & t2, auto const &) { return t2.tofNSigmaEl (); }},
150+ {" fTOFNSigmaPi2" , [](auto const &, auto const & t2, auto const &) { return t2.tofNSigmaPi (); }},
151+ {" fTOFNSigmaPr2" , [](auto const &, auto const & t2, auto const &) { return t2.tofNSigmaPr (); }}};
152+
153+ std::vector< float > dqInputFeatures;
154+ dqInputFeatures. reserve (MlResponse<TypeOutputScore>:: mCachedIndices . size ()) ;
155+
156+ for ( auto idx : MlResponse<TypeOutputScore>:: mCachedIndices ) {
157+ auto enumIdx = static_cast <InputFeatures>(idx);
158+ const auto & name = gFeatureNameMap . at (enumIdx );
159+
160+ auto acc = featureMap. find (name);
161+ if ( acc == featureMap.end ()) {
162+ LOG (error) << " Missing accessor for " << name;
163+ continue ;
164+ } else {
165+ dqInputFeatures. push_back (acc-> second (t1, t2, fg));
166+ }
169167 }
168+ return dqInputFeatures;
170169 }
171- return dqInputFeatures;
172- }
173170
174- protected:
171+ protected:
175172 void setAvailableInputFeatures ()
176173 {
177174 MlResponse<TypeOutputScore>::mAvailableInputFeatures = {
@@ -212,4 +209,4 @@ std::vector<float> getInputFeatures(const T1& t1,
212209
213210#undef FILL_MAP
214211
215- #endif // PWGDQ_CORE_DQMLRESPONSE_H_
212+ #endif // PWGDQ_CORE_DQMLRESPONSE_H_
0 commit comments