Skip to content

Commit 2e3f7ae

Browse files
committed
Remove redundant configurables in D* selector
1 parent 5589eab commit 2e3f7ae

File tree

1 file changed

+13
-38
lines changed

1 file changed

+13
-38
lines changed

PWGHF/TableProducer/candidateSelectorDstarToD0Pi.cxx

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,11 @@ struct HfCandidateSelectorDstarToD0Pi {
9999
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"feature1", "feature2"}, "Names of ML model input features"};
100100
// ML inference D0
101101
Configurable<bool> applyMlD0Daug{"applyMlD0Daug", false, "Flag to apply ML selections on D0 daughter"};
102-
Configurable<std::vector<double>> binsPtMlD0Daug{"binsPtMlD0Daug", std::vector<double>{hf_cuts_ml::vecBinsPt}, "pT bin limits for ML application on D0 daughter"};
103-
Configurable<std::vector<int>> cutDirMlD0Daug{"cutDirMlD0Daug", std::vector<int>{hf_cuts_ml::vecCutDir}, "Whether to reject score values greater or smaller than the threshold on D0 daughter"};
104-
Configurable<LabeledArray<double>> cutsMlD0Daug{"cutsMlD0Daug", {hf_cuts_ml::Cuts[0], hf_cuts_ml::NBinsPt, hf_cuts_ml::NCutScores, hf_cuts_ml::labelsPt, hf_cuts_ml::labelsCutScore}, "ML selections per pT bin on D0 daughter"};
105-
Configurable<int> nClassesMlD0Daug{"nClassesMlD0Daug", static_cast<int>(hf_cuts_ml::NCutScores), "Number of classes in ML model on D0 daughter"};
106-
Configurable<std::vector<std::string>> namesInputFeaturesD0Daug{"namesInputFeaturesD0Daug", std::vector<std::string>{"feature1", "feature2"}, "Names of ML model input features on D0 daughter"};
107102

