Skip to content

Commit 6368dce

Browse files
authored
[PWGHF] Enable ML selection in KF-based LcToPKPi reconstruction (#11489)
1 parent a1ea21e commit 6368dce

File tree

3 files changed

+194
-44
lines changed

3 files changed

+194
-44
lines changed

PWGHF/Core/HfMlResponseLcToPKPi.h

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
#ifndef PWGHF_CORE_HFMLRESPONSELCTOPKPI_H_
1717
#define PWGHF_CORE_HFMLRESPONSELCTOPKPI_H_
1818

19+
#include <map>
20+
#include <string>
1921
#include <vector>
2022

23+
#include "PWGHF/DataModel/CandidateReconstructionTables.h"
24+
2125
#include "PWGHF/Core/HfMlResponse.h"
2226

2327
// Fill the map of available input features
@@ -129,10 +133,22 @@ enum class InputFeaturesLcToPKPi : uint8_t {
129133
tofNSigmaPrExpPr0,
130134
tofNSigmaPiExpPi2,
131135
tpcTofNSigmaPrExpPr0,
132-
tpcTofNSigmaPiExpPi2
136+
tpcTofNSigmaPiExpPi2,
137+
kfChi2PrimProton,
138+
kfChi2PrimKaon,
139+
kfChi2PrimPion,
140+
kfChi2GeoKaonPion,
141+
kfChi2GeoProtonPion,
142+
kfChi2GeoProtonKaon,
143+
kfDcaKaonPion,
144+
kfDcaProtonPion,
145+
kfDcaProtonKaon,
146+
kfChi2Geo,
147+
kfChi2Topo,
148+
kfDecayLengthNormalised
133149
};
134150

135-
template <typename TypeOutputScore = float>
151+
template <typename TypeOutputScore = float, aod::hf_cand::VertexerType reconstructionType = aod::hf_cand::VertexerType::DCAFitter>
136152
class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
137153
{
138154
public:
@@ -179,8 +195,6 @@ class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
179195
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcNSigmaPr2, nSigTpcPr2);
180196
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcNSigmaKa2, nSigTpcKa2);
181197
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcNSigmaPi2, nSigTpcPi2);
182-
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong0, prong2, tpcNSigmaPrExpPr0, tpcNSigmaPr);
183-
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong2, prong0, tpcNSigmaPiExpPi2, tpcNSigmaPi);
184198
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tpcNSigmaPrExpPr0, nSigTpcPr0, nSigTpcPr2);
185199
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tpcNSigmaPiExpPi2, nSigTpcPi2, nSigTpcPi0);
186200
// TOF PID variables
@@ -193,8 +207,6 @@ class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
193207
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tofNSigmaPr2, nSigTofPr2);
194208
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tofNSigmaKa2, nSigTofKa2);
195209
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tofNSigmaPi2, nSigTofPi2);
196-
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong0, prong2, tofNSigmaPrExpPr0, tofNSigmaPr);
197-
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong2, prong0, tofNSigmaPiExpPi2, tofNSigmaPi);
198210
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tofNSigmaPrExpPr0, nSigTofPr0, nSigTofPr2);
199211
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tofNSigmaPiExpPi2, nSigTofPi2, nSigTofPi0);
200212
// Combined PID variables
@@ -207,13 +219,29 @@ class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
207219
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcTofNSigmaPr0, tpcTofNSigmaPr0);
208220
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcTofNSigmaPr1, tpcTofNSigmaPr1);
209221
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcTofNSigmaPr2, tpcTofNSigmaPr2);
210-
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong0, prong2, tpcTofNSigmaPrExpPr0, tpcTofNSigmaPr);
211-
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong2, prong0, tpcTofNSigmaPiExpPi2, tpcTofNSigmaPi);
212222
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tpcTofNSigmaPrExpPr0, tpcTofNSigmaPr0, tpcTofNSigmaPr2);
213223
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tpcTofNSigmaPiExpPi2, tpcTofNSigmaPi2, tpcTofNSigmaPi0);
214224
}
225+
if constexpr (reconstructionType == aod::hf_cand::VertexerType::KfParticle) {
226+
switch (idx) {
227+
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfChi2PrimProton, kfChi2PrimProng0, kfChi2PrimProng2);
228+
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, kfChi2PrimKaon, kfChi2PrimProng1);
229+
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfChi2PrimPion, kfChi2PrimProng2, kfChi2PrimProng0);
230+
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfChi2GeoKaonPion, kfChi2GeoProng1Prong2, kfChi2GeoProng0Prong1);
231+
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, kfChi2GeoProtonPion, kfChi2GeoProng0Prong2);
232+
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfChi2GeoProtonKaon, kfChi2GeoProng0Prong1, kfChi2GeoProng1Prong2);
233+
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfDcaKaonPion, kfDcaProng1Prong2, kfDcaProng0Prong1);
234+
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, kfDcaProtonPion, kfDcaProng0Prong2);
235+
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfDcaProtonKaon, kfDcaProng0Prong1, kfDcaProng1Prong2);
236+
CHECK_AND_FILL_VEC_LCTOPKPI(kfChi2Geo);
237+
CHECK_AND_FILL_VEC_LCTOPKPI(kfChi2Topo);
238+
case static_cast<uint8_t>(InputFeaturesLcToPKPi::kfDecayLengthNormalised): {
239+
inputFeatures.emplace_back(candidate.kfDecayLength() / candidate.kfDecayLengthError());
240+
break;
241+
}
242+
}
243+
}
215244
}
216-
217245
return inputFeatures;
218246
}
219247

