Skip to content

Commit f029faf

Browse files
choich08365Changhwan Choi
andauthored
[PWGJE] GNN-based b-jet tagging analysis code updated (#12644)
Co-authored-by: Changhwan Choi <changhwan.choi@cern.ch>
1 parent 122cea3 commit f029faf

File tree

2 files changed

+44
-71
lines changed

2 files changed

+44
-71
lines changed

PWGJE/Tasks/bjetTaggingGnn.cxx

Lines changed: 32 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ struct BjetTaggingGnn {
5757
Configurable<float> trackEtaMin{"trackEtaMin", -0.9, "minimum track eta"};
5858
Configurable<float> trackEtaMax{"trackEtaMax", 0.9, "maximum track eta"};
5959

60+
Configurable<float> maxIPxy{"maxIPxy", 10, "maximum track DCA in xy plane"};
61+
Configurable<float> maxIPz{"maxIPz", 10, "maximum track DCA in z direction"};
62+
6063
Configurable<float> trackNppCrit{"trackNppCrit", 0.95, "track not physical primary ratio"};
6164

62-
// track level configurables
65+
// sv level configurables
6366
Configurable<float> svPtMin{"svPtMin", 0.5, "minimum SV pT"};
6467

6568
// jet level configurables
@@ -70,9 +73,15 @@ struct BjetTaggingGnn {
7073

7174
Configurable<std::vector<double>> jetRadii{"jetRadii", std::vector<double>{0.4}, "jet resolution parameters"};
7275

76+
Configurable<double> dbMin{"dbMin", -10., "minimum GNN Db"};
77+
Configurable<double> dbMax{"dbMax", 20., "maximum GNN Db"};
78+
Configurable<int> dbNbins{"dbNbins", 3000, "number of bins in axisDbFine"};
79+
7380
Configurable<bool> doDataDriven{"doDataDriven", false, "Flag whether to use fill THnSpase for data driven methods"};
7481
Configurable<bool> callSumw2{"callSumw2", false, "Flag whether to call THnSparse::Sumw2() for error calculation"};
7582

83+
Configurable<int> trainingDatasetRatioParam{"trainingDatasetRatioParam", 0, "Parameter for splitting training/evaluation datasets by collisionId"};
84+
7685
std::vector<int> eventSelectionBits;
7786

7887
std::vector<double> jetRadiiValues;
@@ -83,52 +92,36 @@ struct BjetTaggingGnn {
8392

8493
eventSelectionBits = jetderiveddatautilities::initialiseEventSelectionBits(static_cast<std::string>(eventSelections));
8594

86-
registry.add("h_vertexZ", "Vertex Z;#it{Z} (cm)", {HistType::kTH1F, {{40, -20.0, 20.0}}});
95+
registry.add("h_vertexZ", "Vertex Z;#it{Z} (cm)", {HistType::kTH1F, {{100, -20.0, 20.0}}}, callSumw2);
8796

8897
const AxisSpec axisJetpT{200, 0., 200., "#it{p}_{T} (GeV/#it{c})"};
89-
const AxisSpec axisDb{200, -10., 20., "#it{D}_{b}"};
90-
const AxisSpec axisDbFine{3000, -10., 20., "#it{D}_{b}"};
98+
const AxisSpec axisDb{200, dbMin, dbMax, "#it{D}_{b}"};
99+
const AxisSpec axisDbFine{dbNbins, dbMin, dbMax, "#it{D}_{b}"};
91100
const AxisSpec axisSVMass{200, 0., 10., "#it{m}_{SV} (GeV/#it{c}^{2})"};
92101
const AxisSpec axisSVEnergy{200, 0., 100., "#it{E}_{SV} (GeV)"};
93102
const AxisSpec axisSLxy{200, 0., 100., "#it{SL}_{xy}"};
94103
const AxisSpec axisJetMass{200, 0., 50., "#it{m}_{jet} (GeV/#it{c}^{2})"};
95104
const AxisSpec axisJetProb{200, 0., 40., "-ln(JP)"};
96105
const AxisSpec axisNTracks{42, 0, 42, "#it{n}_{tracks}"};
97106

98-
registry.add("h_jetpT", "", {HistType::kTH1F, {axisJetpT}});
107+
registry.add("h_jetpT", "", {HistType::kTH1F, {axisJetpT}}, callSumw2);
99108
registry.add("h_Db", "", {HistType::kTH1F, {axisDbFine}});
100109
registry.add("h2_jetpT_Db", "", {HistType::kTH2F, {axisJetpT, axisDb}});
101-
registry.add("h2_jetpT_SVMass", "", {HistType::kTH2F, {axisJetpT, axisSVMass}});
102-
registry.add("h2_jetpT_jetMass", "", {HistType::kTH2F, {axisJetpT, axisJetMass}});
103-
registry.add("h2_jetpT_jetProb", "", {HistType::kTH2F, {axisJetpT, axisJetProb}});
104-
registry.add("h2_jetpT_nTracks", "", {HistType::kTH2F, {axisJetpT, axisNTracks}});
105110

106111
if (doprocessMCJets) {
107-
registry.add("h_jetpT_b", "b-jet", {HistType::kTH1F, {axisJetpT}});
108-
registry.add("h_jetpT_c", "c-jet", {HistType::kTH1F, {axisJetpT}});
109-
registry.add("h_jetpT_lf", "lf-jet", {HistType::kTH1F, {axisJetpT}});
112+
registry.add("h_jetpT_b", "b-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
113+
registry.add("h_jetpT_c", "c-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
114+
registry.add("h_jetpT_lf", "lf-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
110115
registry.add("h_Db_b", "b-jet", {HistType::kTH1F, {axisDbFine}});
111116
registry.add("h_Db_c", "c-jet", {HistType::kTH1F, {axisDbFine}});
112117
registry.add("h_Db_lf", "lf-jet", {HistType::kTH1F, {axisDbFine}});
113118
registry.add("h2_jetpT_Db_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisDb}});
114119
registry.add("h2_jetpT_Db_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisDb}});
115120
registry.add("h2_jetpT_Db_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisDb}});
116-
registry.add("h2_jetpT_SVMass_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisSVMass}});
117-
registry.add("h2_jetpT_SVMass_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisSVMass}});
118-
registry.add("h2_jetpT_SVMass_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisSVMass}});
119-
registry.add("h2_jetpT_jetMass_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisJetMass}});
120-
registry.add("h2_jetpT_jetMass_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisJetMass}});
121-
registry.add("h2_jetpT_jetMass_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisJetMass}});
122-
registry.add("h2_jetpT_jetProb_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisJetProb}});
123-
registry.add("h2_jetpT_jetProb_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisJetProb}});
124-
registry.add("h2_jetpT_jetProb_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisJetProb}});
125-
registry.add("h2_jetpT_nTracks_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisNTracks}});
126-
registry.add("h2_jetpT_nTracks_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisNTracks}});
127-
registry.add("h2_jetpT_nTracks_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisNTracks}});
128-
registry.add("h2_Response_DetjetpT_PartjetpT", "", {HistType::kTH2F, {axisJetpT, axisJetpT}});
129-
registry.add("h2_Response_DetjetpT_PartjetpT_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}});
130-
registry.add("h2_Response_DetjetpT_PartjetpT_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}});
131-
registry.add("h2_Response_DetjetpT_PartjetpT_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}});
121+
registry.add("h2_Response_DetjetpT_PartjetpT", "", {HistType::kTH2F, {axisJetpT, axisJetpT}}, callSumw2);
122+
registry.add("h2_Response_DetjetpT_PartjetpT_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}}, callSumw2);
123+
registry.add("h2_Response_DetjetpT_PartjetpT_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}}, callSumw2);
124+
registry.add("h2_Response_DetjetpT_PartjetpT_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}, callSumw2});
132125
registry.add("h2_jetpT_Db_lf_none", "lf-jet (none)", {HistType::kTH2F, {axisJetpT, axisDb}});
133126
registry.add("h2_jetpT_Db_lf_matched", "lf-jet (matched)", {HistType::kTH2F, {axisJetpT, axisDb}});
134127
registry.add("h2_jetpT_Db_npp", "NotPhysPrim", {HistType::kTH2F, {axisJetpT, axisDb}});
@@ -139,17 +132,13 @@ struct BjetTaggingGnn {
139132
registry.add("h_Db_npp_b", "NotPhysPrim b-jet", {HistType::kTH1F, {axisDbFine}});
140133
registry.add("h_Db_npp_c", "NotPhysPrim c-jet", {HistType::kTH1F, {axisDbFine}});
141134
registry.add("h_Db_npp_lf", "NotPhysPrim lf-jet", {HistType::kTH1F, {axisDbFine}});
142-
// registry.add("h2_pT_dcaXY_pp", "tracks", {HistType::kTH2F, {axisJetpT, {200, 0., 1.}}});
143-
// registry.add("h2_pT_dcaXY_npp", "NotPhysPrim tracks", {HistType::kTH2F, {axisJetpT, {200, 0., 1.}}});
144-
// registry.add("h2_pT_dcaZ_pp", "tracks", {HistType::kTH2F, {axisJetpT, {200, 0., 2.}}});
145-
// registry.add("h2_pT_dcaZ_npp", "NotPhysPrim tracks", {HistType::kTH2F, {axisJetpT, {200, 0., 2.}}});
146135
}
147136

148137
if (doprocessMCTruthJets) {
149-
registry.add("h_jetpT_particle", "", {HistType::kTH1F, {axisJetpT}});
150-
registry.add("h_jetpT_particle_b", "particle b-jet", {HistType::kTH1F, {axisJetpT}});
151-
registry.add("h_jetpT_particle_c", "particle c-jet", {HistType::kTH1F, {axisJetpT}});
152-
registry.add("h_jetpT_particle_lf", "particle lf-jet", {HistType::kTH1F, {axisJetpT}});
138+
registry.add("h_jetpT_particle", "", {HistType::kTH1F, {axisJetpT}}, callSumw2);
139+
registry.add("h_jetpT_particle_b", "particle b-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
140+
registry.add("h_jetpT_particle_c", "particle c-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
141+
registry.add("h_jetpT_particle_lf", "particle lf-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
153142
}
154143

155144
if (doDataDriven) {
@@ -181,7 +170,7 @@ struct BjetTaggingGnn {
181170
int nTracks = 0;
182171
for (const auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {
183172

184-
if (constituent.pt() < trackPtMin) {
173+
if (constituent.pt() < trackPtMin || !jettaggingutilities::trackAcceptanceWithDca(constituent, maxIPxy, maxIPz)) {
185174
continue;
186175
}
187176

@@ -266,10 +255,6 @@ struct BjetTaggingGnn {
266255
registry.fill(HIST("h_jetpT"), analysisJet.pt());
267256
registry.fill(HIST("h_Db"), analysisJet.scoreML());
268257
registry.fill(HIST("h2_jetpT_Db"), analysisJet.pt(), analysisJet.scoreML());
269-
registry.fill(HIST("h2_jetpT_SVMass"), analysisJet.pt(), mSV);
270-
registry.fill(HIST("h2_jetpT_jetMass"), analysisJet.pt(), analysisJet.mass());
271-
registry.fill(HIST("h2_jetpT_jetProb"), analysisJet.pt(), analysisJet.jetProb());
272-
registry.fill(HIST("h2_jetpT_nTracks"), analysisJet.pt(), nTracks);
273258

274259
if (doDataDriven) {
275260
registry.fill(HIST("hSparse_Incljets"), analysisJet.pt(), analysisJet.scoreML(), mSV, analysisJet.mass(), nTracks);
@@ -288,7 +273,12 @@ struct BjetTaggingGnn {
288273
return;
289274
}
290275

291-
registry.fill(HIST("h_vertexZ"), collision.posZ());
276+
// Uses only collisionId % trainingDatasetRaioParam != 0 for evaluation dataset
277+
if (trainingDatasetRatioParam && collision.collisionId() % trainingDatasetRatioParam == 0) {
278+
return;
279+
}
280+
281+
registry.fill(HIST("h_vertexZ"), collision.posZ(), useEventWeight ? collision.weight() : 1.f);
292282

293283
for (const auto& analysisJet : MCDjets) {
294284

@@ -347,35 +337,19 @@ struct BjetTaggingGnn {
347337
registry.fill(HIST("h_jetpT"), analysisJet.pt(), weight);
348338
registry.fill(HIST("h_Db"), analysisJet.scoreML(), weight);
349339
registry.fill(HIST("h2_jetpT_Db"), analysisJet.pt(), analysisJet.scoreML(), weight);
350-
registry.fill(HIST("h2_jetpT_SVMass"), analysisJet.pt(), mSV, weight);
351-
registry.fill(HIST("h2_jetpT_jetMass"), analysisJet.pt(), analysisJet.mass(), weight);
352-
registry.fill(HIST("h2_jetpT_jetProb"), analysisJet.pt(), analysisJet.jetProb(), weight);
353-
registry.fill(HIST("h2_jetpT_nTracks"), analysisJet.pt(), nTracks, weight);
354340

355341
if (jetFlavor == JetTaggingSpecies::beauty) {
356342
registry.fill(HIST("h_jetpT_b"), analysisJet.pt(), weight);
357343
registry.fill(HIST("h_Db_b"), analysisJet.scoreML(), weight);
358344
registry.fill(HIST("h2_jetpT_Db_b"), analysisJet.pt(), analysisJet.scoreML(), weight);
359-
registry.fill(HIST("h2_jetpT_SVMass_b"), analysisJet.pt(), mSV, weight);
360-
registry.fill(HIST("h2_jetpT_jetMass_b"), analysisJet.pt(), analysisJet.mass(), weight);
361-
registry.fill(HIST("h2_jetpT_jetProb_b"), analysisJet.pt(), analysisJet.jetProb(), weight);
362-
registry.fill(HIST("h2_jetpT_nTracks_b"), analysisJet.pt(), nTracks, weight);
363345
} else if (jetFlavor == JetTaggingSpecies::charm) {
364346
registry.fill(HIST("h_jetpT_c"), analysisJet.pt(), weight);
365347
registry.fill(HIST("h_Db_c"), analysisJet.scoreML(), weight);
366348
registry.fill(HIST("h2_jetpT_Db_c"), analysisJet.pt(), analysisJet.scoreML(), weight);
367-
registry.fill(HIST("h2_jetpT_SVMass_c"), analysisJet.pt(), mSV, weight);
368-
registry.fill(HIST("h2_jetpT_jetMass_c"), analysisJet.pt(), analysisJet.mass(), weight);
369-
registry.fill(HIST("h2_jetpT_jetProb_c"), analysisJet.pt(), analysisJet.jetProb(), weight);
370-
registry.fill(HIST("h2_jetpT_nTracks_c"), analysisJet.pt(), nTracks, weight);
371349
} else {
372350
registry.fill(HIST("h_jetpT_lf"), analysisJet.pt(), weight);
373351
registry.fill(HIST("h_Db_lf"), analysisJet.scoreML(), weight);
374352
registry.fill(HIST("h2_jetpT_Db_lf"), analysisJet.pt(), analysisJet.scoreML(), weight);
375-
registry.fill(HIST("h2_jetpT_SVMass_lf"), analysisJet.pt(), mSV, weight);
376-
registry.fill(HIST("h2_jetpT_jetMass_lf"), analysisJet.pt(), analysisJet.mass(), weight);
377-
registry.fill(HIST("h2_jetpT_jetProb_lf"), analysisJet.pt(), analysisJet.jetProb(), weight);
378-
registry.fill(HIST("h2_jetpT_nTracks_lf"), analysisJet.pt(), nTracks, weight);
379353
if (jetFlavor == JetTaggingSpecies::none) {
380354
registry.fill(HIST("h2_jetpT_Db_lf_none"), analysisJet.pt(), analysisJet.scoreML(), weight);
381355
} else {

PWGJE/Tasks/bjetTreeCreator.cxx

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ struct BJetTreeCreator {
241241

242242
Configurable<float> vtxRes{"vtxRes", 0.01, "Vertex position resolution (cluster size) for GNN vertex predictions (cm)"};
243243

244+
Configurable<int> trainingDatasetRatioParam{"trainingDatasetRatioParam", 0, "Parameter for splitting training/evaluation datasets by collisionId"};
245+
244246
std::vector<int> eventSelectionBits;
245247

246248
std::vector<double> jetRadiiValues;
@@ -488,7 +490,7 @@ struct BJetTreeCreator {
488490

489491
trkIdx++;
490492

491-
if (constituent.pt() < trackPtMin) {
493+
if (constituent.pt() < trackPtMin || !jettaggingutilities::trackAcceptanceWithDca(constituent, maxIPxy, maxIPz)) {
492494
continue;
493495
}
494496

@@ -707,7 +709,7 @@ struct BJetTreeCreator {
707709
}
708710
PROCESS_SWITCH(BJetTreeCreator, processMCJets, "jet information in MC", false);
709711

710-
using MCDJetTableNoSV = soa::Filtered<soa::Join<aod::ChargedMCDetectorLevelJets, aod::ChargedMCDetectorLevelJetConstituents, aod::ChargedMCDetectorLevelJetsMatchedToChargedMCParticleLevelJets, aod::ChargedMCDetectorLevelJetEventWeights>>;
712+
using MCDJetTableNoSV = soa::Filtered<soa::Join<aod::ChargedMCDetectorLevelJets, aod::ChargedMCDetectorLevelJetConstituents, aod::ChargedMCDetectorLevelJetsMatchedToChargedMCParticleLevelJets, aod::ChargedMCDetectorLevelJetFlavourDef, aod::ChargedMCDetectorLevelJetEventWeights>>;
711713
using JetParticleswID = soa::Join<aod::JetParticles, aod::JMcParticlePIs>;
712714

713715
void processMCJetsForGNN(FilteredCollisionMCD::iterator const& collision, aod::JMcCollisions const&, MCDJetTableNoSV const& MCDjets, MCPJetTable const& MCPjets, JetTracksMCDwID const& allTracks, JetParticleswID const& MCParticles, OriginalTracks const& origTracks, aod::McParticles const& origParticles)
@@ -716,7 +718,12 @@ struct BJetTreeCreator {
716718
return;
717719
}
718720

719-
registry.fill(HIST("h_vertexZ"), collision.posZ());
721+
// Uses only collisionId % trainingDatasetRaioParam == 0 for training dataset
722+
if (trainingDatasetRatioParam && collision.collisionId() % trainingDatasetRatioParam != 0) {
723+
return;
724+
}
725+
726+
registry.fill(HIST("h_vertexZ"), collision.posZ(), collision.weight());
720727

721728
auto const mcParticlesPerColl = MCParticles.sliceBy(mcParticlesPerCollision, collision.mcCollisionId());
722729
auto const mcPJetsPerColl = MCPjets.sliceBy(mcpJetsPerCollision, collision.mcCollisionId());
@@ -738,15 +745,7 @@ struct BJetTreeCreator {
738745
std::vector<int> indicesTracks;
739746
std::vector<int> indicesSVs;
740747

741-
int16_t jetFlavor = 0;
742-
743-
for (const auto& mcpjet : analysisJet.template matchedJetGeo_as<MCPJetTable>()) {
744-
if (useQuarkDef) {
745-
jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, mcParticlesPerColl);
746-
} else {
747-
jetFlavor = jettaggingutilities::getJetFlavorHadron(mcpjet, mcParticlesPerColl);
748-
}
749-
}
748+
int16_t jetFlavor = analysisJet.origin();
750749

751750
if ((jetFlavor != JetTaggingSpecies::charm && jetFlavor != JetTaggingSpecies::beauty) && (static_cast<double>(std::rand()) / RAND_MAX < getReductionFactor(analysisJet.pt()))) {
752751
continue;
@@ -760,7 +759,7 @@ struct BJetTreeCreator {
760759
analyzeJetTrackInfoForGNN(collision, analysisJet, allTracks, origTracks, indicesTracks, jetFlavor, eventWeight, &trkLabels);
761760

762761
registry.fill(HIST("h2_jetMass_jetpT"), analysisJet.pt(), analysisJet.mass(), eventWeight);
763-
registry.fill(HIST("h2_nTracks_jetpT"), analysisJet.pt(), indicesTracks.size());
762+
registry.fill(HIST("h2_nTracks_jetpT"), analysisJet.pt(), indicesTracks.size(), eventWeight);
764763

765764
//+jet
766765
registry.fill(HIST("h_jet_pt"), analysisJet.pt());

0 commit comments

Comments
 (0)