Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions PWGJE/Core/JetTaggingUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,44 @@ namespace jettaggingutilities
{
const int cmTomum = 10000; // using cm -> #mum for impact parameter (dca)

struct BJetParams {
float mJetpT = 0.0;
float mJetEta = 0.0;
float mJetPhi = 0.0;
int mNTracks = -1;
int mNSV = -1;
float mJetMass = 0.0;
};

struct BJetTrackParams {
double mTrackpT = 0.0;
double mTrackEta = 0.0;
double mDotProdTrackJet = 0.0;
double mDotProdTrackJetOverJet = 0.0;
double mDeltaRJetTrack = 0.0;
double mSignedIP2D = 0.0;
double mSignedIP2DSign = 0.0;
double mSignedIP3D = 0.0;
double mSignedIP3DSign = 0.0;
double mMomFraction = 0.0;
double mDeltaRTrackVertex = 0.0;
};

struct BJetSVParams {
double mSVpT = 0.0;
double mDeltaRSVJet = 0.0;
double mSVMass = 0.0;
double mSVfE = 0.0;
double mIPXY = 0.0;
double mCPA = 0.0;
double mChi2PCA = 0.0;
double mDispersion = 0.0;
double mDecayLength2D = 0.0;
double mDecayLength2DError = 0.0;
double mDecayLength3D = 0.0;
double mDecayLength3DError = 0.0;
};

//________________________________________________________________________
bool isBHadron(int pc)
{
Expand Down Expand Up @@ -802,6 +840,113 @@ int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyT
return n_vertices;
}

std::vector<std::vector<float>> getInputsForML(BJetParams jetparams, std::vector<BJetTrackParams>& tracksParams, std::vector<BJetSVParams>& svsParams, int maxJetConst = 10)
{
std::vector<float> jetInput = {jetparams.mJetpT, jetparams.mJetEta, jetparams.mJetPhi, static_cast<float>(jetparams.mNTracks), static_cast<float>(jetparams.mNSV), jetparams.mJetMass};
std::vector<float> tracksInputFlat;
std::vector<float> svsInputFlat;

for (int iconstit = 0; iconstit < maxJetConst; iconstit++) {

tracksInputFlat.push_back(tracksParams[iconstit].mTrackpT);
tracksInputFlat.push_back(tracksParams[iconstit].mTrackEta);
tracksInputFlat.push_back(tracksParams[iconstit].mDotProdTrackJet);
tracksInputFlat.push_back(tracksParams[iconstit].mDotProdTrackJetOverJet);
tracksInputFlat.push_back(tracksParams[iconstit].mDeltaRJetTrack);
tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP2D);
tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP2DSign);
tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP3D);
tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP3DSign);
tracksInputFlat.push_back(tracksParams[iconstit].mMomFraction);
tracksInputFlat.push_back(tracksParams[iconstit].mDeltaRTrackVertex);

svsInputFlat.push_back(svsParams[iconstit].mSVpT);
svsInputFlat.push_back(svsParams[iconstit].mDeltaRSVJet);
svsInputFlat.push_back(svsParams[iconstit].mSVMass);
svsInputFlat.push_back(svsParams[iconstit].mSVfE);
svsInputFlat.push_back(svsParams[iconstit].mIPXY);
svsInputFlat.push_back(svsParams[iconstit].mCPA);
svsInputFlat.push_back(svsParams[iconstit].mChi2PCA);
svsInputFlat.push_back(svsParams[iconstit].mDispersion);
svsInputFlat.push_back(svsParams[iconstit].mDecayLength2D);
svsInputFlat.push_back(svsParams[iconstit].mDecayLength2DError);
svsInputFlat.push_back(svsParams[iconstit].mDecayLength3D);
svsInputFlat.push_back(svsParams[iconstit].mDecayLength3DError);
}

std::vector<std::vector<float>> totalInput;
totalInput.push_back(jetInput);
totalInput.push_back(tracksInputFlat);
totalInput.push_back(svsInputFlat);

return totalInput;
}

