Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions PWGJE/Core/JetTaggingUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
double deltaRJetTrack = 0.0;
double signedIP2D = 0.0;
double signedIP2DSign = 0.0;
double signedIPz = 0.0;
double signedIPzSign = 0.0;
double signedIP3D = 0.0;
double signedIP3DSign = 0.0;
double momFraction = 0.0;
Expand Down Expand Up @@ -153,7 +155,7 @@

int motherStatusCode = std::abs(mother.getGenStatusCode());

if (motherStatusCode == 23 || motherStatusCode == 33 || motherStatusCode == 43 || motherStatusCode == 63) {

Check failure on line 158 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
return mother.globalIndex();
}
}
Expand Down Expand Up @@ -181,7 +183,7 @@

int motherStatusCode = std::abs(mother.getGenStatusCode());

if (motherStatusCode == 23 || motherStatusCode == 33 || motherStatusCode == 43 || motherStatusCode == 63 || (motherStatusCode == 51 && mother.template mothers_first_as<T>().pdgCode() == 21)) {

Check failure on line 186 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.

Check failure on line 186 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[pdg/explicit-code]

Avoid hard-coded PDG codes. Use named values from PDG_t or o2::constants::physics::Pdg instead.
return mother.globalIndex();
}
}
Expand Down Expand Up @@ -209,12 +211,12 @@
hasMcParticle = true;
auto const& particle = track.template mcParticle_as<V>();
origin = RecoDecay::getParticleOrigin(particles, particle, searchUpToQuark);
if (origin == 1 || origin == 2) { // 1=charm , 2=beauty

Check failure on line 214 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
hftrack = track;
if (origin == 1) {
return JetTaggingSpecies::charm;
}
if (origin == 2) {

Check failure on line 219 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
return JetTaggingSpecies::beauty;
}
}
Expand Down Expand Up @@ -242,7 +244,7 @@
for (const auto& particle : jet.template tracks_as<U>()) {
hfparticle = particle; // for init if origin is 1 or 2, the particle is not hfparticle
origin = RecoDecay::getParticleOrigin(particles, particle, searchUpToQuark);
if (origin == 1 || origin == 2) { // 1=charm , 2=beauty

Check failure on line 247 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
hfparticle = particle;
if (origin == 1) {
return JetTaggingSpecies::charm;
Expand Down Expand Up @@ -385,13 +387,13 @@
bool charmQuark = false;
for (auto const& mcpart : mcparticles) {
int pdgcode = mcpart.pdgCode();
if (std::abs(pdgcode) == 21 || (std::abs(pdgcode) >= 1 && std::abs(pdgcode) <= 5)) {

Check failure on line 390 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[pdg/explicit-code]

Avoid hard-coded PDG codes. Use named values from PDG_t or o2::constants::physics::Pdg instead.
double dR = jetutilities::deltaR(jet, mcpart);

if (dR < jet.r() / 100.f) {
if (std::abs(pdgcode) == 5) {

Check failure on line 394 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[pdg/explicit-code]

Avoid hard-coded PDG codes. Use named values from PDG_t or o2::constants::physics::Pdg instead.
return JetTaggingSpecies::beauty; // Beauty jet
} else if (std::abs(pdgcode) == 4) {

Check failure on line 396 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[pdg/explicit-code]

Avoid hard-coded PDG codes. Use named values from PDG_t or o2::constants::physics::Pdg instead.
charmQuark = true;
}
}
Expand Down Expand Up @@ -914,7 +916,7 @@

trkLabels["trkVtxIndex"] = std::vector<int>(nTrks, -1);
if (count.size() != 0) { // If there is any SV cluster not only PV cluster
for (auto& [idx, avgDistance] : avgDistances) // o2-linter: disable=const-ref-in-for-loop

Check failure on line 919 in PWGJE/Core/JetTaggingUtilities.h

View workflow job for this annotation

GitHub Actions / O2 linter

[const-ref-in-for-loop]

Use constant references for non-modified iterators in range-based for loops.
avgDistance /= count[idx];

nVertices += avgDistances.size();
Expand Down Expand Up @@ -1011,7 +1013,7 @@
}
}

tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), rClosestSV});
tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), rClosestSV});
}

auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {
Expand All @@ -1036,7 +1038,7 @@
double dotProduct = RecoDecay::dotProd(std::array<float, 3>{analysisJet.px(), analysisJet.py(), analysisJet.pz()}, std::array<float, 3>{constituent.px(), constituent.py(), constituent.pz()});
int sign = getGeoSign(analysisJet, constituent);

tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), 0.0});
tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), 0.0});
}

auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {
Expand Down
29 changes: 29 additions & 0 deletions PWGJE/Core/MlResponseHfTagging.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ enum class InputFeaturesBTag : uint8_t {
deltaRJetTrack,
signedIP2D,
signedIP2DSign,
signedIPz,
signedIPzSign,
signedIP3D,
signedIP3DSign,
momFraction,
Expand Down Expand Up @@ -148,6 +150,8 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
CHECK_AND_FILL_VEC_BTAG(trackInput, track, deltaRJetTrack)
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIP2D)
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIP2DSign)
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIPz)
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIPzSign)
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIP3D)
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIP3DSign)
CHECK_AND_FILL_VEC_BTAG(trackInput, track, momFraction)
Expand Down Expand Up @@ -192,6 +196,23 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
}
}

/// @brief Method to replace NaN and infinity values in a vector with a specified value
/// @param vec is the vector to be processed
/// @param value is the value to replace NaN values with
/// @return the number of NaN values replaced
template <typename T>
static int replaceNaN(std::vector<T>& vec, T value)
{
int numNaN = 0;
for (auto& el : vec) {
if (std::isnan(el) || std::isinf(el)) {
el = value;
++numNaN;
}
}
return numNaN;
}

/// Method to get the input features vector needed for ML inference in a 2D vector
/// \param jet is the b-jet candidate
/// \param tracks is the vector of tracks associated to the jet
Expand All @@ -209,6 +230,10 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>

std::vector<std::vector<float>> inputFeatures;

replaceNaN(jetInput, 0.f);
replaceNaN(trackInput, 0.f);
replaceNaN(svInput, 0.f);

inputFeatures.push_back(jetInput);
inputFeatures.push_back(trackInput);
inputFeatures.push_back(svInput);
Expand Down Expand Up @@ -237,6 +262,8 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
inputFeatures.insert(inputFeatures.end(), trackInput.begin(), trackInput.end());
inputFeatures.insert(inputFeatures.end(), svInput.begin(), svInput.end());

replaceNaN(inputFeatures, 0.f);

return inputFeatures;
}

Expand All @@ -261,6 +288,8 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
FILL_MAP_BJET(deltaRJetTrack),
FILL_MAP_BJET(signedIP2D),
FILL_MAP_BJET(signedIP2DSign),
FILL_MAP_BJET(signedIPz),
FILL_MAP_BJET(signedIPzSign),
FILL_MAP_BJET(signedIP3D),
FILL_MAP_BJET(signedIP3DSign),
FILL_MAP_BJET(momFraction),
Expand Down
10 changes: 6 additions & 4 deletions PWGJE/Tasks/bjetTreeCreator.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ DECLARE_SOA_COLUMN(DotProdTrackJetOverJet, trackdotjetoverjet, float); //! The d
DECLARE_SOA_COLUMN(DeltaRJetTrack, rjettrack, float); //! The DR jet-track
DECLARE_SOA_COLUMN(SignedIP2D, ip2d, float); //! The track signed 2D IP
DECLARE_SOA_COLUMN(SignedIP2DSign, ip2dsigma, float); //! The track signed 2D IP significance
DECLARE_SOA_COLUMN(SignedIP3D, ip3d, float); //! The track signed 3D IP
DECLARE_SOA_COLUMN(SignedIPz, ipz, float); //! The track signed z IP
DECLARE_SOA_COLUMN(SignedIPzSign, ipzsigma, float); //! The track signed z IP significance
DECLARE_SOA_COLUMN(SignedIP3DSign, ip3dsigma, float); //! The track signed 3D IP significance
DECLARE_SOA_COLUMN(MomFraction, momfraction, float); //! The track momentum fraction of the jets
DECLARE_SOA_COLUMN(DeltaRTrackVertex, rtrackvertex, float); //! DR between the track and the closest SV, to be decided whether to add to or not
Expand All @@ -108,7 +109,8 @@ DECLARE_SOA_TABLE(bjetTracksParams, "AOD", "BJETTRACKSPARAM",
trackInfo::DeltaRJetTrack,
trackInfo::SignedIP2D,
trackInfo::SignedIP2DSign,
trackInfo::SignedIP3D,
trackInfo::SignedIPz,
trackInfo::SignedIPzSign,
trackInfo::SignedIP3DSign,
trackInfo::MomFraction,
trackInfo::DeltaRTrackVertex);
Expand Down Expand Up @@ -460,7 +462,7 @@ struct BJetTreeCreator {
}

if (produceTree) {
bjetTracksParamsTable(bjetParamsTable.lastIndex() + 1, constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), RClosestSV);
bjetTracksParamsTable(bjetParamsTable.lastIndex() + 1, constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ(), constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), RClosestSV);
}
trackIndices.push_back(bjetTracksParamsTable.lastIndex());
}
Expand Down Expand Up @@ -531,7 +533,7 @@ struct BJetTreeCreator {
}

if (produceTree) {
bjetTracksParamsTable(bjetParamsTable.lastIndex() + 1, constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), 0.);
bjetTracksParamsTable(bjetParamsTable.lastIndex() + 1, constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ(), constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), 0.);
}
trackIndices.push_back(bjetTracksParamsTable.lastIndex());
}
Expand Down
Loading