Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 101 additions & 97 deletions PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,28 @@
// -- v0builderopts ......: V0-specific building options (topological, deduplication, etc)
// -- cascadebuilderopts .: cascade-specific building options (topological, etc)

#include <string>
#include <vector>

#include "Framework/DataSpecUtils.h"
#include "Framework/runDataProcessing.h"
#include "Framework/AnalysisTask.h"
#include "Framework/AnalysisDataModel.h"
#include "Common/DataModel/PIDResponse.h"
#include "TableHelper.h"
#include "PWGLF/DataModel/LFStrangenessTables.h"

#include "PWGLF/DataModel/LFStrangenessPIDTables.h"
#include "PWGLF/DataModel/LFStrangenessTables.h"
#include "PWGLF/Utils/strangenessBuilderHelper.h"
#include "CCDB/BasicCCDBManager.h"
#include "DataFormatsParameters/GRPObject.h"
#include "DataFormatsParameters/GRPMagField.h"

#include "Common/Core/TPCVDriftManager.h"
#include "Common/DataModel/PIDResponse.h"
#include "Tools/ML/MlResponse.h"
#include "Tools/ML/model.h"

#include "CCDB/BasicCCDBManager.h"
#include "DataFormatsParameters/GRPMagField.h"
#include "DataFormatsParameters/GRPObject.h"
#include "Framework/AnalysisDataModel.h"
#include "Framework/AnalysisTask.h"
#include "Framework/DataSpecUtils.h"
#include "Framework/runDataProcessing.h"

#include <string>
#include <vector>

using namespace o2;
using namespace o2::framework;
using namespace o2::ml;
Expand Down Expand Up @@ -159,8 +162,8 @@
// helper object
o2::pwglf::strangenessBuilderHelper straHelper;

// ML model
o2::ml::OnnxModel deduplication_bdt;
// ML model
o2::ml::OnnxModel deduplication_bdt;

// table index : match order above
enum tableIndex { kV0Indices = 0,
Expand Down Expand Up @@ -302,27 +305,27 @@

struct : ConfigurableGroup {
std::string prefix = "DeduplicationOpts";

Configurable<int> deduplicationAlgorithm{"deduplicationAlgorithm", 1,
"0: disabled;"
"1: best pointing angle wins;"
"2: best DCA daughters wins;"
"3: best pointing and best DCA wins;"
"4: best BDT score wins;"
"5: selects on PA (not a winner takes it all approach!);"
"6: selects on BDT score (not a winner takes it all approach!)"};

// BDT settings
"0: disabled;"
"1: best pointing angle wins;"
"2: best DCA daughters wins;"
"3: best pointing and best DCA wins;"
"4: best BDT score wins;"
"5: selects on PA (not a winner takes it all approach!);"
"6: selects on BDT score (not a winner takes it all approach!)"};

// BDT settings
Configurable<std::string> BDTLocalPath{"BDTLocalPath", "Deduplication_BDTModel.onnx", "(std::string) Path to the local .onnx file."};
Configurable<std::string> BDTPathCCDB{"BDTPathCCDB", "Users/g/gsetouel/MLModels2", "Path on CCDB"};
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"};
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
Configurable<bool> enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};
Configurable<bool> enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};

// Selection based duplicates removal
Configurable<float> PAthreshold{"PAthreshold", 0.02, "PA cut to remove duplicates."};
Configurable<float> BDTthreshold{"BDTthreshold", 0.7, "BDT score cut to remove duplicates."};

} DeduplicationOpts;

// V0 buffer for V0s used in cascades: master switch
Expand Down Expand Up @@ -435,19 +438,19 @@