// Looping over the SV info and putting them in the input vector
template <typename AnalysisJet, typename AnyTracks, typename SecondaryVertices>
void analyzeJetSVInfo4ML(AnalysisJet const& myJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector<BJetSVParams>& svsParams, float svPtMin = 1.0, int svReductionFactor = 3)
{
using SVType = typename SecondaryVertices::iterator;

// Min-heap to store the top 30 SVs by decayLengthXY/errorDecayLengthXY
auto compare = [](SVType& sv1, SVType& sv2) {
return (sv1.decayLengthXY() / sv1.errorDecayLengthXY()) > (sv2.decayLengthXY() / sv2.errorDecayLengthXY());
};

auto svs = myJet.template secondaryVertices_as<SecondaryVertices>();

// Sort the SVs based on their decay length significance in descending order
// This is needed in order to select longest SVs since some jets could have thousands of SVs
std::sort(svs.begin(), svs.end(), compare);

for (const auto& candSV : svs) {

if (candSV.pt() < svPtMin) {
continue;
}

double deltaRJetSV = jetutilities::deltaR(myJet, candSV);
double massSV = candSV.m();
double energySV = candSV.e();

if (svsParams.size() < (svReductionFactor * myJet.template tracks_as<AnyTracks>().size())) {
svsParams.emplace_back(BJetSVParams{candSV.pt(), deltaRJetSV, massSV, energySV / myJet.energy(), candSV.impactParameterXY(), candSV.cpa(), candSV.chi2PCA(), candSV.dispersion(), candSV.decayLengthXY(), candSV.errorDecayLengthXY(), candSV.decayLength(), candSV.errorDecayLength()});
}
}
}

// Looping over the track info and putting them in the input vector
template <typename AnalysisJet, typename AnyTracks, typename SecondaryVertices>
void analyzeJetTrackInfo4ML(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector<BJetTrackParams>& tracksParams, float trackPtMin = 0.5)
{
for (const auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {

if (constituent.pt() < trackPtMin) {
continue;
}

double deltaRJetTrack = jetutilities::deltaR(analysisJet, constituent);
double dotProduct = RecoDecay::dotProd(std::array<float, 3>{analysisJet.px(), analysisJet.py(), analysisJet.pz()}, std::array<float, 3>{constituent.px(), constituent.py(), constituent.pz()});
int sign = jettaggingutilities::getGeoSign(analysisJet, constituent);

float rClosestSV = 10.;
for (const auto& candSV : analysisJet.template secondaryVertices_as<SecondaryVertices>()) {
double deltaRTrackSV = jetutilities::deltaR(constituent, candSV);
if (deltaRTrackSV < rClosestSV) {
rClosestSV = deltaRTrackSV;
}
}

tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), rClosestSV});
}

auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {
return (tr1.mSignedIP2D / tr1.mSignedIP2DSign) > (tr2.mSignedIP2D / tr2.mSignedIP2DSign);
};

// Sort the tracks based on their IP significance in descending order
std::sort(tracksParams.begin(), tracksParams.end(), compare);
}
}; // namespace jettaggingutilities

#endif // PWGJE_CORE_JETTAGGINGUTILITIES_H_
73 changes: 67 additions & 6 deletions PWGJE/TableProducer/jetTaggerHF.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "PWGJE/DataModel/JetTagging.h"
#include "PWGJE/Core/JetTaggingUtilities.h"
#include "PWGJE/Core/JetDerivedDataUtilities.h"
#include "Tools/ML/MlResponse.h"

using namespace o2;
using namespace o2::framework;
Expand All @@ -36,6 +37,8 @@ using namespace o2::framework::expressions;
template <typename JetTableData, typename JetTableMCD, typename JetTaggingTableData, typename JetTaggingTableMCD>
struct JetTaggerHFTask {

static constexpr double DefaultCutsMl[1][2] = {{0.5, 0.5}};

Produces<JetTaggingTableData> taggingTableData;
Produces<JetTaggingTableMCD> taggingTableMCD;

Expand Down Expand Up @@ -68,8 +71,27 @@ struct JetTaggerHFTask {
Configurable<float> tagPointForSV{"tagPointForSV", 40, "tagging working point for SV"};
Configurable<float> tagPointForSVxyz{"tagPointForSVxyz", 40, "tagging working point for SV xyz"};

// axis spec
ConfigurableAxis binTrackProbability{"binTrackProbability", {100, 0.f, 1.f}, ""};
// ML configuration
Configurable<int> nJetConst{"nJetConst", 10, "maximum number of jet consistuents to be used for ML evaluation"};
Configurable<float> trackPtMin{"trackPtMin", 0.5, "minimum track pT"};
Configurable<float> svPtMin{"svPtMin", 0.5, "minimum SV pT"};

Configurable<float> svReductionFactor{"svReductionFactor", 1.0, "factor for how many SVs to keep"};

Configurable<std::vector<double>> binsPtMl{"binsPtMl", std::vector<double>{5., 1000.}, "pT bin limits for ML application"};
Configurable<std::vector<int>> cutDirMl{"cutDirMl", std::vector<int>{cuts_ml::CutSmaller, cuts_ml::CutNot}, "Whether to reject score values greater or smaller than the threshold"};
Configurable<LabeledArray<double>> cutsMl{"cutsMl", {DefaultCutsMl[0], 1, 2, {"pT bin 0"}, {"score for default b-jet tagging", "uncer 1"}}, "ML selections per pT bin"};
Configurable<int> nClassesMl{"nClassesMl", 2, "Number of classes in ML model"};
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"feature1", "feature2"}, "Names of ML model input features"};

