Skip to content

Commit 46d1da1

Browse files
Merge pull request #32 from alibuild/alibot-cleanup-12242
Please consider the following formatting changes to #12242
2 parents 3ceb3ec + ca352c8 commit 46d1da1

File tree

2 files changed

+102
-98
lines changed

2 files changed

+102
-98
lines changed

PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

Lines changed: 101 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,28 @@
3232
// -- v0builderopts ......: V0-specific building options (topological, deduplication, etc)
3333
// -- cascadebuilderopts .: cascade-specific building options (topological, etc)
3434

35-
#include <string>
36-
#include <vector>
37-
38-
#include "Framework/DataSpecUtils.h"
39-
#include "Framework/runDataProcessing.h"
40-
#include "Framework/AnalysisTask.h"
41-
#include "Framework/AnalysisDataModel.h"
42-
#include "Common/DataModel/PIDResponse.h"
4335
#include "TableHelper.h"
44-
#include "PWGLF/DataModel/LFStrangenessTables.h"
36+
4537
#include "PWGLF/DataModel/LFStrangenessPIDTables.h"
38+
#include "PWGLF/DataModel/LFStrangenessTables.h"
4639
#include "PWGLF/Utils/strangenessBuilderHelper.h"
47-
#include "CCDB/BasicCCDBManager.h"
48-
#include "DataFormatsParameters/GRPObject.h"
49-
#include "DataFormatsParameters/GRPMagField.h"
40+
5041
#include "Common/Core/TPCVDriftManager.h"
42+
#include "Common/DataModel/PIDResponse.h"
5143
#include "Tools/ML/MlResponse.h"
5244
#include "Tools/ML/model.h"
5345

46+
#include "CCDB/BasicCCDBManager.h"
47+
#include "DataFormatsParameters/GRPMagField.h"
48+
#include "DataFormatsParameters/GRPObject.h"
49+
#include "Framework/AnalysisDataModel.h"
50+
#include "Framework/AnalysisTask.h"
51+
#include "Framework/DataSpecUtils.h"
52+
#include "Framework/runDataProcessing.h"
53+
54+
#include <string>
55+
#include <vector>
56+
5457
using namespace o2;
5558
using namespace o2::framework;
5659
using namespace o2::ml;
@@ -159,8 +162,8 @@ struct StrangenessBuilder {
159162
// helper object
160163
o2::pwglf::strangenessBuilderHelper straHelper;
161164

162-
// ML model
163-
o2::ml::OnnxModel deduplication_bdt;
165+
// ML model
166+
o2::ml::OnnxModel deduplication_bdt;
164167

165168
// table index : match order above
166169
enum tableIndex { kV0Indices = 0,
@@ -302,27 +305,27 @@ struct StrangenessBuilder {
302305

303306
struct : ConfigurableGroup {
304307
std::string prefix = "DeduplicationOpts";
305-
308+
306309
Configurable<int> deduplicationAlgorithm{"deduplicationAlgorithm", 1,
307-
"0: disabled;"
308-
"1: best pointing angle wins;"
309-
"2: best DCA daughters wins;"
310-
"3: best pointing and best DCA wins;"
311-
"4: best BDT score wins;"
312-
"5: selects on PA (not a winner takes it all approach!);"
313-
"6: selects on BDT score (not a winner takes it all approach!)"};
314-
315-
// BDT settings
310+
"0: disabled;"
311+
"1: best pointing angle wins;"
312+
"2: best DCA daughters wins;"
313+
"3: best pointing and best DCA wins;"
314+
"4: best BDT score wins;"
315+
"5: selects on PA (not a winner takes it all approach!);"
316+
"6: selects on BDT score (not a winner takes it all approach!)"};
317+
318+
// BDT settings
316319
Configurable<std::string> BDTLocalPath{"BDTLocalPath", "Deduplication_BDTModel.onnx", "(std::string) Path to the local .onnx file."};
317320
Configurable<std::string> BDTPathCCDB{"BDTPathCCDB", "Users/g/gsetouel/MLModels2", "Path on CCDB"};
318321
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"};
319322
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
320-
Configurable<bool> enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};
321-
323+
Configurable<bool> enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};
324+
322325
// Selection based duplicates removal
323326
Configurable<float> PAthreshold{"PAthreshold", 0.02, "PA cut to remove duplicates."};
324327
Configurable<float> BDTthreshold{"BDTthreshold", 0.7, "BDT score cut to remove duplicates."};
325-
328+
326329
} DeduplicationOpts;
327330

