|
| 1 | +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. |
| 2 | +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. |
| 3 | +// All rights not expressly granted are reserved. |
| 4 | +// |
| 5 | +// This software is distributed under the terms of the GNU General Public |
| 6 | +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". |
| 7 | +// |
| 8 | +// In applying this license CERN does not waive the privileges and immunities |
| 9 | +// granted to it by virtue of its status as an Intergovernmental Organization |
| 10 | +// or submit itself to any jurisdiction. |
| 11 | + |
| 12 | +/// \file MuonMatchingMlResponse.h |
| 13 | +/// \brief Class to compute the ML response for MFT-Muon matching |
| 14 | +/// \author Maurice Coquet <maurice.louis.coquet@cern.ch> |
| 15 | + |
| 16 | +#ifndef PWGDQ_CORE_MUONMATCHINGMLRESPONSE_H_ |
| 17 | +#define PWGDQ_CORE_MUONMATCHINGMLRESPONSE_H_ |
| 18 | + |
| 19 | +#include "Tools/ML/MlResponse.h" |
| 20 | + |
| 21 | +#include <map> |
| 22 | +#include <string> |
| 23 | +#include <vector> |
| 24 | + |
| 25 | +// Fill the map of available input features |
| 26 | +// the key is the feature's name (std::string) |
| 27 | +// the value is the corresponding value in EnumInputFeatures |
| 28 | +#define FILL_MAP_MFTMUON_MATCH(FEATURE) \ |
| 29 | + { \ |
| 30 | + #FEATURE, static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE) \ |
| 31 | + } |
| 32 | + |
| 33 | +// Check if the index of mCachedIndices (index associated to a FEATURE) |
| 34 | +// matches the entry in EnumInputFeatures associated to this FEATURE |
| 35 | +// if so, the inputFeatures vector is filled with the FEATURE's value |
| 36 | +// by calling the corresponding GETTER=FEATURE from track |
| 37 | +#define CHECK_AND_FILL_MUON_TRACK(FEATURE, GETTER) \ |
| 38 | + case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \ |
| 39 | + inputFeature = muon.GETTER(); \ |
| 40 | + break; \ |
| 41 | + } |
| 42 | + |
| 43 | +// Check if the index of mCachedIndices (index associated to a FEATURE) |
| 44 | +// matches the entry in EnumInputFeatures associated to this FEATURE |
| 45 | +// if so, the inputFeatures vector is filled with the FEATURE's value |
| 46 | +// by calling the corresponding GETTER=FEATURE from track |
| 47 | +#define CHECK_AND_FILL_MFT_TRACK(FEATURE, GETTER) \ |
| 48 | + case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \ |
| 49 | + inputFeature = mft.GETTER(); \ |
| 50 | + break; \ |
| 51 | + } |
| 52 | + |
| 53 | +// Check if the index of mCachedIndices (index associated to a FEATURE) |
| 54 | +// matches the entry in EnumInputFeatures associated to this FEATURE |
| 55 | +// if so, the inputFeatures vector is filled with the FEATURE's value |
| 56 | +// by calling the corresponding GETTER=FEATURE from track |
| 57 | +#define CHECK_AND_FILL_MUON_COV(FEATURE, GETTER) \ |
| 58 | + case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \ |
| 59 | + inputFeature = muoncov.GETTER(); \ |
| 60 | + break; \ |
| 61 | + } |
| 62 | + |
| 63 | +// Check if the index of mCachedIndices (index associated to a FEATURE) |
| 64 | +// matches the entry in EnumInputFeatures associated to this FEATURE |
| 65 | +// if so, the inputFeatures vector is filled with the FEATURE's value |
| 66 | +// by calling the corresponding GETTER=FEATURE from track |
| 67 | +#define CHECK_AND_FILL_MFT_COV(FEATURE, GETTER) \ |
| 68 | + case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \ |
| 69 | + inputFeature = mftcov.GETTER(); \ |
| 70 | + break; \ |
| 71 | + } |
| 72 | + |
| 73 | +// Check if the index of mCachedIndices (index associated to a FEATURE) |
| 74 | +// matches the entry in EnumInputFeatures associated to this FEATURE |
| 75 | +// if so, the inputFeatures vector is filled with the FEATURE's value |
| 76 | +// by calling the corresponding GETTER1 and GETTER2 from track. |
| 77 | +#define CHECK_AND_FILL_MFTMUON_DIFF(FEATURE, GETTER1, GETTER2) \ |
| 78 | + case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \ |
| 79 | + inputFeature = (mft.GETTER2() - muon.GETTER1()); \ |
| 80 | + break; \ |
| 81 | + } |
| 82 | + |
| 83 | +// Check if the index of mCachedIndices (index associated to a FEATURE) |
| 84 | +// matches the entry in EnumInputFeatures associated to this FEATURE |
| 85 | +// if so, the inputFeatures vector is filled with the FEATURE's value |
| 86 | +// by calling the corresponding GETTER=FEATURE from collision |
| 87 | +#define CHECK_AND_FILL_MFTMUON_COLLISION(GETTER) \ |
| 88 | + case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::GETTER): { \ |
| 89 | + inputFeature = collision.GETTER(); \ |
| 90 | + break; \ |
| 91 | + } |
| 92 | + |
| 93 | +namespace o2::analysis |
| 94 | +{ |
| 95 | +// possible input features for ML |
| 96 | +enum class InputFeaturesMFTMuonMatch : uint8_t { |
| 97 | + zMatching, |
| 98 | + xMFT, |
| 99 | + yMFT, |
| 100 | + qOverptMFT, |
| 101 | + tglMFT, |
| 102 | + phiMFT, |
| 103 | + dcaXY, |
| 104 | + dcaZ, |
| 105 | + chi2MFT, |
| 106 | + nClustersMFT, |
| 107 | + xMCH, |
| 108 | + yMCH, |
| 109 | + qOverptMCH, |
| 110 | + tglMCH, |
| 111 | + phiMCH, |
| 112 | + nClustersMCH, |
| 113 | + chi2MCH, |
| 114 | + pdca, |
| 115 | + cXXMFT, |
| 116 | + cXYMFT, |
| 117 | + cYYMFT, |
| 118 | + cPhiYMFT, |
| 119 | + cPhiXMFT, |
| 120 | + cPhiPhiMFT, |
| 121 | + cTglYMFT, |
| 122 | + cTglXMFT, |
| 123 | + cTglPhiMFT, |
| 124 | + cTglTglMFT, |
| 125 | + c1PtYMFT, |
| 126 | + c1PtXMFT, |
| 127 | + c1PtPhiMFT, |
| 128 | + c1PtTglMFT, |
| 129 | + c1Pt21Pt2MFT, |
| 130 | + cXXMCH, |
| 131 | + cXYMCH, |
| 132 | + cYYMCH, |
| 133 | + cPhiYMCH, |
| 134 | + cPhiXMCH, |
| 135 | + cPhiPhiMCH, |
| 136 | + cTglYMCH, |
| 137 | + cTglXMCH, |
| 138 | + cTglPhiMCH, |
| 139 | + cTglTglMCH, |
| 140 | + c1PtYMCH, |
| 141 | + c1PtXMCH, |
| 142 | + c1PtPhiMCH, |
| 143 | + c1PtTglMCH, |
| 144 | + c1Pt21Pt2MCH, |
| 145 | + deltaX, |
| 146 | + deltaY, |
| 147 | + deltaPhi, |
| 148 | + deltaEta, |
| 149 | + deltaPt, |
| 150 | + posX, |
| 151 | + posY, |
| 152 | + posZ, |
| 153 | + numContrib, |
| 154 | + trackOccupancyInTimeRange, |
| 155 | + ft0cOccupancyInTimeRange, |
| 156 | + multFT0A, |
| 157 | + multFT0C, |
| 158 | + multNTracksPV, |
| 159 | + multNTracksPVeta1, |
| 160 | + multNTracksPVetaHalf, |
| 161 | + isInelGt0, |
| 162 | + isInelGt1, |
| 163 | + multFT0M, |
| 164 | + centFT0M, |
| 165 | + centFT0A, |
| 166 | + centFT0C, |
| 167 | + chi2MCHMFT |
| 168 | +}; |
| 169 | + |
| 170 | +template <typename TypeOutputScore = float> |
| 171 | +class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore> |
| 172 | +{ |
| 173 | + public: |
| 174 | + /// Default constructor |
| 175 | + MlResponseMFTMuonMatch() = default; |
| 176 | + /// Default destructor |
| 177 | + virtual ~MlResponseMFTMuonMatch() = default; |
| 178 | + |
| 179 | + template <typename T1, typename T2, typename C1, typename C2, typename U> |
| 180 | + float returnFeature(uint8_t idx, T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision) |
| 181 | + { |
| 182 | + float inputFeature = 0.; |
| 183 | + switch (idx) { |
| 184 | + CHECK_AND_FILL_MFT_TRACK(zMatching, z); |
| 185 | + CHECK_AND_FILL_MFT_TRACK(xMFT, x); |
| 186 | + CHECK_AND_FILL_MFT_TRACK(yMFT, y); |
| 187 | + CHECK_AND_FILL_MFT_TRACK(qOverptMFT, signed1Pt); |
| 188 | + CHECK_AND_FILL_MFT_TRACK(tglMFT, tgl); |
| 189 | + CHECK_AND_FILL_MFT_TRACK(phiMFT, phi); |
| 190 | + CHECK_AND_FILL_MFT_TRACK(chi2MFT, chi2); |
| 191 | + CHECK_AND_FILL_MFT_TRACK(nClustersMFT, nClusters); |
| 192 | + CHECK_AND_FILL_MUON_TRACK(dcaXY, fwddcaXY); |
| 193 | + CHECK_AND_FILL_MUON_TRACK(dcaZ, fwddcaz); |
| 194 | + CHECK_AND_FILL_MUON_TRACK(xMCH, x); |
| 195 | + CHECK_AND_FILL_MUON_TRACK(yMCH, y); |
| 196 | + CHECK_AND_FILL_MUON_TRACK(qOverptMCH, signed1Pt); |
| 197 | + CHECK_AND_FILL_MUON_TRACK(tglMCH, tgl); |
| 198 | + CHECK_AND_FILL_MUON_TRACK(phiMCH, phi); |
| 199 | + CHECK_AND_FILL_MUON_TRACK(nClustersMCH, nClusters); |
| 200 | + CHECK_AND_FILL_MUON_TRACK(chi2MCH, chi2); |
| 201 | + CHECK_AND_FILL_MUON_TRACK(pdca, pDca); |
| 202 | + CHECK_AND_FILL_MFT_COV(cXXMFT, cXX); |
| 203 | + CHECK_AND_FILL_MFT_COV(cXYMFT, cXY); |
| 204 | + CHECK_AND_FILL_MFT_COV(cYYMFT, cYY); |
| 205 | + CHECK_AND_FILL_MFT_COV(cPhiYMFT, cPhiY); |
| 206 | + CHECK_AND_FILL_MFT_COV(cPhiXMFT, cPhiX); |
| 207 | + CHECK_AND_FILL_MFT_COV(cPhiPhiMFT, cPhiPhi); |
| 208 | + CHECK_AND_FILL_MFT_COV(cTglYMFT, cTglY); |
| 209 | + CHECK_AND_FILL_MFT_COV(cTglXMFT, cTglX); |
| 210 | + CHECK_AND_FILL_MFT_COV(cTglPhiMFT, cTglPhi); |
| 211 | + CHECK_AND_FILL_MFT_COV(cTglTglMFT, cTglTgl); |
| 212 | + CHECK_AND_FILL_MFT_COV(c1PtYMFT, c1PtY); |
| 213 | + CHECK_AND_FILL_MFT_COV(c1PtXMFT, c1PtX); |
| 214 | + CHECK_AND_FILL_MFT_COV(c1PtPhiMFT, c1PtPhi); |
| 215 | + CHECK_AND_FILL_MFT_COV(c1PtTglMFT, c1PtTgl); |
| 216 | + CHECK_AND_FILL_MFT_COV(c1Pt21Pt2MFT, c1Pt21Pt2); |
| 217 | + CHECK_AND_FILL_MUON_COV(cXXMCH, cXX); |
| 218 | + CHECK_AND_FILL_MUON_COV(cXYMCH, cXY); |
| 219 | + CHECK_AND_FILL_MUON_COV(cYYMCH, cYY); |
| 220 | + CHECK_AND_FILL_MUON_COV(cPhiYMCH, cPhiY); |
| 221 | + CHECK_AND_FILL_MUON_COV(cPhiXMCH, cPhiX); |
| 222 | + CHECK_AND_FILL_MUON_COV(cPhiPhiMCH, cPhiPhi); |
| 223 | + CHECK_AND_FILL_MUON_COV(cTglYMCH, cTglY); |
| 224 | + CHECK_AND_FILL_MUON_COV(cTglXMCH, cTglX); |
| 225 | + CHECK_AND_FILL_MUON_COV(cTglPhiMCH, cTglPhi); |
| 226 | + CHECK_AND_FILL_MUON_COV(cTglTglMCH, cTglTgl); |
| 227 | + CHECK_AND_FILL_MUON_COV(c1PtYMCH, c1PtY); |
| 228 | + CHECK_AND_FILL_MUON_COV(c1PtXMCH, c1PtX); |
| 229 | + CHECK_AND_FILL_MUON_COV(c1PtPhiMCH, c1PtPhi); |
| 230 | + CHECK_AND_FILL_MUON_COV(c1PtTglMCH, c1PtTgl); |
| 231 | + CHECK_AND_FILL_MUON_COV(c1Pt21Pt2MCH, c1Pt21Pt2); |
| 232 | + CHECK_AND_FILL_MFTMUON_COLLISION(posX); |
| 233 | + CHECK_AND_FILL_MFTMUON_COLLISION(posY); |
| 234 | + CHECK_AND_FILL_MFTMUON_COLLISION(posZ); |
| 235 | + CHECK_AND_FILL_MFTMUON_COLLISION(numContrib); |
| 236 | + CHECK_AND_FILL_MFTMUON_COLLISION(trackOccupancyInTimeRange); |
| 237 | + CHECK_AND_FILL_MFTMUON_COLLISION(ft0cOccupancyInTimeRange); |
| 238 | + CHECK_AND_FILL_MFTMUON_COLLISION(multFT0A); |
| 239 | + CHECK_AND_FILL_MFTMUON_COLLISION(multFT0C); |
| 240 | + CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPV); |
| 241 | + CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPVeta1); |
| 242 | + CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPVetaHalf); |
| 243 | + CHECK_AND_FILL_MFTMUON_COLLISION(isInelGt0); |
| 244 | + CHECK_AND_FILL_MFTMUON_COLLISION(isInelGt1); |
| 245 | + CHECK_AND_FILL_MFTMUON_COLLISION(multFT0M); |
| 246 | + CHECK_AND_FILL_MFTMUON_COLLISION(centFT0M); |
| 247 | + CHECK_AND_FILL_MFTMUON_COLLISION(centFT0A); |
| 248 | + CHECK_AND_FILL_MFTMUON_COLLISION(centFT0C); |
| 249 | + CHECK_AND_FILL_MUON_TRACK(chi2MCHMFT, chi2MatchMCHMFT); |
| 250 | + } |
| 251 | + return inputFeature; |
| 252 | + } |
| 253 | + |
| 254 | + template <typename T1> |
| 255 | + float returnFeatureTest(uint8_t idx, T1 const& muon) |
| 256 | + { |
| 257 | + float inputFeature = 0.; |
| 258 | + switch (idx) { |
| 259 | + CHECK_AND_FILL_MUON_TRACK(chi2MCHMFT, chi2MatchMCHMFT); |
| 260 | + } |
| 261 | + return inputFeature; |
| 262 | + } |
| 263 | + |
| 264 | + /// Method to get the input features vector needed for ML inference |
| 265 | + /// \param track is the single track, \param collision is the collision |
| 266 | + /// \return inputFeatures vector |
| 267 | + template <typename T1, typename T2, typename C1, typename C2, typename U> |
| 268 | + std::vector<float> getInputFeatures(T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision) |
| 269 | + { |
| 270 | + std::vector<float> inputFeatures; |
| 271 | + for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) { |
| 272 | + float inputFeature = returnFeature(idx, muon, mft, muoncov, mftcov, collision); |
| 273 | + inputFeatures.emplace_back(inputFeature); |
| 274 | + } |
| 275 | + return inputFeatures; |
| 276 | + } |
| 277 | + |
| 278 | + template <typename T1> |
| 279 | + std::vector<float> getInputFeaturesTest(T1 const& muon) |
| 280 | + { |
| 281 | + std::vector<float> inputFeatures; |
| 282 | + for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) { |
| 283 | + float inputFeature = returnFeatureTest(idx, muon); |
| 284 | + inputFeatures.emplace_back(inputFeature); |
| 285 | + } |
| 286 | + return inputFeatures; |
| 287 | + } |
| 288 | + |
| 289 | + /// Method to get the value of variable chosen for binning |
| 290 | + /// \param track is the single track, \param collision is the collision |
| 291 | + /// \return binning variable |
| 292 | + template <typename T1, typename T2, typename C1, typename C2, typename U> |
| 293 | + float getBinningFeature(T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision) |
| 294 | + { |
| 295 | + return returnFeature(mCachedIndexBinning, muon, mft, muoncov, mftcov, collision); |
| 296 | + } |
| 297 | + |
| 298 | + void cacheBinningIndex(std::string const& cfgBinningFeature) |
| 299 | + { |
| 300 | + setAvailableInputFeatures(); |
| 301 | + if (MlResponse<TypeOutputScore>::mAvailableInputFeatures.count(cfgBinningFeature)) { |
| 302 | + mCachedIndexBinning = MlResponse<TypeOutputScore>::mAvailableInputFeatures[cfgBinningFeature]; |
| 303 | + } else { |
| 304 | + LOG(fatal) << "Binning feature " << cfgBinningFeature << " not available! Please check your configurables."; |
| 305 | + } |
| 306 | + } |
| 307 | + |
| 308 | + protected: |
| 309 | + /// Method to fill the map of available input features |
| 310 | + void setAvailableInputFeatures() |
| 311 | + { |
| 312 | + MlResponse<TypeOutputScore>::mAvailableInputFeatures = { |
| 313 | + FILL_MAP_MFTMUON_MATCH(zMatching), |
| 314 | + FILL_MAP_MFTMUON_MATCH(xMFT), |
| 315 | + FILL_MAP_MFTMUON_MATCH(yMFT), |
| 316 | + FILL_MAP_MFTMUON_MATCH(qOverptMFT), |
| 317 | + FILL_MAP_MFTMUON_MATCH(tglMFT), |
| 318 | + FILL_MAP_MFTMUON_MATCH(phiMFT), |
| 319 | + FILL_MAP_MFTMUON_MATCH(dcaXY), |
| 320 | + FILL_MAP_MFTMUON_MATCH(dcaZ), |
| 321 | + FILL_MAP_MFTMUON_MATCH(chi2MFT), |
| 322 | + FILL_MAP_MFTMUON_MATCH(nClustersMFT), |
| 323 | + FILL_MAP_MFTMUON_MATCH(xMCH), |
| 324 | + FILL_MAP_MFTMUON_MATCH(yMCH), |
| 325 | + FILL_MAP_MFTMUON_MATCH(qOverptMCH), |
| 326 | + FILL_MAP_MFTMUON_MATCH(tglMCH), |
| 327 | + FILL_MAP_MFTMUON_MATCH(phiMCH), |
| 328 | + FILL_MAP_MFTMUON_MATCH(nClustersMCH), |
| 329 | + FILL_MAP_MFTMUON_MATCH(chi2MCH), |
| 330 | + FILL_MAP_MFTMUON_MATCH(pdca), |
| 331 | + FILL_MAP_MFTMUON_MATCH(cXXMFT), |
| 332 | + FILL_MAP_MFTMUON_MATCH(cXYMFT), |
| 333 | + FILL_MAP_MFTMUON_MATCH(cYYMFT), |
| 334 | + FILL_MAP_MFTMUON_MATCH(cPhiYMFT), |
| 335 | + FILL_MAP_MFTMUON_MATCH(cPhiXMFT), |
| 336 | + FILL_MAP_MFTMUON_MATCH(cPhiPhiMFT), |
| 337 | + FILL_MAP_MFTMUON_MATCH(cTglYMFT), |
| 338 | + FILL_MAP_MFTMUON_MATCH(cTglXMFT), |
| 339 | + FILL_MAP_MFTMUON_MATCH(cTglPhiMFT), |
| 340 | + FILL_MAP_MFTMUON_MATCH(cTglTglMFT), |
| 341 | + FILL_MAP_MFTMUON_MATCH(c1PtYMFT), |
| 342 | + FILL_MAP_MFTMUON_MATCH(c1PtXMFT), |
| 343 | + FILL_MAP_MFTMUON_MATCH(c1PtPhiMFT), |
| 344 | + FILL_MAP_MFTMUON_MATCH(c1PtTglMFT), |
| 345 | + FILL_MAP_MFTMUON_MATCH(c1Pt21Pt2MFT), |
| 346 | + FILL_MAP_MFTMUON_MATCH(cXXMCH), |
| 347 | + FILL_MAP_MFTMUON_MATCH(cXYMCH), |
| 348 | + FILL_MAP_MFTMUON_MATCH(cYYMCH), |
| 349 | + FILL_MAP_MFTMUON_MATCH(cPhiYMCH), |
| 350 | + FILL_MAP_MFTMUON_MATCH(cPhiXMCH), |
| 351 | + FILL_MAP_MFTMUON_MATCH(cPhiPhiMCH), |
| 352 | + FILL_MAP_MFTMUON_MATCH(cTglYMCH), |
| 353 | + FILL_MAP_MFTMUON_MATCH(cTglXMCH), |
| 354 | + FILL_MAP_MFTMUON_MATCH(cTglPhiMCH), |
| 355 | + FILL_MAP_MFTMUON_MATCH(cTglTglMCH), |
| 356 | + FILL_MAP_MFTMUON_MATCH(c1PtYMCH), |
| 357 | + FILL_MAP_MFTMUON_MATCH(c1PtXMCH), |
| 358 | + FILL_MAP_MFTMUON_MATCH(c1PtPhiMCH), |
| 359 | + FILL_MAP_MFTMUON_MATCH(c1PtTglMCH), |
| 360 | + FILL_MAP_MFTMUON_MATCH(c1Pt21Pt2MCH), |
| 361 | + FILL_MAP_MFTMUON_MATCH(chi2MCHMFT)}; |
| 362 | + } |
| 363 | + |
| 364 | + uint8_t mCachedIndexBinning; // index correspondance between configurable and available input features |
| 365 | +}; |
| 366 | + |
| 367 | +} // namespace o2::analysis |
| 368 | + |
| 369 | +#undef FILL_MAP_MFTMUON_MAP |
| 370 | +#undef CHECK_AND_FILL_MUON_TRACK |
| 371 | +#undef CHECK_AND_FILL_MFT_TRACK |
| 372 | +#undef CHECK_AND_FILL_MUON_COV |
| 373 | +#undef CHECK_AND_FILL_MFT_COV |
| 374 | +#undef CHECK_AND_FILL_MFTMUON_DIFF |
| 375 | +#undef CHECK_AND_FILL_MFTMUON_COLLISION |
| 376 | + |
| 377 | +#endif // PWGDQ_CORE_MUONMATCHINGMLRESPONSE_H_ |
0 commit comments