Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 50 additions & 29 deletions src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,17 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
)
except IndexError:
parent_type = 0

if primary_energy > 0:
primary_fraction = EonEntrance / primary_energy
else:
primary_fraction = -1

if visible_length != -1:
primary = frame[self.mctree].get_primary(HEParticle.id)
primary_is_nu, primary_type = primary.is_neutrino, primary.type
else:
primary_type = 0
primary_is_nu = False
output.update(
{
"e_fraction_" + self._extractor_name: primary_fraction,
Expand All @@ -153,6 +159,8 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
"particle_type_" + self._extractor_name: HEParticle.type,
"containment_" + self._extractor_name: containment,
"parent_type_" + self._extractor_name: parent_type,
"primary_type_" + self._extractor_name: primary_type,
"primary_is_nu_" + self._extractor_name: primary_is_nu,
}
)

Expand Down Expand Up @@ -182,7 +190,9 @@ def get_tracks(
primaries = [self.check_primary_energy(frame, p) for p in primaries]

MMCTrackList = frame[self.mmctracklist]
if self.daughters:
if self.daughters & (
not self._is_corsika
): # expensive operation unecessary for CORSIKA
temp_MMCTrackList = []
for track in MMCTrackList:
for p in primaries:
Expand All @@ -192,6 +202,26 @@ def get_tracks(
temp_MMCTrackList.append(track)
break
MMCTrackList = simclasses.I3MMCTrackList(temp_MMCTrackList)
elif self._is_corsika & self.daughters:
MMCTrackList_filtered = []
for track in MMCTrackList:
try:
if (
frame[self.mctree].get_primary(track.GetI3Particle())
in primaries
):
MMCTrackList_filtered.append(track)
except RuntimeError as e:
if "particle not found" in str(e):
# get event header
self.warning(
f"Particle {track.GetI3Particle().id} not found in MCTree."
f" Skipping track in event {frame['I3EventHeader']}"
)
else:
raise e # re-raise unexpected errors

MMCTrackList = simclasses.I3MMCTrackList(MMCTrackList_filtered)

MuonGun_tracks = np.array(
MuonGun.Track.harvest(frame[self.mctree], MMCTrackList)
Expand Down Expand Up @@ -347,14 +377,6 @@ def highest_energy_track(
if tmp_EonEntrance > EonEntrance:
particle = track_particle

closest_pos = np.array(
[
track.GetXc(),
track.GetYc(),
track.GetZc(),
]
)

EonEntrance = tmp_EonEntrance

visible_length = intersections.second - max(
Expand Down Expand Up @@ -391,6 +413,13 @@ def highest_energy_track(
)
particle.time = track.GetTi()
else:
closest_pos = np.array(
[
track.GetXc(),
track.GetYc(),
track.GetZc(),
]
)
# If the track is stopping or throughgoing,
# pos is point closest to detector center.
distance = np.sqrt((closest_pos**2).sum())
Expand Down Expand Up @@ -718,7 +747,7 @@ def highest_energy_bundle(
lengths = lengths[length_mask]

containment = GN_containment_types.stopping_bundle.value
closest_pos = []
highest_e = 0
for track, MGtrack in zip(MMCTrackList, MuonGun_tracks):
intersections = self.hull.surface.intersection(
MGtrack.pos, MGtrack.dir
Expand Down Expand Up @@ -751,30 +780,24 @@ def highest_energy_bundle(
raise # re-raise unexpected errors

EonEntrance += track_energy

closest_pos.append(
np.array(
if track_energy > highest_e:
highest_e = track_energy
closest_pos = np.array(
[
track.GetXc(),
track.GetYc(),
track.GetZc(),
]
)
* track_energy
)
if closest_time is None:
closest_time = track.GetTc()
elif closest_time < track.GetTc():

closest_time = track.GetTc()

if intersections.second > 0:
visible_length = max(
visible_length, intersections.second - intersections.first
)
if MGtrack.length > intersections.second:
containment = (
GN_containment_types.throughgoing_bundle.value
)
if intersections.second > 0:
visible_length = intersections.second - intersections.first
if MGtrack.length > intersections.second:
containment = (
GN_containment_types.throughgoing_bundle.value
)

# If no intersection.second is every positive
# the visible_length can still be negative here
Expand All @@ -793,8 +816,6 @@ def highest_energy_bundle(
visible_length >= 0
), f"Visible length is negative for particle {frame['I3EventHeader']}"

closest_pos = np.sum(closest_pos, axis=0) / EonEntrance

bundle.pos = dataclasses.I3Position(
closest_pos[0], closest_pos[1], closest_pos[2]
)
Expand Down
Loading