108103
// CCDB configuration
109104
Configurable<std::string> ccdbUrl{"ccdbUrl", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
110105
Configurable<std::vector<std::string>> modelPathsCCDB{"modelPathsCCDB", std::vector<std::string>{""}, "Paths of models on CCDB"};
111-
Configurable<std::vector<std::string>> modelPathsCCDBD0Daug{"modelPathsCCDBD0Daug", std::vector<std::string>{""}, "Paths of models on CCDB for D0 daughter"};
112106
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"Model.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"};
113-
Configurable<std::vector<std::string>> onnxFileNamesD0Daug{"onnxFileNamesD0Daug", std::vector<std::string>{"Model.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path) for D0 daughter"};
114107
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"};
115108
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
116109

@@ -119,16 +112,13 @@ struct HfCandidateSelectorDstarToD0Pi {
119112

120113
HfHelper hfHelper;
121114
o2::analysis::HfMlResponseDstarToD0Pi<float> hfMlResponse;
122-
o2::analysis::HfMlResponseDstarToD0Pi<float> hfMlResponseD0Daughter;
123115
std::vector<float> outputMlDstarToD0Pi = {};
124-
std::vector<float> outputMlD0ToKPi = {};
125116
o2::ccdb::CcdbApi ccdbApi;
126117

127118
TrackSelectorPi selectorPion;
128119
TrackSelectorKa selectorKaon;
129120

130121
using TracksSel = soa::Join<aod::TracksWDcaExtra, aod::TracksPidPi, aod::PidTpcTofFullPi, aod::TracksPidKa, aod::PidTpcTofFullKa>;
131-
// using TracksSel = soa::Join<aod::Tracks, aod::TracksPidPi, aod::TracksPidKa>;
132122
using HfFullDstarCandidate = soa::Join<aod::HfD0FromDstar, aod::HfCandDstarsWPid>;
133123

134124
AxisSpec axisBdtScore{100, 0.f, 1.f};
@@ -165,14 +155,14 @@ struct HfCandidateSelectorDstarToD0Pi {
165155
registry.get<TH2>(HIST("QA/hSelections"))->GetXaxis()->SetBinLabel(iBin + 1, labels[iBin].data());
166156
}
167157

168-
if (applyMl) {
158+
if (applyMl || applyMlD0Daug) {
169159
registry.add("QA/hBdtScore1VsStatus", ";BDT score", {HistType::kTH1F, {axisBdtScore}});
170160
registry.add("QA/hBdtScore2VsStatus", ";BDT score", {HistType::kTH1F, {axisBdtScore}});
171161
registry.add("QA/hBdtScore3VsStatus", ";BDT score", {HistType::kTH1F, {axisBdtScore}});
172162
}
173163
}
174164

175-
if (applyMl) {
165+
if (applyMl || applyMlD0Daug) {
176166
hfMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
177167
if (loadModelsFromCCDB) {
178168
ccdbApi.init(ccdbUrl);
@@ -183,18 +173,6 @@ struct HfCandidateSelectorDstarToD0Pi {
183173
hfMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
184174
hfMlResponse.init();
185175
}
186-
187-
if (applyMlD0Daug) {
188-
hfMlResponseD0Daughter.configure(binsPtMlD0Daug, cutsMlD0Daug, cutDirMlD0Daug, nClassesMlD0Daug);
189-
if (loadModelsFromCCDB) {
190-
ccdbApi.init(ccdbUrl);
191-
hfMlResponseD0Daughter.setModelPathsCCDB(onnxFileNamesD0Daug, ccdbApi, modelPathsCCDBD0Daug, timestampCCDB);
192-
} else {
193-
hfMlResponseD0Daughter.setModelPathsLocal(onnxFileNamesD0Daug);
194-
}
195-
hfMlResponseD0Daughter.cacheInputFeaturesIndices(namesInputFeaturesD0Daug);
196-
hfMlResponseD0Daughter.init();
197-
}
198176
}
199177

200178
/// Conjugate-independent topological cuts on D0
@@ -258,14 +236,6 @@ struct HfCandidateSelectorDstarToD0Pi {
258236
return false;
259237
}
260238

261-
if (applyMlD0Daug) {
262-
outputMlD0ToKPi.clear();
263-
std::vector<float> inputFeaturesD0 = hfMlResponseD0Daughter.getInputFeaturesTrigger(candidate);
264-
bool isSelectedMlD0 = hfMlResponseD0Daughter.isSelectedMl(inputFeaturesD0, candpT, outputMlD0ToKPi);
265-
if (!isSelectedMlD0) {
266-
return false;
267-
}
268-
}
269239
return true;
270240
}
271241

@@ -404,7 +374,7 @@ struct HfCandidateSelectorDstarToD0Pi {
404374

405375
if (!TESTBIT(candDstar.hfflag(), aod::hf_cand_2prong::DecayType::D0ToPiK)) {
406376
hfSelDstarCandidate(statusDstar, statusD0Flag, statusTopol, statusCand, statusPID);
407-
if (applyMl) {
377+
if (applyMl || applyMlD0Daug) {
408378
hfMlDstarCandidate(outputMlDstarToD0Pi);
409379
}
410380
if (activateQA) {
@@ -420,7 +390,7 @@ struct HfCandidateSelectorDstarToD0Pi {
420390

421391
if (!selectionDstar(candDstar)) {
422392
hfSelDstarCandidate(statusDstar, statusD0Flag, statusTopol, statusCand, statusPID);
423-
if (applyMl) {
393+
if (applyMl || applyMlD0Daug) {
424394
hfMlDstarCandidate(outputMlDstarToD0Pi);
425395
}
426396
continue;
@@ -433,7 +403,7 @@ struct HfCandidateSelectorDstarToD0Pi {
433403
bool topoDstar = selectionTopolConjugate(candDstar);
434404
if (!topoDstar) {
435405
hfSelDstarCandidate(statusDstar, statusD0Flag, statusTopol, statusCand, statusPID);
436-
if (applyMl) {
406+
if (applyMl || applyMlD0Daug) {
437407
hfMlDstarCandidate(outputMlDstarToD0Pi);
438408
}
439409
continue;
@@ -481,7 +451,7 @@ struct HfCandidateSelectorDstarToD0Pi {
481451

482452
if (pidDstar == 0) {
483453
hfSelDstarCandidate(statusDstar, statusD0Flag, statusTopol, statusCand, statusPID);
484-
if (applyMl) {
454+
if (applyMl || applyMlD0Daug) {
485455
hfMlDstarCandidate(outputMlDstarToD0Pi);
486456
}
487457
continue;
@@ -496,11 +466,16 @@ struct HfCandidateSelectorDstarToD0Pi {
496466
}
497467
statusPID = true;
498468

499-
if (applyMl) {
469+
if (applyMl || applyMlD0Daug) {
500470
// ML selections
501471
bool isSelectedMlDstar = false;
502472

503-
std::vector<float> inputFeatures = hfMlResponse.getInputFeatures(candDstar);
473+
std::vector<float> inputFeatures{};
474+
if (applyMlD0Daug) {
475+
inputFeatures = hfMlResponse.getInputFeaturesTrigger(candDstar);
476+
} else {
477+
inputFeatures = hfMlResponse.getInputFeatures(candDstar);
478+
}
504479
isSelectedMlDstar = hfMlResponse.isSelectedMl(inputFeatures, ptCand, outputMlDstarToD0Pi);
505480

506481
hfMlDstarCandidate(outputMlDstarToD0Pi);

0 commit comments

Comments
 (0)