float getMassSigmaK0Short(float pt)
{
return preSelectOpts.massCutK0->get("constant") + pt * preSelectOpts.massCutK0->get("linear") + preSelectOpts.massCutK0->get("expoConstant") * TMath::Exp(-pt / preSelectOpts.massCutK0->get("expoRelax"));

Check failure on line 441 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[root/entity]

Replace ROOT entities with equivalents from standard C++ or from O2.
}
float getMassSigmaLambda(float pt)
{
return preSelectOpts.massCutLambda->get("constant") + pt * preSelectOpts.massCutLambda->get("linear") + preSelectOpts.massCutLambda->get("expoConstant") * TMath::Exp(-pt / preSelectOpts.massCutLambda->get("expoRelax"));

Check failure on line 445 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[root/entity]

Replace ROOT entities with equivalents from standard C++ or from O2.
}
float getMassSigmaXi(float pt)
{
return preSelectOpts.massCutXi->get("constant") + pt * preSelectOpts.massCutXi->get("linear") + preSelectOpts.massCutXi->get("expoConstant") * TMath::Exp(-pt / preSelectOpts.massCutXi->get("expoRelax"));

Check failure on line 449 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[root/entity]

Replace ROOT entities with equivalents from standard C++ or from O2.
}
float getMassSigmaOmega(float pt)
{
return preSelectOpts.massCutOm->get("constant") + pt * preSelectOpts.massCutOm->get("linear") + preSelectOpts.massCutOm->get("expoConstant") * TMath::Exp(-pt / preSelectOpts.massCutOm->get("expoRelax"));

Check failure on line 453 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[root/entity]

Replace ROOT entities with equivalents from standard C++ or from O2.
}