@@ -273,6 +301,23 @@ class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
273301
FILL_MAP_LCTOPKPI(tpcTofNSigmaPr2),
274302
FILL_MAP_LCTOPKPI(tpcTofNSigmaPrExpPr0),
275303
FILL_MAP_LCTOPKPI(tpcTofNSigmaPiExpPi2)};
304+
if constexpr (reconstructionType == aod::hf_cand::VertexerType::KfParticle) {
305+
std::map<std::string, uint8_t> mapKfFeatures{
306+
// KFParticle variables
307+
FILL_MAP_LCTOPKPI(kfChi2PrimProton),
308+
FILL_MAP_LCTOPKPI(kfChi2PrimKaon),
309+
FILL_MAP_LCTOPKPI(kfChi2PrimPion),
310+
FILL_MAP_LCTOPKPI(kfChi2GeoKaonPion),
311+
FILL_MAP_LCTOPKPI(kfChi2GeoProtonPion),
312+
FILL_MAP_LCTOPKPI(kfChi2GeoProtonKaon),
313+
FILL_MAP_LCTOPKPI(kfDcaKaonPion),
314+
FILL_MAP_LCTOPKPI(kfDcaProtonPion),
315+
FILL_MAP_LCTOPKPI(kfDcaProtonKaon),
316+
FILL_MAP_LCTOPKPI(kfChi2Geo),
317+
FILL_MAP_LCTOPKPI(kfChi2Topo),
318+
FILL_MAP_LCTOPKPI(kfDecayLengthNormalised)};
319+
MlResponse<TypeOutputScore>::mAvailableInputFeatures.insert(mapKfFeatures.begin(), mapKfFeatures.end());
320+
}
276321
}
277322
};
278323