328331
// V0 buffer for V0s used in cascades: master switch
@@ -545,14 +548,14 @@ struct StrangenessBuilder {
545548
int mcParticleBachelor;
546549
};
547550
mcCascinfo thisCascInfo;
548-
//*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*
551+
//*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*
549552
// Helper structure to save v0 duplicates auxiliary info
550-
struct V0DuplicateExtra {
553+
struct V0DuplicateExtra {
551554
bool isBestPA;
552555
bool isBestDCADau;
553556
bool isBestMLScore;
554557
bool isBuildOk;
555-
float PA;
558+
float PA;
556559
float V0DCAToPVz;
557560
float V0zVtx;
558561
float MLScore;
@@ -635,14 +638,14 @@ struct StrangenessBuilder {
635638
hFindable->GetXaxis()->SetBinLabel(6, "Cascades with collId -1");
636639
}
637640

638-
if (DeduplicationOpts.deduplicationAlgorithm.value > 0){
641+
if (DeduplicationOpts.deduplicationAlgorithm.value > 0) {
639642
histos.add("DeduplicationQA/hMLScore", "hMLScore", kTH1F, {{200, 0.0f, 1.0f}});
640643
histos.add("DeduplicationQA/hPA", "hPA", kTH1F, {{200, 0.0f, 0.4f}});
641644
histos.add("DeduplicationQA/hBestPA", "hBestPA", kTH1F, {{200, 0.0f, 0.4f}});
642645
histos.add("DeduplicationQA/hBestDCADau", "hBestDCADau", kTH1F, {{200, -10.0f, 10.0f}});
643646
histos.add("DeduplicationQA/hBestMLScore", "hBestMLScore", kTH1F, {{200, 0.0f, 1.0f}});
644647
}
645-
648+
646649
auto hPrimaryV0s = histos.add<TH1>("hPrimaryV0s", "hPrimaryV0s", kTH1D, {{2, -0.5f, 1.5f}});
647650
hPrimaryV0s->GetXaxis()->SetBinLabel(1, "All V0s");
648651
hPrimaryV0s->GetXaxis()->SetBinLabel(2, "Primary V0s");
@@ -764,22 +767,21 @@ struct StrangenessBuilder {
764767
straHelper.cascadeselections.lambdaMassWindow = cascadeBuilderOpts.lambdaMassWindow;
765768
straHelper.cascadeselections.maxDaughterEta = cascadeBuilderOpts.maxDaughterEta;
766769

767-
// Loading BDT model
768-
if (DeduplicationOpts.deduplicationAlgorithm.value==4 || DeduplicationOpts.deduplicationAlgorithm.value==6){
769-
if (DeduplicationOpts.loadModelsFromCCDB) {
770-
770+
// Loading BDT model
771+
if (DeduplicationOpts.deduplicationAlgorithm.value == 4 || DeduplicationOpts.deduplicationAlgorithm.value == 6) {
772+
if (DeduplicationOpts.loadModelsFromCCDB) {
773+
771774
/// Fetching model for specific timestamp
772775
LOG(info) << "Fetching model for timestamp: " << DeduplicationOpts.timestampCCDB.value;
773-
776+
774777
bool retrieveSuccess = ccdbApi.retrieveBlob(DeduplicationOpts.BDTPathCCDB.value, ".", metadata, DeduplicationOpts.timestampCCDB.value, false, DeduplicationOpts.BDTLocalPath.value);
775778
if (retrieveSuccess) {
776779
deduplication_bdt.initModel(DeduplicationOpts.BDTLocalPath.value, DeduplicationOpts.enableOptimizations.value);
777780
} else {
778781
LOG(fatal) << "Error encountered while fetching/loading the Gamma model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
779782
}
780-
}
781-
else{
782-
deduplication_bdt.initModel(DeduplicationOpts.BDTLocalPath.value, DeduplicationOpts.enableOptimizations.value);
783+
} else {
784+
deduplication_bdt.initModel(DeduplicationOpts.BDTLocalPath.value, DeduplicationOpts.enableOptimizations.value);
783785
}
784786
}
785787
}
@@ -799,35 +801,36 @@ struct StrangenessBuilder {
799801
}
800802

801803
// Simple function to rank vectors based on values
802-
std::vector<int> rankSort(const std::vector<float>& v_temp, bool descending = false) {
804+
std::vector<int> rankSort(const std::vector<float>& v_temp, bool descending = false)
805+
{
803806
std::vector<std::pair<float, size_t>> v_sort(v_temp.size());
804807

805808
// Pair each value with its original index
806809
for (size_t i = 0U; i < v_temp.size(); ++i) {
807-
v_sort[i] = std::make_pair(v_temp[i], i);
810+
v_sort[i] = std::make_pair(v_temp[i], i);
808811
}
809812

810813
// Sort by value - ascending: lowest gets rank 1, descending: highest gets rank 1
811814

812815
if (descending) {
813-
std::sort(v_sort.begin(), v_sort.end(), [](const auto& a, const auto& b) {
814-
return a.first > b.first;
815-
});
816+
std::sort(v_sort.begin(), v_sort.end(), [](const auto& a, const auto& b) {
817+
return a.first > b.first;
818+
});
816819
} else {
817-
std::sort(v_sort.begin(), v_sort.end(), [](const auto& a, const auto& b) {
818-
return a.first < b.first;
819-
});
820+
std::sort(v_sort.begin(), v_sort.end(), [](const auto& a, const auto& b) {
821+
return a.first < b.first;
822+
});
820823
}
821824

822825
std::pair<float, size_t> rank_tracker = std::make_pair(std::numeric_limits<float>::quiet_NaN(), 0);
823826
std::vector<int> result(v_temp.size());
824827

825828
for (size_t i = 0U; i < v_sort.size(); ++i) {
826-
// Only update rank if value is different from previous
827-
if (v_sort[i].first != rank_tracker.first) {
828-
rank_tracker = std::make_pair(v_sort[i].first, i + 1); // +1 for 1-based rank
829-
}
830-
result[v_sort[i].second] = rank_tracker.second; // assign rank to original index
829+
// Only update rank if value is different from previous
830+
if (v_sort[i].first != rank_tracker.first) {
831+
rank_tracker = std::make_pair(v_sort[i].first, i + 1); // +1 for 1-based rank
832+
}
833+
result[v_sort[i].second] = rank_tracker.second; // assign rank to original index
831834
}
832835

833836
return result;
@@ -861,27 +864,27 @@ struct StrangenessBuilder {
861864

862865
// Defining context variables
863866
int NDuplicates = 0;
864-
float AvgPA = 0.0f;
867+
float AvgPA = 0.0f;
865868

866869
// Containers for ranking
867870
std::vector<float> paVec(V0Grouped[iV0].collisionIds.size(), 999.f);
868871
std::vector<float> v0zVec(V0Grouped[iV0].collisionIds.size(), 999.f);
869872

870873
// Auxiliary vector to store V0 duplicate info
871-
std::vector<V0DuplicateExtra> V0DuplicateExtras;
874+
std::vector<V0DuplicateExtra> V0DuplicateExtras;
872875

873-
// Loop over duplicates
876+
// Loop over duplicates
874877
for (size_t ic = 0; ic < V0Grouped[iV0].collisionIds.size(); ic++) {
875878

876879
// Helper structure to save duplicates info - initializing with dummy values
877-
V0DuplicateExtra v0DuplicateInfo;
880+
V0DuplicateExtra v0DuplicateInfo;
878881
v0DuplicateInfo.isBestPA = false;
879882
v0DuplicateInfo.isBestDCADau = false;
880883
v0DuplicateInfo.isBestMLScore = false;
881884
v0DuplicateInfo.isBuildOk = false;
882-
v0DuplicateInfo.PA = 10;
883-
v0DuplicateInfo.V0DCAToPVz = 999.f;
884-
v0DuplicateInfo.V0zVtx = 999.f;
885+
v0DuplicateInfo.PA = 10;
886+
v0DuplicateInfo.V0DCAToPVz = 999.f;
887+
v0DuplicateInfo.V0zVtx = 999.f;
885888
v0DuplicateInfo.MLScore = -1;
886889

887890
// We include V0DuplicateExtra info in the vector at this point to avoid indexing issues later
@@ -916,7 +919,7 @@ struct StrangenessBuilder {
916919
// <false>: do not apply selections: do as much as possible to preserve
917920
// candidate at this level and do not select with topo selections
918921
if (straHelper.buildV0Candidate<false>(V0Grouped[iV0].collisionIds[ic], collision.posX(), collision.posY(), collision.posZ(), pTrack, nTrack, posTrackPar, negTrackPar, true, false, true)) {
919-
922+
920923
// candidate built, check pointing angle
921924
if (straHelper.v0.pointingAngle < bestPointingAngle) {
922925
bestPointingAngle = straHelper.v0.pointingAngle;
@@ -928,33 +931,32 @@ struct StrangenessBuilder {
928931
}
929932

930933
// Calculating features for ML Analysis
931-
if (DeduplicationOpts.deduplicationAlgorithm.value==4 || DeduplicationOpts.deduplicationAlgorithm.value==6){
932-
AvgPA += straHelper.v0.pointingAngle;
934+
if (DeduplicationOpts.deduplicationAlgorithm.value == 4 || DeduplicationOpts.deduplicationAlgorithm.value == 6) {
935+
AvgPA += straHelper.v0.pointingAngle;
933936
paVec[ic] = straHelper.v0.pointingAngle;
934-
v0zVec[ic] = std::abs(straHelper.v0.position[2]);
937+
v0zVec[ic] = std::abs(straHelper.v0.position[2]);
935938
NDuplicates++;
936939
}
937-
940+
938941
// Updating values in the struct
939-
V0DuplicateExtras[ic].isBuildOk = true;
940-
V0DuplicateExtras[ic].PA = straHelper.v0.pointingAngle;
941-
V0DuplicateExtras[ic].V0DCAToPVz = std::abs(straHelper.v0.v0DCAToPVz);
942-
V0DuplicateExtras[ic].V0zVtx = std::abs(straHelper.v0.position[2]);
942+
V0DuplicateExtras[ic].isBuildOk = true;
943+
V0DuplicateExtras[ic].PA = straHelper.v0.pointingAngle;
944+
V0DuplicateExtras[ic].V0DCAToPVz = std::abs(straHelper.v0.v0DCAToPVz);
945+
V0DuplicateExtras[ic].V0zVtx = std::abs(straHelper.v0.position[2]);
943946
} // end build V0
944947
} // end candidate loop
945-
946948

947949
// Additional loop to perform ML Analysis if requested
948-
if (DeduplicationOpts.deduplicationAlgorithm.value==4 || DeduplicationOpts.deduplicationAlgorithm.value==6){
949-
950-
// Preparing features
950+
if (DeduplicationOpts.deduplicationAlgorithm.value == 4 || DeduplicationOpts.deduplicationAlgorithm.value == 6) {
951+
952+
// Preparing features
951953
if (NDuplicates > 0)
952-
AvgPA /= NDuplicates;
953-
954+
AvgPA /= NDuplicates;
955+
954956
// Get vector of ranks
955957
std::vector<int> paRanks = rankSort(paVec, false);
956958
std::vector<int> v0zRanks = rankSort(v0zVec, false);
957-
959+
958960
// Fill the ML score for all candidates
959961
for (size_t ic = 0; ic < V0Grouped[iV0].collisionIds.size(); ic++) {
960962

@@ -963,28 +965,27 @@ struct StrangenessBuilder {
963965
continue;
964966

965967
// Input vector for BDT
966-
std::vector<float> inputFeatures{V0DuplicateExtras[ic].V0DCAToPVz, // 1. V0DCAToPVz
967-
V0DuplicateExtras[ic].PA, // 2. Pointing Angle
968-
V0DuplicateExtras[ic].V0zVtx, // 3. V0 Vtx z-position
969-
static_cast<float>(paRanks[ic]), // 4. Pointing Angle Rank
970-
static_cast<float>(NDuplicates), // 5. N. of Duplicates
971-
AvgPA, // 6. Avg Pointing Angle
972-
static_cast<float>(v0zRanks[ic])}; // 7. V0 Vtx z Rank
968+
std::vector<float> inputFeatures{V0DuplicateExtras[ic].V0DCAToPVz, // 1. V0DCAToPVz
969+
V0DuplicateExtras[ic].PA, // 2. Pointing Angle
970+
V0DuplicateExtras[ic].V0zVtx, // 3. V0 Vtx z-position
971+
static_cast<float>(paRanks[ic]), // 4. Pointing Angle Rank
972+
static_cast<float>(NDuplicates), // 5. N. of Duplicates
973+
AvgPA, // 6. Avg Pointing Angle
974+
static_cast<float>(v0zRanks[ic])}; // 7. V0 Vtx z Rank
973975

976+
float* BDTProbability = deduplication_bdt.evalModel(inputFeatures);
974977

975-
float* BDTProbability = deduplication_bdt.evalModel(inputFeatures);
976-
977978
if (BDTProbability[1] > bestMLScore) {
978979
bestMLScore = BDTProbability[1];
979980
bestMLScoreIndex = ic;
980-
}
981-
981+
}
982+
982983
// QA histo
983-
histos.fill(HIST("DeduplicationQA/hMLScore"), BDTProbability[1]);
984+
histos.fill(HIST("DeduplicationQA/hMLScore"), BDTProbability[1]);
984985
histos.fill(HIST("DeduplicationQA/hPA"), V0DuplicateExtras[ic].PA);
985986

986-
// Updating BDT score info in the struct
987-
V0DuplicateExtras[ic].MLScore = BDTProbability[1];
987+
// Updating BDT score info in the struct
988+
V0DuplicateExtras[ic].MLScore = BDTProbability[1];
988989
}
989990
}
990991

@@ -993,12 +994,15 @@ struct StrangenessBuilder {
993994
histos.fill(HIST("DeduplicationQA/hBestMLScore"), bestMLScore);
994995

995996
// Final step: Defining the winners:
996-
if (bestPointingAngleIndex != static_cast<size_t>(-1)) V0DuplicateExtras[bestPointingAngleIndex].isBestPA = true;
997-
if (bestDCADaughtersIndex != static_cast<size_t>(-1)) V0DuplicateExtras[bestDCADaughtersIndex].isBestDCADau = true;
998-
if (bestMLScoreIndex != static_cast<size_t>(-1)) V0DuplicateExtras[bestMLScoreIndex].isBestMLScore = true;
999-
1000-
// return vector with duplicates info
1001-
return V0DuplicateExtras;
997+
if (bestPointingAngleIndex != static_cast<size_t>(-1))
998+
V0DuplicateExtras[bestPointingAngleIndex].isBestPA = true;
999+
if (bestDCADaughtersIndex != static_cast<size_t>(-1))
1000+
V0DuplicateExtras[bestDCADaughtersIndex].isBestDCADau = true;
1001+
if (bestMLScoreIndex != static_cast<size_t>(-1))
1002+
V0DuplicateExtras[bestMLScoreIndex].isBestMLScore = true;
1003+
1004+
// return vector with duplicates info
1005+
return V0DuplicateExtras;
10021006
}
10031007

10041008
template <typename TCollisions>
@@ -1177,7 +1181,7 @@ struct StrangenessBuilder {
11771181

11781182
// skip if empty
11791183
if (deduplicationOutput.empty()) {
1180-
continue;
1184+
continue;
11811185
}
11821186

11831187
// mark de-duplicated candidates

PWGLF/Utils/strangenessBuilderHelper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ class strangenessBuilderHelper
355355

356356
// propagate to collision vertex
357357
o2::base::Propagator::Instance()->propagateToDCABxByBz({pvX, pvY, pvZ}, V0Temp, 2.f, fitter.getMatCorrType(), &dcaV0Info);
358-
v0.v0DCAToPVxy = dcaV0Info[0];
358+
v0.v0DCAToPVxy = dcaV0Info[0];
359359
v0.v0DCAToPVz = dcaV0Info[1];
360360

361361
v0.positiveTrackX = fitter.getTrack(0).getX();

0 commit comments

Comments
 (0)