Skip to content

Commit 4bb0979

Browse files
authored
[PWGJE] Adding the ability to match subjets (#10267)
1 parent 4cea901 commit 4bb0979

40 files changed

+2332
-267
lines changed

PWGJE/Core/JetCandidateUtilities.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,24 @@ auto slicedPerCandidateCollision(T const& table, U const& candidates, V const& c
255255
}
256256
}
257257

258+
/**
259+
* returns a slice of the table depending on the index of the candidate
260+
* @param CandidateTable candidtae table type
261+
* @param jet jet that the slice is based on
262+
* @param table the table to be sliced
263+
*/
264+
template <typename CandidateTable, typename T, typename U, typename V, typename M, typename N, typename O, typename P>
265+
auto slicedPerJet(T const& table, U const& jet, V const& perD0Jet, M const& perDplusJet, N const& perLcJet, O const& perBplusJet, P const& perDielectronJet)
266+
{
267+
if constexpr (jethfutilities::isHFTable<CandidateTable>() || jethfutilities::isHFMcTable<CandidateTable>()) {
268+
return jethfutilities::slicedPerHFJet<CandidateTable>(table, jet, perD0Jet, perDplusJet, perLcJet, perBplusJet);
269+
} else if constexpr (jetdqutilities::isDielectronTable<CandidateTable>() || jetdqutilities::isDielectronMcTable<CandidateTable>()) {
270+
return jetdqutilities::slicedPerDielectronJet<CandidateTable>(table, jet, perDielectronJet);
271+
} else {
272+
return table;
273+
}
274+
}
275+
258276
/**
259277
* returns the candidate collision Id of candidate based on type of candidate
260278
*

PWGJE/Core/JetDQUtilities.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,22 @@ auto slicedPerDielectronCollision(T const& table, U const& /*candidates*/, V con
180180
}
181181
}
182182

183+
/**
184+
* returns a slice of the table depending on the index of the Dielectron jet
185+
* @param DielectronTable dielectron table type
186+
* @param jet jet that the slice is based on
187+
* @param table the table to be sliced
188+
*/
189+
template <typename DielectronTable, typename T, typename U, typename V>
190+
auto slicedPerDielectronJet(T const& table, U const& jet, V const& perDielectronJet)
191+
{
192+
if constexpr (isDielectronTable<DielectronTable>() || isDielectronMcTable<DielectronTable>()) {
193+
return table.sliceBy(perDielectronJet, jet.globalIndex());
194+
} else {
195+
return table;
196+
}
197+
}
198+
183199
/**
184200
* returns the Dielectron collision Id of candidate based on type of Dielectron candidate
185201
*

PWGJE/Core/JetHFUtilities.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,29 @@ auto slicedPerHFCollision(T const& table, U const& /*candidates*/, V const& coll
440440
}
441441
}
442442

