Skip to content

Commit 9c681c2

Browse files
choich08365chchoi
andauthored
[PWGJE] Add an extra table and vertexClustering function for GNN b-jet tagging (#8597)
Co-authored-by: chchoi <changhwan.choi@cern.ch>
1 parent 82f622f commit 9c681c2

File tree

2 files changed

+399
-1
lines changed

2 files changed

+399
-1
lines changed

PWGJE/Core/JetTaggingUtilities.h

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include <algorithm>
2727
#include <functional>
2828
#include <memory>
29+
#include <string>
30+
#include <unordered_map>
2931

3032
#include "TF1.h"
3133
#include "Framework/Logger.h"
@@ -629,6 +631,177 @@ bool isTaggedJetSV(T const jet, U const& /*prongs*/, float const& prongChi2PCAMi
629631
return true;
630632
}
631633

634+
/**
635+
* Clusters jet constituent tracks into groups of tracks originating from same mcParticle position (trkVtxIndex), and finds each track origin (trkOrigin). (for GNN b-jet tagging)
636+
* @param trkLabels Track labels for GNN vertex and track origin predictions. trkVtxIndex: The index value of each vertex (cluster) which is determined by the function. trkOrigin: The category of the track origin (0: not physical primary, 1: charm, 2: beauty, 3: primary vertex, 4: other secondary vertex).
637+
* @param vtxResParam Vertex resolution parameter which determines the cluster size. (cm)
638+
* @param trackPtMin Minimum value of track pT.
639+
* @return The number of vertices (clusters) in the jet.
640+
*/
641+
template <typename AnyCollision, typename AnalysisJet, typename AnyTracks, typename AnyParticles, typename AnyOriginalParticles>
642+
int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyTracks const&, AnyParticles const& particles, AnyOriginalParticles const&, std::unordered_map<std::string, std::vector<int>>& trkLabels, bool searchUpToQuark, float vtxResParam = 0.01 /* 0.01cm = 100um */, float trackPtMin = 0.5)
643+
{
644+
const auto& tracks = jet.template tracks_as<AnyTracks>();
645+
const int n_trks = tracks.size();
646+
647+
// trkVtxIndex
648+
649+
std::vector<int> tempTrkVtxIndex;
650+
651+
int i = 0;
652+
for (const auto& constituent : tracks) {
653+
if (!constituent.has_mcParticle() || !constituent.template mcParticle_as<AnyParticles>().isPhysicalPrimary() || constituent.pt() < trackPtMin)
654+
tempTrkVtxIndex.push_back(-1);
655+
else
656+
tempTrkVtxIndex.push_back(i++);
657+
}
658+
tempTrkVtxIndex.push_back(i); // temporary index for PV
659+
if (n_trks < 1) { // the process should be done for n_trks == 1 as well
660+
trkLabels["trkVtxIndex"] = tempTrkVtxIndex;
661+
return n_trks;
662+
}
663+
664+
int n_pos = n_trks + 1;
665+
std::vector<float> dists(n_pos * (n_pos - 1) / 2);
666+
auto trk_pair_idx = [n_pos](int ti, int tj) {
667+
if (ti == tj || ti >= n_pos || tj >= n_pos || ti < 0 || tj < 0) {
668+
LOGF(info, "Track pair index out of range");
669+
return -1;
670+
} else {
671+
return (ti < tj) ? (ti * n_pos - (ti * (ti + 1)) / 2 + tj - ti - 1) : (tj * n_pos - (tj * (tj + 1)) / 2 + ti - tj - 1);
672+
}
673+
}; // index n_trks is for PV
674+
675+
for (int ti = 0; ti < n_pos - 1; ti++)
676+
for (int tj = ti + 1; tj < n_pos; tj++) {
677+
std::array<float, 3> posi, posj;
678+
679+
if (tj < n_trks) {
680+
if (tracks[tj].has_mcParticle()) {
681+
const auto& pj = tracks[tj].template mcParticle_as<AnyParticles>().template mcParticle_as<AnyOriginalParticles>();
682+
posj = std::array<float, 3>{pj.vx(), pj.vy(), pj.vz()};
683+
} else {
684+
dists[trk_pair_idx(ti, tj)] = std::numeric_limits<float>::max();
685+
continue;
686+
}
687+
} else {
688+
posj = std::array<float, 3>{collision.posX(), collision.posY(), collision.posZ()};
689+
}
690+
691+
if (tracks[ti].has_mcParticle()) {
692+
const auto& pi = tracks[ti].template mcParticle_as<AnyParticles>().template mcParticle_as<AnyOriginalParticles>();
693+
posi = std::array<float, 3>{pi.vx(), pi.vy(), pi.vz()};
694+
} else {
695+
dists[trk_pair_idx(ti, tj)] = std::numeric_limits<float>::max();
696+
continue;
697+
}
698+
699+
dists[trk_pair_idx(ti, tj)] = RecoDecay::distance(posi, posj);
700+
}
701+
702+
int clusteri = -1, clusterj = -1;
703+
float min_min_dist = -1.f; // If there is an not-merge-able min_dist pair, check the 2nd-min_dist pair.
704+
while (true) {
705+
706+
float min_dist = -1.f; // Get min_dist pair
707+
for (int ti = 0; ti < n_pos - 1; ti++)
708+
for (int tj = ti + 1; tj < n_pos; tj++)
709+
if (tempTrkVtxIndex[ti] != tempTrkVtxIndex[tj] && tempTrkVtxIndex[ti] >= 0 && tempTrkVtxIndex[tj] >= 0) {
710+
float dist = dists[trk_pair_idx(ti, tj)];
711+
if ((dist < min_dist || min_dist < 0.f) && dist > min_min_dist) {
712+
min_dist = dist;
713+
clusteri = ti;
714+
clusterj = tj;
715+
}
716+
}
717+
if (clusteri < 0 || clusterj < 0)
718+
break;
719+
720+
bool mrg = true; // Merge-ability check
721+
for (int ti = 0; ti < n_pos && mrg; ti++)
722+
if (tempTrkVtxIndex[ti] == tempTrkVtxIndex[clusteri] && tempTrkVtxIndex[ti] >= 0) {
723+
for (int tj = 0; tj < n_pos && mrg; tj++)
724+
if (tj != ti && tempTrkVtxIndex[tj] == tempTrkVtxIndex[clusterj] && tempTrkVtxIndex[tj] >= 0) {
725+
if (dists[trk_pair_idx(ti, tj)] > vtxResParam) { // If there is more distant pair compared to vtx_res between two clusters, they cannot be merged.
726+
mrg = false;
727+
min_min_dist = min_dist;
728+
}
729+
}
730+
}
731+
if (min_dist > vtxResParam || min_dist < 0.f)
732+
break;
733+
734+
if (mrg) { // Merge two clusters
735+
int old_index = tempTrkVtxIndex[clusterj];
736+
for (int t = 0; t < n_pos; t++)
737+
if (tempTrkVtxIndex[t] == old_index)
738+
tempTrkVtxIndex[t] = tempTrkVtxIndex[clusteri];
739+
}
740+
}
741+
742+
int n_vertices = 0;
743+
744+
// Sort the indices from PV (as 0) to the most distant SV (as 1~).
745+
int idxPV = tempTrkVtxIndex[n_trks];
746+
for (int t = 0; t < n_trks; t++)
747+
if (tempTrkVtxIndex[t] == idxPV) {
748+
tempTrkVtxIndex[t] = -2;
749+
n_vertices = 1; // There is a track originating from PV
750+
}
751+
752+
std::unordered_map<int, float> avgDistances;
753+
std::unordered_map<int, int> count;
754+
for (int t = 0; t < n_trks; t++) {
755+
if (tempTrkVtxIndex[t] >= 0) {
756+
avgDistances[tempTrkVtxIndex[t]] += dists[trk_pair_idx(t, n_trks)];
757+
count[tempTrkVtxIndex[t]]++;
758+
}
759+
}
760+
761+
trkLabels["trkVtxIndex"] = std::vector<int>(n_trks, -1);
762+
if (count.size() != 0) { // If there is any SV cluster not only PV cluster
763+
for (auto& [idx, avgDistance] : avgDistances)
764+
avgDistance /= count[idx];
765+
766+
n_vertices += avgDistances.size();
767+
768+
std::vector<std::pair<int, float>> sortedIndices(avgDistances.begin(), avgDistances.end());
769+
std::sort(sortedIndices.begin(), sortedIndices.end(), [](const auto& a, const auto& b) { return a.second < b.second; });
770+
int rank = 1;
771+
for (const auto& [idx, avgDistance] : sortedIndices) {
772+
bool found = false;
773+
for (int t = 0; t < n_trks; t++)
774+
if (tempTrkVtxIndex[t] == idx) {
775+
trkLabels["trkVtxIndex"][t] = rank;
776+
found = true;
777+
}
778+
rank += found;
779+
}
780+
}
781+
782+
for (int t = 0; t < n_trks; t++)
783+
if (tempTrkVtxIndex[t] == -2)
784+
trkLabels["trkVtxIndex"][t] = 0;
785+
786+
// trkOrigin
787+
788+
int trkIdx = 0;
789+
for (auto& constituent : jet.template tracks_as<AnyTracks>()) {
790+
if (!constituent.has_mcParticle() || !constituent.template mcParticle_as<AnyParticles>().isPhysicalPrimary() || constituent.pt() < trackPtMin) {
791+
trkLabels["trkOrigin"].push_back(0);
792+
} else {
793+
const auto& particle = constituent.template mcParticle_as<AnyParticles>();
794+
int orig = RecoDecay::getParticleOrigin(particles, particle, searchUpToQuark);
795+
trkLabels["trkOrigin"].push_back((orig > 0) ? orig : (trkLabels["trkVtxIndex"][trkIdx] == 0) ? 3
796+
: 4);
797+
}
798+
799+
trkIdx++;
800+
}
801+
802+
return n_vertices;
803+
}
804+
632805
}; // namespace jettaggingutilities
633806

634807
#endif // PWGJE_CORE_JETTAGGINGUTILITIES_H_

0 commit comments

Comments
 (0)