Configurable<std::string> ccdbUrl{"ccdbUrl", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
Configurable<std::vector<std::string>> modelPathsCCDB{"modelPathsCCDB", std::vector<std::string>{"Users/h/hahassan"}, "Paths of models on CCDB"};
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"ML_bjets/Models/LHC24g4_70_200/model.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"};
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"};
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};

o2::analysis::MlResponse<float> bMlResponse;
o2::ccdb::CcdbApi ccdbApi;

using JetTagTracksData = soa::Join<aod::JetTracks, aod::JTrackExtras, aod::JTrackPIs>;
using JetTagTracksMCD = soa::Join<aod::JetTracksMCD, aod::JTrackExtras, aod::JTrackPIs>;
Expand Down Expand Up @@ -232,6 +254,45 @@ struct JetTaggerHFTask {
registry.add("h2_neg_track_probability_flavour", "negative track probability", {HistType::kTH2F, {{trackProbabilityAxis}, {jetFlavourAxis}}});
}
}

if (processDataAlgorithmML || processMCDAlgorithmML) {
bMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
if (loadModelsFromCCDB) {
ccdbApi.init(ccdbUrl);
bMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
} else {
bMlResponse.setModelPathsLocal(onnxFileNames);
}
// bMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
bMlResponse.init();
}
}

template <typename AnyJets, typename AnyTracks, typename SecondaryVertices>
void analyzeJetAlgorithmML(AnyJets const& alljets, AnyTracks const& allTracks, SecondaryVertices const& allSVs)
{
for (const auto& analysisJet : alljets) {

std::vector<BJetTrackParams> tracksParams;
std::vector<BJetSVParams> svsParams;

analyzeJetSVInfo4ML(analysisJet, allTracks, allSVs, svsParams, svPtMin, svReductionFactor);
analyzeJetTrackInfo4ML(analysisJet, allTracks, allSVs, tracksParams, trackPtMin);

int nSVs = analysisJet.template secondaryVertices_as<aod::DataSecondaryVertex3Prongs>().size();

BJetParams jetparam = {analysisJet.pt(), analysisJet.eta(), analysisJet.phi(), static_cast<int>(tracksParams.size()), static_cast<int>(nSVs), analysisJet.mass()};
tracksParams.resize(nJetConst); // resize to the number of inputs of the ML
svsParams.resize(nJetConst); // resize to the number of inputs of the ML

auto inputML = getInputsForML(jetparam, tracksParams, svsParams, nJetConst);

std::vector<float> output;
// bool isSelectedMl = bMlResponse.isSelectedMl(inputML, analysisJet.pt(), output);
bMlResponse.isSelectedMl(inputML, analysisJet.pt(), output);

scoreML[jet.globalIndex()] = output[0];
}
}

void processDummy(aod::JetCollisions const&)
Expand Down Expand Up @@ -334,15 +395,15 @@ struct JetTaggerHFTask {
}
PROCESS_SWITCH(JetTaggerHFTask, processMCDWithSV, "Fill tagging decision for mcd jets with sv", false);

void processDataAlgorithmML(aod::JetCollision const& /*collision*/, soa::Join<JetTableData, aod::DataSecondaryVertex3ProngIndices> const& /*allJets*/, JetTagTracksData const& /*allTracks*/, aod::DataSecondaryVertex3Prongs const& allSVs)
void processDataAlgorithmML(soa::Join<JetTableData, aod::DataSecondaryVertex3ProngIndices> const& allJets, JetTagTracksData const& allTracks, aod::DataSecondaryVertex3Prongs const& allSVs)
{
// To create table for ML
analyzeJetAlgorithmML(alljets, allTracks, allSVs);
}
PROCESS_SWITCH(JetTaggerHFTask, processDataAlgorithmML, "Fill ML evaluation score for data jets", false);

void processMCDAlgorithmML(aod::JetCollision const& /*collision*/, soa::Join<JetTableMCD, aod::ChargedMCDetectorLevelJetFlavourDef, aod::MCDSecondaryVertex3ProngIndices> const& /*allJets*/, JetTagTracksMCD const& /*allTracks*/, aod::MCDSecondaryVertex3Prongs const& allSVs)
void processMCDAlgorithmML(soa::Join<JetTableMCD, aod::MCDSecondaryVertex3ProngIndices> const& allJets, JetTagTracksMCD const& allTracks, aod::MCDSecondaryVertex3Prongs const& allSVs)
{
// To create table for ML
analyzeJetAlgorithmML(alljets, allTracks, allSVs);
}
PROCESS_SWITCH(JetTaggerHFTask, processMCDAlgorithmML, "Fill ML evaluation score for MCD jets", false);
};
Expand Down
Loading