443+
/**
444+
* returns a slice of the table depending on the index of the HF candidate
445+
*
446+
* @param HFTable HF table type
447+
* @param jet jet that is being sliced based on
448+
* @param table the table to be sliced
449+
*/
450+
template <typename HFTable, typename T, typename U, typename V, typename M, typename N, typename O>
451+
auto slicedPerHFJet(T const& table, U const& jet, V const& perD0Jet, M const& perDplusJet, N const& perLcJet, O const& perBplusJet)
452+
{
453+
if constexpr (isD0Table<HFTable>() || isD0McTable<HFTable>()) {
454+
return table.sliceBy(perD0Jet, jet.globalIndex());
455+
} else if constexpr (isDplusTable<HFTable>() || isDplusMcTable<HFTable>()) {
456+
return table.sliceBy(perDplusJet, jet.globalIndex());
457+
} else if constexpr (isLcTable<HFTable>() || isLcMcTable<HFTable>()) {
458+
return table.sliceBy(perLcJet, jet.globalIndex());
459+
} else if constexpr (isBplusTable<HFTable>() || isBplusMcTable<HFTable>()) {
460+
return table.sliceBy(perBplusJet, jet.globalIndex());
461+
} else {
462+
return table;
463+
}
464+
}
465+
443466
/**
444467
* returns the HF collision Id of candidate based on type of HF candidate
445468
*

PWGJE/Core/JetMatchingUtilities.h

Lines changed: 221 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "PWGJE/DataModel/EMCALClusters.h"
4141
#include "PWGJE/DataModel/Jet.h"
4242
#include "PWGJE/Core/JetCandidateUtilities.h"
43+
#include "PWGJE/Core/JetFindingUtilities.h"
4344

4445
namespace jetmatchingutilities
4546
{
@@ -406,8 +407,8 @@ auto constexpr getConstituentId(T const& track)
406407
}
407408
}
408409

409-
template <bool isEMCAL, bool jetsBaseIsMc, bool jetsTagIsMc, typename T, typename U, typename V, typename O>
410-
float getPtSum(T const& tracksBase, U const& clustersBase, V const& tracksTag, O const& clustersTag)
410+
template <bool isEMCAL, bool isCandidate, bool jetsBaseIsMc, bool jetsTagIsMc, typename T, typename U, typename V, typename O, typename P, typename Q, typename R, typename S>
411+
float getPtSum(T const& tracksBase, U const& candidatesBase, V const& clustersBase, O const& tracksTag, P const& candidatesTag, Q const& clustersTag, R const& fullTracksBase, S const& fullTracksTag)
411412
{
412413
std::vector<int> particleTracker;
413414
float ptSum = 0.;
@@ -464,6 +465,47 @@ float getPtSum(T const& tracksBase, U const& clustersBase, V const& tracksTag, O
464465
}
465466
}
466467
}
468+
if constexpr (isCandidate) {
469+
if constexpr (jetsTagIsMc) {
470+
for (auto const& candidateBase : candidatesBase) {
471+
if (jetcandidateutilities::isMatchedCandidate(candidateBase)) {
472+
const auto candidateBaseMcId = jetcandidateutilities::matchedParticleId(candidateBase, fullTracksBase, fullTracksTag);
473+
for (auto const& candidateTag : candidatesTag) {
474+
const auto candidateTagId = candidateTag.mcParticleId();
475+
if (candidateBaseMcId == candidateTagId) {
476+
ptSum += candidateBase.pt();
477+
}
478+
break; // should only be one
479+
}
480+
}
481+
break;
482+
}
483+
} else if constexpr (jetsBaseIsMc) {
484+
for (auto const& candidateTag : candidatesTag) {
485+
if (jetcandidateutilities::isMatchedCandidate(candidateTag)) {
486+
const auto candidateTagMcId = jetcandidateutilities::matchedParticleId(candidateTag, fullTracksTag, fullTracksBase);
487+
for (auto const& candidateBase : candidatesBase) {
488+
const auto candidateBaseId = candidateBase.mcParticleId();
489+
if (candidateTagMcId == candidateBaseId) {
490+
ptSum += candidateTag.pt();
491+
}
492+
break; // should only be one
493+
}
494+
}
495+
break;
496+
}
497+
} else {
498+
for (auto const& candidateBase : candidatesBase) {
499+
for (auto const& candidateTag : candidatesTag) {
500+
if (candidateBase.globalIndex() == candidateTag.globalIndex()) {
501+
ptSum += candidateBase.pt();
502+
}
503+
break; // should only be one
504+
}
505+
break;
506+
}
507+
}
508+
}
467509
return ptSum;
468510
}
469511

@@ -472,30 +514,34 @@ auto getConstituents(T const& jet, U const& /*constituents*/)
472514
{
473515
if constexpr (jetfindingutilities::isEMCALClusterTable<U>()) {
474516
return jet.template clusters_as<U>();
475-
} else if constexpr (jetfindingutilities::isDummyTable<U>()) { // this is for the case where EMCal clusters are tested but no clusters exist, like in the case of charged jet analyses
517+
} else if constexpr (jetcandidateutilities::isCandidateTable<U>() || jetcandidateutilities::isCandidateMcTable<U>()) {
518+
return jet.template candidates_as<U>();
519+
} else if constexpr (jetfindingutilities::isDummyTable<U>() || std::is_same_v<U, o2::aod::JCollisions> || std::is_same_v<U, o2::aod::JMcCollisions>) { // this is for the case where EMCal clusters or candidates are tested but no clusters or candidates exist and dummy tables are used, like in the case of charged jet analyses
476520
return nullptr;
477521
} else {
478522
return jet.template tracks_as<U>();
479523
}
480524
}
481525