o2::ccdb::CcdbApi ccdbApi;
Expand Down Expand Up @@ -545,14 +548,14 @@
int mcParticleBachelor;
};
mcCascinfo thisCascInfo;
//*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*
//*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*+-+*
// Helper structure to save v0 duplicates auxiliary info
struct V0DuplicateExtra {
struct V0DuplicateExtra {
bool isBestPA;
bool isBestDCADau;
bool isBestMLScore;
bool isBuildOk;
float PA;
float PA;
float V0DCAToPVz;
float V0zVtx;
float MLScore;
Expand Down Expand Up @@ -635,14 +638,14 @@
hFindable->GetXaxis()->SetBinLabel(6, "Cascades with collId -1");
}

if (DeduplicationOpts.deduplicationAlgorithm.value > 0){
if (DeduplicationOpts.deduplicationAlgorithm.value > 0) {
histos.add("DeduplicationQA/hMLScore", "hMLScore", kTH1F, {{200, 0.0f, 1.0f}});
histos.add("DeduplicationQA/hPA", "hPA", kTH1F, {{200, 0.0f, 0.4f}});
histos.add("DeduplicationQA/hBestPA", "hBestPA", kTH1F, {{200, 0.0f, 0.4f}});
histos.add("DeduplicationQA/hBestDCADau", "hBestDCADau", kTH1F, {{200, -10.0f, 10.0f}});
histos.add("DeduplicationQA/hBestMLScore", "hBestMLScore", kTH1F, {{200, 0.0f, 1.0f}});
}

auto hPrimaryV0s = histos.add<TH1>("hPrimaryV0s", "hPrimaryV0s", kTH1D, {{2, -0.5f, 1.5f}});
hPrimaryV0s->GetXaxis()->SetBinLabel(1, "All V0s");
hPrimaryV0s->GetXaxis()->SetBinLabel(2, "Primary V0s");
Expand Down Expand Up @@ -764,22 +767,21 @@
straHelper.cascadeselections.lambdaMassWindow = cascadeBuilderOpts.lambdaMassWindow;
straHelper.cascadeselections.maxDaughterEta = cascadeBuilderOpts.maxDaughterEta;

// Loading BDT model
if (DeduplicationOpts.deduplicationAlgorithm.value==4 || DeduplicationOpts.deduplicationAlgorithm.value==6){
if (DeduplicationOpts.loadModelsFromCCDB) {
// Loading BDT model
if (DeduplicationOpts.deduplicationAlgorithm.value == 4 || DeduplicationOpts.deduplicationAlgorithm.value == 6) {
if (DeduplicationOpts.loadModelsFromCCDB) {

/// Fetching model for specific timestamp
LOG(info) << "Fetching model for timestamp: " << DeduplicationOpts.timestampCCDB.value;

bool retrieveSuccess = ccdbApi.retrieveBlob(DeduplicationOpts.BDTPathCCDB.value, ".", metadata, DeduplicationOpts.timestampCCDB.value, false, DeduplicationOpts.BDTLocalPath.value);
if (retrieveSuccess) {
deduplication_bdt.initModel(DeduplicationOpts.BDTLocalPath.value, DeduplicationOpts.enableOptimizations.value);
} else {
LOG(fatal) << "Error encountered while fetching/loading the Gamma model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
}
}
else{
deduplication_bdt.initModel(DeduplicationOpts.BDTLocalPath.value, DeduplicationOpts.enableOptimizations.value);
} else {
deduplication_bdt.initModel(DeduplicationOpts.BDTLocalPath.value, DeduplicationOpts.enableOptimizations.value);
}
}
}
Expand All @@ -799,35 +801,36 @@
}

// Simple function to rank vectors based on values
std::vector<int> rankSort(const std::vector<float>& v_temp, bool descending = false) {
std::vector<int> rankSort(const std::vector<float>& v_temp, bool descending = false)
{
std::vector<std::pair<float, size_t>> v_sort(v_temp.size());

// Pair each value with its original index
for (size_t i = 0U; i < v_temp.size(); ++i) {
v_sort[i] = std::make_pair(v_temp[i], i);
v_sort[i] = std::make_pair(v_temp[i], i);
}

// Sort by value - ascending: lowest gets rank 1, descending: highest gets rank 1

if (descending) {
std::sort(v_sort.begin(), v_sort.end(), [](const auto& a, const auto& b) {
return a.first > b.first;
});
std::sort(v_sort.begin(), v_sort.end(), [](const auto& a, const auto& b) {
return a.first > b.first;
});
} else {
std::sort(v_sort.begin(), v_sort.end(), [](const auto& a, const auto& b) {
return a.first < b.first;
});
std::sort(v_sort.begin(), v_sort.end(), [](const auto& a, const auto& b) {
return a.first < b.first;
});
}

std::pair<float, size_t> rank_tracker = std::make_pair(std::numeric_limits<float>::quiet_NaN(), 0);
std::vector<int> result(v_temp.size());

for (size_t i = 0U; i < v_sort.size(); ++i) {
// Only update rank if value is different from previous
if (v_sort[i].first != rank_tracker.first) {
rank_tracker = std::make_pair(v_sort[i].first, i + 1); // +1 for 1-based rank
}
result[v_sort[i].second] = rank_tracker.second; // assign rank to original index
// Only update rank if value is different from previous
if (v_sort[i].first != rank_tracker.first) {
rank_tracker = std::make_pair(v_sort[i].first, i + 1); // +1 for 1-based rank
}
result[v_sort[i].second] = rank_tracker.second; // assign rank to original index
}

return result;
Expand Down Expand Up @@ -861,27 +864,27 @@

// Defining context variables
int NDuplicates = 0;
float AvgPA = 0.0f;
float AvgPA = 0.0f;

// Containers for ranking
std::vector<float> paVec(V0Grouped[iV0].collisionIds.size(), 999.f);
std::vector<float> v0zVec(V0Grouped[iV0].collisionIds.size(), 999.f);

// Auxiliary vector to store V0 duplicate info
std::vector<V0DuplicateExtra> V0DuplicateExtras;
std::vector<V0DuplicateExtra> V0DuplicateExtras;

// Loop over duplicates
// Loop over duplicates
for (size_t ic = 0; ic < V0Grouped[iV0].collisionIds.size(); ic++) {

// Helper structure to save duplicates info - initializing with dummy values
V0DuplicateExtra v0DuplicateInfo;
V0DuplicateExtra v0DuplicateInfo;
v0DuplicateInfo.isBestPA = false;
v0DuplicateInfo.isBestDCADau = false;
v0DuplicateInfo.isBestMLScore = false;
v0DuplicateInfo.isBuildOk = false;
v0DuplicateInfo.PA = 10;
v0DuplicateInfo.V0DCAToPVz = 999.f;
v0DuplicateInfo.V0zVtx = 999.f;
v0DuplicateInfo.PA = 10;
v0DuplicateInfo.V0DCAToPVz = 999.f;
v0DuplicateInfo.V0zVtx = 999.f;
v0DuplicateInfo.MLScore = -1;

// We include V0DuplicateExtra info in the vector at this point to avoid indexing issues later
Expand Down Expand Up @@ -916,7 +919,7 @@
// <false>: do not apply selections: do as much as possible to preserve
// candidate at this level and do not select with topo selections
if (straHelper.buildV0Candidate<false>(V0Grouped[iV0].collisionIds[ic], collision.posX(), collision.posY(), collision.posZ(), pTrack, nTrack, posTrackPar, negTrackPar, true, false, true)) {

// candidate built, check pointing angle
if (straHelper.v0.pointingAngle < bestPointingAngle) {
bestPointingAngle = straHelper.v0.pointingAngle;
Expand All @@ -928,33 +931,32 @@
}

// Calculating features for ML Analysis
if (DeduplicationOpts.deduplicationAlgorithm.value==4 || DeduplicationOpts.deduplicationAlgorithm.value==6){
AvgPA += straHelper.v0.pointingAngle;
if (DeduplicationOpts.deduplicationAlgorithm.value == 4 || DeduplicationOpts.deduplicationAlgorithm.value == 6) {
AvgPA += straHelper.v0.pointingAngle;
paVec[ic] = straHelper.v0.pointingAngle;
v0zVec[ic] = std::abs(straHelper.v0.position[2]);
v0zVec[ic] = std::abs(straHelper.v0.position[2]);
NDuplicates++;
}

// Updating values in the struct
V0DuplicateExtras[ic].isBuildOk = true;
V0DuplicateExtras[ic].PA = straHelper.v0.pointingAngle;
V0DuplicateExtras[ic].V0DCAToPVz = std::abs(straHelper.v0.v0DCAToPVz);
V0DuplicateExtras[ic].V0zVtx = std::abs(straHelper.v0.position[2]);
V0DuplicateExtras[ic].isBuildOk = true;
V0DuplicateExtras[ic].PA = straHelper.v0.pointingAngle;
V0DuplicateExtras[ic].V0DCAToPVz = std::abs(straHelper.v0.v0DCAToPVz);
V0DuplicateExtras[ic].V0zVtx = std::abs(straHelper.v0.position[2]);
} // end build V0
} // end candidate loop


// Additional loop to perform ML Analysis if requested
if (DeduplicationOpts.deduplicationAlgorithm.value==4 || DeduplicationOpts.deduplicationAlgorithm.value==6){
// Preparing features
if (DeduplicationOpts.deduplicationAlgorithm.value == 4 || DeduplicationOpts.deduplicationAlgorithm.value == 6) {

// Preparing features
if (NDuplicates > 0)
AvgPA /= NDuplicates;
AvgPA /= NDuplicates;

// Get vector of ranks
std::vector<int> paRanks = rankSort(paVec, false);
std::vector<int> v0zRanks = rankSort(v0zVec, false);

// Fill the ML score for all candidates
for (size_t ic = 0; ic < V0Grouped[iV0].collisionIds.size(); ic++) {

Expand All @@ -963,28 +965,27 @@
continue;

// Input vector for BDT
std::vector<float> inputFeatures{V0DuplicateExtras[ic].V0DCAToPVz, // 1. V0DCAToPVz
V0DuplicateExtras[ic].PA, // 2. Pointing Angle
V0DuplicateExtras[ic].V0zVtx, // 3. V0 Vtx z-position
static_cast<float>(paRanks[ic]), // 4. Pointing Angle Rank
static_cast<float>(NDuplicates), // 5. N. of Duplicates
AvgPA, // 6. Avg Pointing Angle
static_cast<float>(v0zRanks[ic])}; // 7. V0 Vtx z Rank
std::vector<float> inputFeatures{V0DuplicateExtras[ic].V0DCAToPVz, // 1. V0DCAToPVz
V0DuplicateExtras[ic].PA, // 2. Pointing Angle
V0DuplicateExtras[ic].V0zVtx, // 3. V0 Vtx z-position
static_cast<float>(paRanks[ic]), // 4. Pointing Angle Rank
static_cast<float>(NDuplicates), // 5. N. of Duplicates
AvgPA, // 6. Avg Pointing Angle
static_cast<float>(v0zRanks[ic])}; // 7. V0 Vtx z Rank

float* BDTProbability = deduplication_bdt.evalModel(inputFeatures);

float* BDTProbability = deduplication_bdt.evalModel(inputFeatures);

if (BDTProbability[1] > bestMLScore) {
bestMLScore = BDTProbability[1];
bestMLScoreIndex = ic;
}
}

// QA histo
histos.fill(HIST("DeduplicationQA/hMLScore"), BDTProbability[1]);
histos.fill(HIST("DeduplicationQA/hMLScore"), BDTProbability[1]);
histos.fill(HIST("DeduplicationQA/hPA"), V0DuplicateExtras[ic].PA);

// Updating BDT score info in the struct
V0DuplicateExtras[ic].MLScore = BDTProbability[1];
// Updating BDT score info in the struct
V0DuplicateExtras[ic].MLScore = BDTProbability[1];
}
}

Expand All @@ -993,12 +994,15 @@
histos.fill(HIST("DeduplicationQA/hBestMLScore"), bestMLScore);

// Final step: Defining the winners:
if (bestPointingAngleIndex != static_cast<size_t>(-1)) V0DuplicateExtras[bestPointingAngleIndex].isBestPA = true;
if (bestDCADaughtersIndex != static_cast<size_t>(-1)) V0DuplicateExtras[bestDCADaughtersIndex].isBestDCADau = true;
if (bestMLScoreIndex != static_cast<size_t>(-1)) V0DuplicateExtras[bestMLScoreIndex].isBestMLScore = true;

// return vector with duplicates info
return V0DuplicateExtras;
if (bestPointingAngleIndex != static_cast<size_t>(-1))
V0DuplicateExtras[bestPointingAngleIndex].isBestPA = true;
if (bestDCADaughtersIndex != static_cast<size_t>(-1))
V0DuplicateExtras[bestDCADaughtersIndex].isBestDCADau = true;
if (bestMLScoreIndex != static_cast<size_t>(-1))
V0DuplicateExtras[bestMLScoreIndex].isBestMLScore = true;

// return vector with duplicates info
return V0DuplicateExtras;
}

template <typename TCollisions>
Expand Down Expand Up @@ -1177,7 +1181,7 @@

// skip if empty
if (deduplicationOutput.empty()) {
continue;
continue;
}

// mark de-duplicated candidates
Expand Down Expand Up @@ -1250,10 +1254,10 @@

bool trackIsInteresting = false;
if (
(originParticle.pdgCode() == 310 && v0BuilderOpts.mc_addGeneratedK0Short.value > 0) ||

Check failure on line 1257 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[pdg/explicit-code]

Avoid hard-coded PDG codes. Use named values from PDG_t or o2::constants::physics::Pdg instead.
(originParticle.pdgCode() == 3122 && v0BuilderOpts.mc_addGeneratedLambda.value > 0) ||

Check failure on line 1258 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[pdg/explicit-code]

Avoid hard-coded PDG codes. Use named values from PDG_t or o2::constants::physics::Pdg instead.
(originParticle.pdgCode() == -3122 && v0BuilderOpts.mc_addGeneratedAntiLambda.value > 0) ||

Check failure on line 1259 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[pdg/explicit-code]

Avoid hard-coded PDG codes. Use named values from PDG_t or o2::constants::physics::Pdg instead.
(originParticle.pdgCode() == 22 && v0BuilderOpts.mc_addGeneratedGamma.value > 0)) {

Check failure on line 1260 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[pdg/explicit-code]

Avoid hard-coded PDG codes. Use named values from PDG_t or o2::constants::physics::Pdg instead.
trackIsInteresting = true;
}
if (!trackIsInteresting) {
Expand Down Expand Up @@ -1303,7 +1307,7 @@
currentV0Entry.pdgCode = positiveTrackIndex.pdgCode;
currentV0Entry.particleId = positiveTrackIndex.originId;
currentV0Entry.isCollinearV0 = false;
if (v0BuilderOpts.mc_addGeneratedGammaMakeCollinear.value && currentV0Entry.pdgCode == 22) {

Check failure on line 1310 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[pdg/explicit-code]

Avoid hard-coded PDG codes. Use named values from PDG_t or o2::constants::physics::Pdg instead.
currentV0Entry.isCollinearV0 = true;
}
currentV0Entry.found = false;
Expand Down Expand Up @@ -1745,7 +1749,7 @@
straHelper.v0.daughterDCA,
straHelper.v0.positiveDCAxy,
straHelper.v0.negativeDCAxy,
TMath::Cos(straHelper.v0.pointingAngle),

Check failure on line 1752 in PWGLF/TableProducer/Strangeness/strangenessbuilder.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[root/entity]

Replace ROOT entities with equivalents from standard C++ or from O2.
straHelper.v0.dcaToPV,
v0.v0Type);
products.v0dataLink(products.v0cores.lastIndex(), -1);
Expand Down
2 changes: 1 addition & 1 deletion PWGLF/Utils/strangenessBuilderHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class strangenessBuilderHelper

// propagate to collision vertex
o2::base::Propagator::Instance()->propagateToDCABxByBz({pvX, pvY, pvZ}, V0Temp, 2.f, fitter.getMatCorrType(), &dcaV0Info);
v0.v0DCAToPVxy = dcaV0Info[0];
v0.v0DCAToPVxy = dcaV0Info[0];
v0.v0DCAToPVz = dcaV0Info[1];

v0.positiveTrackX = fitter.getTrack(0).getX();
Expand Down
Loading