PWGHF/TableProducer/candidateSelectorLc.cxx

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ struct HfCandidateSelectorLc {
9393
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
9494

9595
HfHelper hfHelper;
96-
o2::analysis::HfMlResponseLcToPKPi<float> hfMlResponse;
96+
o2::analysis::HfMlResponseLcToPKPi<float, aod::hf_cand::VertexerType::DCAFitter> hfMlResponseDCA;
97+
o2::analysis::HfMlResponseLcToPKPi<float, aod::hf_cand::VertexerType::KfParticle> hfMlResponseKF;
9798
std::vector<float> outputMlLcToPKPi = {};
9899
std::vector<float> outputMlLcToPiKP = {};
99100
o2::ccdb::CcdbApi ccdbApi;
@@ -142,15 +143,28 @@ struct HfCandidateSelectorLc {
142143
}
143144

144145
if (applyMl) {
145-
hfMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
146-
if (loadModelsFromCCDB) {
147-
ccdbApi.init(ccdbUrl);
148-
hfMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
149-
} else {
150-
hfMlResponse.setModelPathsLocal(onnxFileNames);
146+
if (doprocessNoBayesPidWithDCAFitterN || doprocessBayesPidWithDCAFitterN) {
147+
hfMlResponseDCA.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
148+
if (loadModelsFromCCDB) {
149+
ccdbApi.init(ccdbUrl);
150+
hfMlResponseDCA.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
151+
} else {
152+
hfMlResponseDCA.setModelPathsLocal(onnxFileNames);
153+
}
154+
hfMlResponseDCA.cacheInputFeaturesIndices(namesInputFeatures);
155+
hfMlResponseDCA.init();
156+
}
157+
if (doprocessNoBayesPidWithKFParticle || doprocessBayesPidWithKFParticle) {
158+
hfMlResponseKF.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
159+
if (loadModelsFromCCDB) {
160+
ccdbApi.init(ccdbUrl);
161+
hfMlResponseKF.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
162+
} else {
163+
hfMlResponseKF.setModelPathsLocal(onnxFileNames);
164+
}
165+
hfMlResponseKF.cacheInputFeaturesIndices(namesInputFeatures);
166+
hfMlResponseKF.init();
151167
}
152-
hfMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
153-
hfMlResponse.init();
154168
}
155169

156170
massK0Star892 = o2::constants::physics::MassK0Star892;
@@ -273,7 +287,7 @@ struct HfCandidateSelectorLc {
273287
return false;
274288
}
275289

276-
float massLc, massKPi;
290+
float massLc{0.f}, massKPi{0.f};
277291
if constexpr (reconstructionType == aod::hf_cand::VertexerType::DCAFitter) {
278292
if (trackProton.globalIndex() == candidate.prong0Id()) {
279293
massLc = hfHelper.invMassLcToPKPi(candidate);
@@ -553,13 +567,24 @@ struct HfCandidateSelectorLc {
553567
isSelectedMlLcToPKPi = false;
554568
isSelectedMlLcToPiKP = false;
555569

556-
if (pidLcToPKPi == 1 && pidBayesLcToPKPi == 1 && topolLcToPKPi) {
557-
std::vector<float> inputFeaturesLcToPKPi = hfMlResponse.getInputFeatures(candidate, true);
558-
isSelectedMlLcToPKPi = hfMlResponse.isSelectedMl(inputFeaturesLcToPKPi, candidate.pt(), outputMlLcToPKPi);
559-
}
560-
if (pidLcToPiKP == 1 && pidBayesLcToPiKP == 1 && topolLcToPiKP) {
561-
std::vector<float> inputFeaturesLcToPiKP = hfMlResponse.getInputFeatures(candidate, false);
562-
isSelectedMlLcToPiKP = hfMlResponse.isSelectedMl(inputFeaturesLcToPiKP, candidate.pt(), outputMlLcToPiKP);
570+
if constexpr (reconstructionType == aod::hf_cand::VertexerType::DCAFitter) {
571+
if (pidLcToPKPi == 1 && pidBayesLcToPKPi == 1 && topolLcToPKPi) {
572+
std::vector<float> inputFeaturesLcToPKPi = hfMlResponseDCA.getInputFeatures(candidate, true);
573+
isSelectedMlLcToPKPi = hfMlResponseDCA.isSelectedMl(inputFeaturesLcToPKPi, candidate.pt(), outputMlLcToPKPi);
574+
}
575+
if (pidLcToPiKP == 1 && pidBayesLcToPiKP == 1 && topolLcToPiKP) {
576+
std::vector<float> inputFeaturesLcToPiKP = hfMlResponseDCA.getInputFeatures(candidate, false);
577+
isSelectedMlLcToPiKP = hfMlResponseDCA.isSelectedMl(inputFeaturesLcToPiKP, candidate.pt(), outputMlLcToPiKP);
578+
}
579+
} else {
580+
if (pidLcToPKPi == 1 && pidBayesLcToPKPi == 1 && topolLcToPKPi) {
581+
std::vector<float> inputFeaturesLcToPKPi = hfMlResponseKF.getInputFeatures(candidate, true);
582+
isSelectedMlLcToPKPi = hfMlResponseKF.isSelectedMl(inputFeaturesLcToPKPi, candidate.pt(), outputMlLcToPKPi);
583+
}
584+
if (pidLcToPiKP == 1 && pidBayesLcToPiKP == 1 && topolLcToPiKP) {
585+
std::vector<float> inputFeaturesLcToPiKP = hfMlResponseKF.getInputFeatures(candidate, false);
586+
isSelectedMlLcToPiKP = hfMlResponseKF.isSelectedMl(inputFeaturesLcToPiKP, candidate.pt(), outputMlLcToPiKP);
587+
}
563588
}
564589

565590
hfMlLcToPKPiCandidate(outputMlLcToPKPi, outputMlLcToPiKP);

0 commit comments

Comments
 (0)