482-
template <bool jetsBaseIsMc, bool jetsTagIsMc, typename T, typename U, typename V, typename M, typename N, typename O>
483-
void MatchPt(T const& jetsBasePerCollision, U const& jetsTagPerCollision, std::vector<std::vector<int>>& baseToTagMatchingPt, std::vector<std::vector<int>>& tagToBaseMatchingPt, V const& tracksBase, M const& clustersBase, N const& tracksTag, O const& clustersTag, float minPtFraction)
526+
template <bool jetsBaseIsMc, bool jetsTagIsMc, typename T, typename U, typename V, typename M, typename N, typename O, typename P, typename Q>
527+
void MatchPt(T const& jetsBasePerCollision, U const& jetsTagPerCollision, std::vector<std::vector<int>>& baseToTagMatchingPt, std::vector<std::vector<int>>& tagToBaseMatchingPt, V const& tracksBase, M const& candidatesBase, N const& clustersBase, O const& tracksTag, P const& candidatesTag, Q const& clustersTag, float minPtFraction)
484528
{
485529
float ptSumBase;
486530
float ptSumTag;
487531
for (const auto& jetBase : jetsBasePerCollision) {
488532
auto jetBaseTracks = getConstituents(jetBase, tracksBase);
489533
auto jetBaseClusters = getConstituents(jetBase, clustersBase);
534+
auto jetBaseCandidates = getConstituents(jetBase, candidatesBase);
490535
for (const auto& jetTag : jetsTagPerCollision) {
491536
if (std::round(jetBase.r()) != std::round(jetTag.r())) {
492537
continue;
493538
}
494539
auto jetTagTracks = getConstituents(jetTag, tracksTag);
495540
auto jetTagClusters = getConstituents(jetTag, clustersTag);
541+
auto jetTagCandidates = getConstituents(jetTag, candidatesTag);
496542

497-
ptSumBase = getPtSum < jetfindingutilities::isEMCALClusterTable<M>() || jetfindingutilities::isEMCALClusterTable<O>(), jetsBaseIsMc, jetsTagIsMc > (jetBaseTracks, jetBaseClusters, jetTagTracks, jetTagClusters);
498-
ptSumTag = getPtSum < jetfindingutilities::isEMCALClusterTable<M>() || jetfindingutilities::isEMCALClusterTable<O>(), jetsTagIsMc, jetsBaseIsMc > (jetTagTracks, jetTagClusters, jetBaseTracks, jetBaseClusters);
543+
ptSumBase = getPtSum < jetfindingutilities::isEMCALClusterTable<N>() || jetfindingutilities::isEMCALClusterTable<Q>(), (jetcandidateutilities::isCandidateTable<M>() || jetcandidateutilities::isCandidateMcTable<M>()) && (jetcandidateutilities::isCandidateTable<P>() || jetcandidateutilities::isCandidateMcTable<P>()), jetsBaseIsMc, jetsTagIsMc > (jetBaseTracks, jetBaseCandidates, jetBaseClusters, jetTagTracks, jetTagCandidates, jetTagClusters, tracksBase, tracksTag);
544+
ptSumTag = getPtSum < jetfindingutilities::isEMCALClusterTable<N>() || jetfindingutilities::isEMCALClusterTable<Q>(), (jetcandidateutilities::isCandidateTable<M>() || jetcandidateutilities::isCandidateMcTable<M>()) && (jetcandidateutilities::isCandidateTable<P>() || jetcandidateutilities::isCandidateMcTable<P>()), jetsTagIsMc, jetsBaseIsMc > (jetTagTracks, jetTagCandidates, jetTagClusters, jetBaseTracks, jetBaseCandidates, jetBaseClusters, tracksTag, tracksBase);
499545
if (ptSumBase > jetBase.pt() * minPtFraction) {
500546
baseToTagMatchingPt[jetBase.globalIndex()].push_back(jetTag.globalIndex());
501547
}
@@ -508,22 +554,187 @@ void MatchPt(T const& jetsBasePerCollision, U const& jetsTagPerCollision, std::v
508554

509555
// function that calls all the Match functions
510556
template <bool jetsBaseIsMc, bool jetsTagIsMc, typename T, typename U, typename V, typename M, typename N, typename O, typename P, typename R>
511-
void doAllMatching(T const& jetsBasePerCollision, U const& jetsTagPerCollision, std::vector<std::vector<int>>& baseToTagMatchingGeo, std::vector<std::vector<int>>& baseToTagMatchingPt, std::vector<std::vector<int>>& baseToTagMatchingHF, std::vector<std::vector<int>>& tagToBaseMatchingGeo, std::vector<std::vector<int>>& tagToBaseMatchingPt, std::vector<std::vector<int>>& tagToBaseMatchingHF, V const& candidatesBase, M const& candidatesTag, N const& tracksBase, O const& clustersBase, P const& tracksTag, R const& clustersTag, bool doMatchingGeo, bool doMatchingHf, bool doMatchingPt, float maxMatchingDistance, float minPtFraction)
557+
void doAllMatching(T const& jetsBasePerCollision, U const& jetsTagPerCollision, std::vector<std::vector<int>>& baseToTagMatchingGeo, std::vector<std::vector<int>>& baseToTagMatchingPt, std::vector<std::vector<int>>& baseToTagMatchingHF, std::vector<std::vector<int>>& tagToBaseMatchingGeo, std::vector<std::vector<int>>& tagToBaseMatchingPt, std::vector<std::vector<int>>& tagToBaseMatchingHF, V const& candidatesBase, M const& tracksBase, N const& clustersBase, O const& candidatesTag, P const& tracksTag, R const& clustersTag, bool doMatchingGeo, bool doMatchingHf, bool doMatchingPt, float maxMatchingDistance, float minPtFraction)
512558
{
513559
// geometric matching
514560
if (doMatchingGeo) {
515561
MatchGeo(jetsBasePerCollision, jetsTagPerCollision, baseToTagMatchingGeo, tagToBaseMatchingGeo, maxMatchingDistance);
516562
}
517563
// pt matching
518564
if (doMatchingPt) {
519-
MatchPt<jetsBaseIsMc, jetsTagIsMc>(jetsBasePerCollision, jetsTagPerCollision, baseToTagMatchingPt, tagToBaseMatchingPt, tracksBase, clustersBase, tracksTag, clustersTag, minPtFraction);
565+
MatchPt<jetsBaseIsMc, jetsTagIsMc>(jetsBasePerCollision, jetsTagPerCollision, baseToTagMatchingPt, tagToBaseMatchingPt, tracksBase, candidatesBase, clustersBase, tracksTag, candidatesTag, clustersTag, minPtFraction);
520566
}
521567
// HF matching
522-
if constexpr (jetcandidateutilities::isCandidateTable<V>()) {
568+
if constexpr (jetcandidateutilities::isCandidateTable<V>() || jetcandidateutilities::isCandidateMcTable<V>()) {
523569
if (doMatchingHf) {
524570
MatchHF<jetsBaseIsMc, jetsTagIsMc>(jetsBasePerCollision, jetsTagPerCollision, baseToTagMatchingHF, tagToBaseMatchingHF, candidatesBase, candidatesTag, tracksBase, tracksTag);
525571
}
526572
}
527573
}
574+
575+
// function that does pair matching
576+
template <bool jetsBaseIsMc, bool jetsTagIsMc, typename T, typename U, typename V, typename M, typename N, typename O>
577+
void doPairMatching(T const& pairsBase, U const& pairsTag, std::vector<std::vector<int>>& baseToTagMatching, std::vector<std::vector<int>>& tagToBaseMatching, V const& /*candidatesBase*/, M const& tracksBase, N const& /*candidatesTag*/, O const& tracksTag)
578+
{
579+
bool hasTrackBase1 = false;
580+
bool hasTrackBase2 = false;
581+
bool hasCandidateBase1 = false;
582+
bool hasCandidateBase2 = false;
583+
std::vector<int> pairsTagIndices;
584+
for (auto i = 0; i < pairsTag.size(); i++) {
585+
pairsTagIndices.push_back(i);
586+
}
587+
for (const auto& pairBase : pairsBase) {
588+
if (pairBase.has_track1()) {
589+
hasTrackBase1 = true;
590+
}
591+
if (pairBase.has_track2()) {
592+
hasTrackBase2 = true;
593+
}
594+
if (pairBase.has_candidate1()) {
595+
hasCandidateBase1 = true;
596+
}
597+
if (pairBase.has_candidate2()) {
598+
hasCandidateBase2 = true;
599+
}
600+
int matchedPairTagIndex = -1;
601+
for (auto pairTagIndex : pairsTagIndices) {
602+
const auto& pairTag = pairsTag.iteratorAt(pairTagIndex);
603+
if (hasTrackBase1 && !pairTag.has_track1()) {
604+
continue;
605+
}
606+
if (hasTrackBase2 && !pairTag.has_track2()) {
607+
continue;
608+
}
609+
if (hasCandidateBase1 && !pairTag.has_candidate1()) {
610+
continue;
611+
}
612+
if (hasCandidateBase2 && !pairTag.has_candidate2()) {
613+
continue;
614+
}
615+
int nMatched = 0;
616+
bool isMatched = false;
617+
if (hasTrackBase1) {
618+
const auto& trackBase1 = pairBase.template track1_as<M>();
619+
const auto& trackTag1 = pairTag.template track1_as<O>();
620+
if constexpr (jetsTagIsMc) {
621+
if (trackBase1.mcParticleId() == trackTag1.globalIndex()) {
622+
nMatched++;
623+
isMatched = true;
624+
}
625+
} else if constexpr (jetsBaseIsMc) {
626+
if (trackBase1.globalIndex() == trackTag1.mcParticleId()) {
627+
nMatched++;
628+
isMatched = true;
629+
}
630+
} else {
631+
if (trackBase1.globalIndex() == trackTag1.globalIndex()) {
632+
nMatched++;
633+
isMatched = true;
634+
}
635+
}
636+
if (!isMatched) {
637+
continue;
638+
}
639+
}
640+
isMatched = false;
641+
642+
if (hasTrackBase2) {
643+
const auto& trackBase2 = pairBase.template track2_as<M>();
644+
const auto& trackTag2 = pairTag.template track2_as<O>();
645+
if constexpr (jetsTagIsMc) {
646+
if (trackBase2.mcParticleId() == trackTag2.globalIndex()) {
647+
nMatched++;
648+
isMatched = true;
649+
}
650+
} else if constexpr (jetsBaseIsMc) {
651+
if (trackBase2.globalIndex() == trackTag2.mcParticleId()) {
652+
nMatched++;
653+
isMatched = true;
654+
}
655+
} else {
656+
if (trackBase2.globalIndex() == trackTag2.globalIndex()) {
657+
nMatched++;
658+
isMatched = true;
659+
}
660+
}
661+
if (!isMatched) {
662+
continue;
663+
}
664+
}
665+
isMatched = false;
666+
if (hasCandidateBase1) {
667+
const auto& candidateBase1 = pairBase.template candidate1_as<V>();
668+
const auto& candidateTag1 = pairTag.template candidate1_as<N>();
669+
if constexpr (jetsTagIsMc) {
670+
if (jetcandidateutilities::isMatchedCandidate(candidateBase1)) {
671+
const auto candidateBaseMcId = jetcandidateutilities::matchedParticleId(candidateBase1, tracksBase, tracksTag);
672+
if (candidateBaseMcId == candidateTag1.globalIndex()) {
673+
nMatched++;
674+
isMatched = true;
675+
}
676+
}
677+
} else if constexpr (jetsBaseIsMc) {
678+
if (jetcandidateutilities::isMatchedCandidate(candidateTag1)) {
679+
const auto candidateTagMcId = jetcandidateutilities::matchedParticleId(candidateTag1, tracksTag, tracksBase);
680+
if (candidateTagMcId == candidateBase1.globalIndex()) {
681+
nMatched++;
682+
isMatched = true;
683+
}
684+
}
685+
} else {
686+
if (candidateBase1.globalIndex() == candidateTag1.globalIndex()) {
687+
nMatched++;
688+
isMatched = true;
689+
}
690+
}
691+
if (!isMatched) {
692+
continue;
693+
}
694+
}
695+
isMatched = false;
696+
if (hasCandidateBase2) {
697+
const auto& candidateBase2 = pairBase.template candidate2_as<V>();
698+
const auto& candidateTag2 = pairTag.template candidate2_as<N>();
699+
if constexpr (jetsTagIsMc) {
700+
if (jetcandidateutilities::isMatchedCandidate(candidateBase2)) {
701+
const auto candidateBaseMcId = jetcandidateutilities::matchedParticleId(candidateBase2, tracksBase, tracksTag);
702+
if (candidateBaseMcId == candidateTag2.globalIndex()) {
703+
nMatched++;
704+
isMatched = true;
705+
}
706+
}
707+
} else if constexpr (jetsBaseIsMc) {
708+
if (jetcandidateutilities::isMatchedCandidate(candidateTag2)) {
709+
const auto candidateTagMcId = jetcandidateutilities::matchedParticleId(candidateTag2, tracksTag, tracksBase);
710+
if (candidateTagMcId == candidateBase2.globalIndex()) {
711+
nMatched++;
712+
isMatched = true;
713+
}
714+
}
715+
} else {
716+
if (candidateBase2.globalIndex() == candidateTag2.globalIndex()) {
717+
nMatched++;
718+
isMatched = true;
719+
}
720+
}
721+
if (!isMatched) {
722+
continue;
723+
}
724+
}
725+
726+
if (nMatched == 2) {
727+
baseToTagMatching[pairBase.globalIndex()].push_back(pairTag.globalIndex());
728+
tagToBaseMatching[pairTag.globalIndex()].push_back(pairBase.globalIndex());
729+
matchedPairTagIndex = pairTagIndex;
730+
break; // can only be one match per jet
731+
}
732+
}
733+
if (matchedPairTagIndex != -1) {
734+
pairsTagIndices.erase(std::find(pairsTagIndices.begin(), pairsTagIndices.end(), matchedPairTagIndex));
735+
}
736+
}
737+
}
738+
528739
}; // namespace jetmatchingutilities
529740
#endif // PWGJE_CORE_JETMATCHINGUTILITIES_H_

0 commit comments

Comments
 (0)