|
26 | 26 | #include <algorithm> |
27 | 27 | #include <functional> |
28 | 28 | #include <memory> |
| 29 | +#include <string> |
| 30 | +#include <unordered_map> |
29 | 31 |
|
30 | 32 | #include "TF1.h" |
31 | 33 | #include "Framework/Logger.h" |
@@ -629,6 +631,177 @@ bool isTaggedJetSV(T const jet, U const& /*prongs*/, float const& prongChi2PCAMi |
629 | 631 | return true; |
630 | 632 | } |
631 | 633 |
|
| 634 | +/** |
| 635 | + * Clusters jet constituent tracks into groups of tracks originating from same mcParticle position (trkVtxIndex), and finds each track origin (trkOrigin). (for GNN b-jet tagging) |
| 636 | + * @param trkLabels Track labels for GNN vertex and track origin predictions. trkVtxIndex: The index value of each vertex (cluster) which is determined by the function. trkOrigin: The category of the track origin (0: not physical primary, 1: charm, 2: beauty, 3: primary vertex, 4: other secondary vertex). |
| 637 | + * @param vtxResParam Vertex resolution parameter which determines the cluster size. (cm) |
| 638 | + * @param trackPtMin Minimum value of track pT. |
| 639 | + * @return The number of vertices (clusters) in the jet. |
| 640 | + */ |
| 641 | +template <typename AnyCollision, typename AnalysisJet, typename AnyTracks, typename AnyParticles, typename AnyOriginalParticles> |
| 642 | +int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyTracks const&, AnyParticles const& particles, AnyOriginalParticles const&, std::unordered_map<std::string, std::vector<int>>& trkLabels, bool searchUpToQuark, float vtxResParam = 0.01 /* 0.01cm = 100um */, float trackPtMin = 0.5) |
| 643 | +{ |
| 644 | + const auto& tracks = jet.template tracks_as<AnyTracks>(); |
| 645 | + const int n_trks = tracks.size(); |
| 646 | + |
| 647 | + // trkVtxIndex |
| 648 | + |
| 649 | + std::vector<int> tempTrkVtxIndex; |
| 650 | + |
| 651 | + int i = 0; |
| 652 | + for (const auto& constituent : tracks) { |
| 653 | + if (!constituent.has_mcParticle() || !constituent.template mcParticle_as<AnyParticles>().isPhysicalPrimary() || constituent.pt() < trackPtMin) |
| 654 | + tempTrkVtxIndex.push_back(-1); |
| 655 | + else |
| 656 | + tempTrkVtxIndex.push_back(i++); |
| 657 | + } |
| 658 | + tempTrkVtxIndex.push_back(i); // temporary index for PV |
| 659 | + if (n_trks < 1) { // the process should be done for n_trks == 1 as well |
| 660 | + trkLabels["trkVtxIndex"] = tempTrkVtxIndex; |
| 661 | + return n_trks; |
| 662 | + } |
| 663 | + |
| 664 | + int n_pos = n_trks + 1; |
| 665 | + std::vector<float> dists(n_pos * (n_pos - 1) / 2); |
| 666 | + auto trk_pair_idx = [n_pos](int ti, int tj) { |
| 667 | + if (ti == tj || ti >= n_pos || tj >= n_pos || ti < 0 || tj < 0) { |
| 668 | + LOGF(info, "Track pair index out of range"); |
| 669 | + return -1; |
| 670 | + } else { |
| 671 | + return (ti < tj) ? (ti * n_pos - (ti * (ti + 1)) / 2 + tj - ti - 1) : (tj * n_pos - (tj * (tj + 1)) / 2 + ti - tj - 1); |
| 672 | + } |
| 673 | + }; // index n_trks is for PV |
| 674 | + |
| 675 | + for (int ti = 0; ti < n_pos - 1; ti++) |
| 676 | + for (int tj = ti + 1; tj < n_pos; tj++) { |
| 677 | + std::array<float, 3> posi, posj; |
| 678 | + |
| 679 | + if (tj < n_trks) { |
| 680 | + if (tracks[tj].has_mcParticle()) { |
| 681 | + const auto& pj = tracks[tj].template mcParticle_as<AnyParticles>().template mcParticle_as<AnyOriginalParticles>(); |
| 682 | + posj = std::array<float, 3>{pj.vx(), pj.vy(), pj.vz()}; |
| 683 | + } else { |
| 684 | + dists[trk_pair_idx(ti, tj)] = std::numeric_limits<float>::max(); |
| 685 | + continue; |
| 686 | + } |
| 687 | + } else { |
| 688 | + posj = std::array<float, 3>{collision.posX(), collision.posY(), collision.posZ()}; |
| 689 | + } |
| 690 | + |
| 691 | + if (tracks[ti].has_mcParticle()) { |
| 692 | + const auto& pi = tracks[ti].template mcParticle_as<AnyParticles>().template mcParticle_as<AnyOriginalParticles>(); |
| 693 | + posi = std::array<float, 3>{pi.vx(), pi.vy(), pi.vz()}; |
| 694 | + } else { |
| 695 | + dists[trk_pair_idx(ti, tj)] = std::numeric_limits<float>::max(); |
| 696 | + continue; |
| 697 | + } |
| 698 | + |
| 699 | + dists[trk_pair_idx(ti, tj)] = RecoDecay::distance(posi, posj); |
| 700 | + } |
| 701 | + |
| 702 | + int clusteri = -1, clusterj = -1; |
| 703 | + float min_min_dist = -1.f; // If there is an not-merge-able min_dist pair, check the 2nd-min_dist pair. |
| 704 | + while (true) { |
| 705 | + |
| 706 | + float min_dist = -1.f; // Get min_dist pair |
| 707 | + for (int ti = 0; ti < n_pos - 1; ti++) |
| 708 | + for (int tj = ti + 1; tj < n_pos; tj++) |
| 709 | + if (tempTrkVtxIndex[ti] != tempTrkVtxIndex[tj] && tempTrkVtxIndex[ti] >= 0 && tempTrkVtxIndex[tj] >= 0) { |
| 710 | + float dist = dists[trk_pair_idx(ti, tj)]; |
| 711 | + if ((dist < min_dist || min_dist < 0.f) && dist > min_min_dist) { |
| 712 | + min_dist = dist; |
| 713 | + clusteri = ti; |
| 714 | + clusterj = tj; |
| 715 | + } |
| 716 | + } |
| 717 | + if (clusteri < 0 || clusterj < 0) |
| 718 | + break; |
| 719 | + |
| 720 | + bool mrg = true; // Merge-ability check |
| 721 | + for (int ti = 0; ti < n_pos && mrg; ti++) |
| 722 | + if (tempTrkVtxIndex[ti] == tempTrkVtxIndex[clusteri] && tempTrkVtxIndex[ti] >= 0) { |
| 723 | + for (int tj = 0; tj < n_pos && mrg; tj++) |
| 724 | + if (tj != ti && tempTrkVtxIndex[tj] == tempTrkVtxIndex[clusterj] && tempTrkVtxIndex[tj] >= 0) { |
| 725 | + if (dists[trk_pair_idx(ti, tj)] > vtxResParam) { // If there is more distant pair compared to vtx_res between two clusters, they cannot be merged. |
| 726 | + mrg = false; |
| 727 | + min_min_dist = min_dist; |
| 728 | + } |
| 729 | + } |
| 730 | + } |
| 731 | + if (min_dist > vtxResParam || min_dist < 0.f) |
| 732 | + break; |
| 733 | + |
| 734 | + if (mrg) { // Merge two clusters |
| 735 | + int old_index = tempTrkVtxIndex[clusterj]; |
| 736 | + for (int t = 0; t < n_pos; t++) |
| 737 | + if (tempTrkVtxIndex[t] == old_index) |
| 738 | + tempTrkVtxIndex[t] = tempTrkVtxIndex[clusteri]; |
| 739 | + } |
| 740 | + } |
| 741 | + |
| 742 | + int n_vertices = 0; |
| 743 | + |
| 744 | + // Sort the indices from PV (as 0) to the most distant SV (as 1~). |
| 745 | + int idxPV = tempTrkVtxIndex[n_trks]; |
| 746 | + for (int t = 0; t < n_trks; t++) |
| 747 | + if (tempTrkVtxIndex[t] == idxPV) { |
| 748 | + tempTrkVtxIndex[t] = -2; |
| 749 | + n_vertices = 1; // There is a track originating from PV |
| 750 | + } |
| 751 | + |
| 752 | + std::unordered_map<int, float> avgDistances; |
| 753 | + std::unordered_map<int, int> count; |
| 754 | + for (int t = 0; t < n_trks; t++) { |
| 755 | + if (tempTrkVtxIndex[t] >= 0) { |
| 756 | + avgDistances[tempTrkVtxIndex[t]] += dists[trk_pair_idx(t, n_trks)]; |
| 757 | + count[tempTrkVtxIndex[t]]++; |
| 758 | + } |
| 759 | + } |
| 760 | + |
| 761 | + trkLabels["trkVtxIndex"] = std::vector<int>(n_trks, -1); |
| 762 | + if (count.size() != 0) { // If there is any SV cluster not only PV cluster |
| 763 | + for (auto& [idx, avgDistance] : avgDistances) |
| 764 | + avgDistance /= count[idx]; |
| 765 | + |
| 766 | + n_vertices += avgDistances.size(); |
| 767 | + |
| 768 | + std::vector<std::pair<int, float>> sortedIndices(avgDistances.begin(), avgDistances.end()); |
| 769 | + std::sort(sortedIndices.begin(), sortedIndices.end(), [](const auto& a, const auto& b) { return a.second < b.second; }); |
| 770 | + int rank = 1; |
| 771 | + for (const auto& [idx, avgDistance] : sortedIndices) { |
| 772 | + bool found = false; |
| 773 | + for (int t = 0; t < n_trks; t++) |
| 774 | + if (tempTrkVtxIndex[t] == idx) { |
| 775 | + trkLabels["trkVtxIndex"][t] = rank; |
| 776 | + found = true; |
| 777 | + } |
| 778 | + rank += found; |
| 779 | + } |
| 780 | + } |
| 781 | + |
| 782 | + for (int t = 0; t < n_trks; t++) |
| 783 | + if (tempTrkVtxIndex[t] == -2) |
| 784 | + trkLabels["trkVtxIndex"][t] = 0; |
| 785 | + |
| 786 | + // trkOrigin |
| 787 | + |
| 788 | + int trkIdx = 0; |
| 789 | + for (auto& constituent : jet.template tracks_as<AnyTracks>()) { |
| 790 | + if (!constituent.has_mcParticle() || !constituent.template mcParticle_as<AnyParticles>().isPhysicalPrimary() || constituent.pt() < trackPtMin) { |
| 791 | + trkLabels["trkOrigin"].push_back(0); |
| 792 | + } else { |
| 793 | + const auto& particle = constituent.template mcParticle_as<AnyParticles>(); |
| 794 | + int orig = RecoDecay::getParticleOrigin(particles, particle, searchUpToQuark); |
| 795 | + trkLabels["trkOrigin"].push_back((orig > 0) ? orig : (trkLabels["trkVtxIndex"][trkIdx] == 0) ? 3 |
| 796 | + : 4); |
| 797 | + } |
| 798 | + |
| 799 | + trkIdx++; |
| 800 | + } |
| 801 | + |
| 802 | + return n_vertices; |
| 803 | +} |
| 804 | + |
632 | 805 | }; // namespace jettaggingutilities |
633 | 806 |
|
634 | 807 | #endif // PWGJE_CORE_JETTAGGINGUTILITIES_H_ |
0 commit comments