Skip to content

Commit 55cee12

Browse files
mathieuouillonbaltzell
authored andcommitted
feat(track-finding): add GNN track finder
Introduce GNN_Track_Finding as a fourth track-finding mode alongside the renamed MLP_Track_Finding (was AI_Track_Finding), CV_Distance, and CV_Hough. The new path runs a GravNet edge scorer (TorchScript via DJL) on a per-event AHDC + ATOF hit graph, extracts tracks as connected components on edges with sigmoid score >= 0.1, then re-preclusters each surviving track's AHDC hits and pairs them into per-superlayer Clusters so the existing DOCA refinement + helix fit + Kalman stages consume them unchanged. Selected via ALERT.Mode in YAML. MLP regression is bit-identical (same pre-existing AHDC::track sum_adc/ dEdx precision drift); only COAT::config changes, reflecting the renamed mode.
1 parent feca853 commit 55cee12

10 files changed

Lines changed: 648 additions & 8 deletions

File tree

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package org.jlab.rec.ahdc.AI;
2+
3+
/** Normalization and graph-construction constants for the GNN track finder.
4+
* Mirrors track-finding/gnn/config.py — keep in sync with the training config.
5+
*/
6+
final class GNNConstants {
7+
private GNNConstants() {}
8+
9+
static final int NODE_FEAT_DIM = 10;
10+
static final int EDGE_FEAT_DIM = 9;
11+
12+
// Model architecture parameters (control the minimum graph size at inference).
13+
// GravNet progressive-k reaches 2*k, topk uses k+1 → N_nodes >= 2*k + 2.
14+
// The exported model clamps topk(k+1) to N internally (see
15+
// track-finding/export_torchscript.py::_knn_indices), so any graph with
16+
// >=3 nodes runs without crashing. Smaller graphs can't form any edge
17+
// with the MAX_LAYER_GAP rule anyway, so we skip them here.
18+
static final int MIN_NODES = 3;
19+
20+
// Graph construction
21+
static final int MAX_LAYER_GAP = 2;
22+
static final double MAX_EDGE_DISTANCE = 35.0; // mm
23+
static final double MAX_EDGE_DIST_SQ = MAX_EDGE_DISTANCE * MAX_EDGE_DISTANCE;
24+
25+
// Feature normalization
26+
static final double MAX_R = 100.0; // mm
27+
static final double DOCA_STD = 10.0; // mm
28+
static final double Z_HALF_LENGTH = 200.0; // mm
29+
static final double STEREO_ANGLE_MAX = 0.03; // rad
30+
static final double STEREO_SCALE = 1.0 / STEREO_ANGLE_MAX;
31+
32+
// ATOF abs_layer convention from Python's build_graph
33+
static final int ATOF_BAR_ABS_LAYER = 10; // component == 10
34+
static final int ATOF_WEDGE_ABS_LAYER = 11; // all other components
35+
36+
// Track extraction: connected components at a single score threshold, matching
37+
// gnn/evaluate.py (extract_tracks(..., method="cc", threshold=0.1)). Drop tracks
38+
// with fewer than MIN_TRACK_NODES total nodes — same filter evaluate.py applies
39+
// after the method call.
40+
static final double TRACK_SCORE_THRESHOLD = 0.1;
41+
static final int MIN_TRACK_NODES = 3;
42+
}
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
package org.jlab.rec.ahdc.AI;
2+
3+
import java.util.ArrayList;
4+
import java.util.HashSet;
5+
import java.util.List;
6+
import java.util.Set;
7+
8+
import org.jlab.geom.prim.Line3D;
9+
import org.jlab.geom.prim.Point3D;
10+
import org.jlab.geom.prim.Vector3D;
11+
import org.jlab.io.base.DataBank;
12+
import org.jlab.rec.ahdc.Hit.Hit;
13+
14+
/** Builds the graph tensors expected by the exported GNN edge scorer.
15+
* Ports track-finding/gnn/dataset.py::build_graph — must stay byte-compatible
16+
* with the training-time feature layout and normalization.
17+
*/
18+
final class GNNGraphBuilder {
19+
20+
/** Container for the tensors + node provenance that the caller needs. */
21+
static final class GraphInput {
22+
final float[][] nodeFeatures; // shape [N, 10]
23+
final long[][] edgeIndex; // shape [2, E]
24+
final float[][] edgeAttr; // shape [E, 9]
25+
/** nodeToSource[i] is the backing Hit for AHDC nodes, or null for ATOF nodes. */
26+
final Hit[] nodeToSource;
27+
28+
GraphInput(float[][] nodeFeatures, long[][] edgeIndex, float[][] edgeAttr, Hit[] nodeToSource) {
29+
this.nodeFeatures = nodeFeatures;
30+
this.edgeIndex = edgeIndex;
31+
this.edgeAttr = edgeAttr;
32+
this.nodeToSource = nodeToSource;
33+
}
34+
}
35+
36+
private GNNGraphBuilder() {}
37+
38+
/** Build a graph from AHDC hits (required) plus the ATOF::hits bank (optional). */
39+
static GraphInput build(List<Hit> ahdcHits, DataBank atofHitsBank) {
40+
int nAhdc = ahdcHits == null ? 0 : ahdcHits.size();
41+
42+
// Node state buffers (grow as we append AHDC then ATOF nodes).
43+
List<double[]> nodeBuf = new ArrayList<>(); // per-node raw floats (see NodeField indexes)
44+
List<Line3D> nodeLine = new ArrayList<>(); // wire line for AHDC; null for ATOF
45+
List<Hit> nodeHit = new ArrayList<>(); // backing Hit for AHDC; null for ATOF
46+
47+
// --- AHDC nodes -------------------------------------------------------------
48+
for (int i = 0; i < nAhdc; i++) {
49+
Hit h = ahdcHits.get(i);
50+
Line3D line = h.getLine();
51+
if (line == null) continue; // missing geometry → skip (shouldn't happen after setWirePosition)
52+
53+
Point3D mid = line.midpoint();
54+
Vector3D dir = line.toVector();
55+
double len = Math.max(dir.mag(), 1e-12);
56+
double ux = dir.x() / len, uy = dir.y() / len, uz = dir.z() / len;
57+
double stereo = Math.atan2(Math.sqrt(ux*ux + uy*uy), uz);
58+
59+
int absLayer = (h.getSuperLayerId() - 1) * 2 + (h.getLayerId() - 1);
60+
nodeBuf.add(new double[]{
61+
absLayer, // 0: abs_layer
62+
h.getPhi(), // 1: phi
63+
h.getRadius(), // 2: r
64+
stereo, // 3: stereo_angle
65+
mid.x(), // 4: x_mid
66+
mid.y(), // 5: y_mid
67+
mid.z(), // 6: z_mid
68+
ux, // 7: ux
69+
uy, // 8: uy
70+
uz, // 9: uz
71+
h.getX(), // 10: x (raw, for edge distance mask)
72+
h.getY(), // 11: y (raw, for edge distance mask)
73+
0.0, // 12: det_type = 0 (AHDC)
74+
});
75+
nodeLine.add(line);
76+
nodeHit.add(h);
77+
}
78+
79+
// --- ATOF nodes -------------------------------------------------------------
80+
// Deduplicate by (sector, layer, component) — inference-time variant of the
81+
// Python dedup which also keys on track id (only needed at training time).
82+
if (atofHitsBank != null) {
83+
Set<Long> seen = new HashSet<>();
84+
int rows = atofHitsBank.rows();
85+
for (int r = 0; r < rows; r++) {
86+
int sector = atofHitsBank.getInt("sector", r);
87+
int layer = atofHitsBank.getInt("layer", r);
88+
int component = atofHitsBank.getInt("component", r);
89+
long key = (((long)sector * 1000L) + layer) * 1000L + component;
90+
if (!seen.add(key)) continue;
91+
92+
double x = atofHitsBank.getFloat("x", r);
93+
double y = atofHitsBank.getFloat("y", r);
94+
double radius = Math.hypot(x, y);
95+
double phi = Math.atan2(y, x);
96+
int absLayer = (component == 10) ? GNNConstants.ATOF_BAR_ABS_LAYER
97+
: GNNConstants.ATOF_WEDGE_ABS_LAYER;
98+
99+
nodeBuf.add(new double[]{
100+
absLayer, phi, radius,
101+
0.0, // stereo
102+
x, y, 0.0, // mid
103+
0.0, 0.0, 1.0, // (ux, uy, uz)
104+
x, y, // raw x, y (for edge mask)
105+
1.0, // det_type = 1 (ATOF)
106+
});
107+
nodeLine.add(null);
108+
nodeHit.add(null);
109+
}
110+
}
111+
112+
int n = nodeBuf.size();
113+
if (n < 2) {
114+
return new GraphInput(new float[0][GNNConstants.NODE_FEAT_DIM],
115+
new long[][]{new long[0], new long[0]},
116+
new float[0][GNNConstants.EDGE_FEAT_DIM],
117+
new Hit[0]);
118+
}
119+
120+
// --- Node feature tensor [N, 10] --------------------------------------------
121+
float[][] nodeFeatures = new float[n][GNNConstants.NODE_FEAT_DIM];
122+
for (int i = 0; i < n; i++) {
123+
double[] v = nodeBuf.get(i);
124+
nodeFeatures[i][0] = (float)(v[0] / 11.0);
125+
nodeFeatures[i][1] = (float)(v[1] / Math.PI);
126+
nodeFeatures[i][2] = (float)(v[2] / GNNConstants.DOCA_STD);
127+
nodeFeatures[i][3] = (float)(v[3] / GNNConstants.STEREO_ANGLE_MAX);
128+
nodeFeatures[i][4] = (float)(v[4] / GNNConstants.MAX_R);
129+
nodeFeatures[i][5] = (float)(v[5] / GNNConstants.MAX_R);
130+
nodeFeatures[i][6] = (float)(v[6] / GNNConstants.Z_HALF_LENGTH);
131+
nodeFeatures[i][7] = (float)(v[7] * GNNConstants.STEREO_SCALE);
132+
nodeFeatures[i][8] = (float)(v[8] * GNNConstants.STEREO_SCALE);
133+
nodeFeatures[i][9] = (float)(v[9]);
134+
}
135+
136+
// --- Edge construction (directed, layer_gap in [1, MAX_LAYER_GAP]) -----------
137+
// Mirrors Python's np.where(mask) on a non-symmetric mask.
138+
int[] absLayer = new int[n];
139+
double[] xRaw = new double[n];
140+
double[] yRaw = new double[n];
141+
double[] rRaw = new double[n];
142+
double[] phiRaw = new double[n];
143+
double[] stereoRaw = new double[n];
144+
double[] detTypeRaw = new double[n];
145+
for (int i = 0; i < n; i++) {
146+
double[] v = nodeBuf.get(i);
147+
absLayer[i] = (int) v[0];
148+
phiRaw[i] = v[1];
149+
rRaw[i] = v[2];
150+
stereoRaw[i] = v[3];
151+
xRaw[i] = v[10];
152+
yRaw[i] = v[11];
153+
detTypeRaw[i] = v[12];
154+
}
155+
156+
List<long[]> edgePairs = new ArrayList<>();
157+
for (int i = 0; i < n; i++) {
158+
for (int j = 0; j < n; j++) {
159+
if (i == j) continue;
160+
int gap = absLayer[j] - absLayer[i];
161+
if (gap < 1 || gap > GNNConstants.MAX_LAYER_GAP) continue;
162+
double dx = xRaw[i] - xRaw[j];
163+
double dy = yRaw[i] - yRaw[j];
164+
if (dx*dx + dy*dy > GNNConstants.MAX_EDGE_DIST_SQ) continue;
165+
edgePairs.add(new long[]{i, j});
166+
}
167+
}
168+
169+
int e = edgePairs.size();
170+
long[][] edgeIndex = new long[2][e];
171+
float[][] edgeAttr = new float[e][GNNConstants.EDGE_FEAT_DIM];
172+
173+
for (int k = 0; k < e; k++) {
174+
long[] p = edgePairs.get(k);
175+
int s = (int) p[0];
176+
int d = (int) p[1];
177+
edgeIndex[0][k] = s;
178+
edgeIndex[1][k] = d;
179+
180+
// dphi wrapped into [-pi, pi]
181+
double dphi = phiRaw[s] - phiRaw[d];
182+
dphi = ((dphi + Math.PI) % (2.0 * Math.PI) + 2.0 * Math.PI) % (2.0 * Math.PI) - Math.PI;
183+
double dlayer = (double)(absLayer[d] - absLayer[s]) / GNNConstants.MAX_LAYER_GAP;
184+
185+
double doca, z1, z2;
186+
Line3D ls = nodeLine.get(s);
187+
Line3D ld = nodeLine.get(d);
188+
if (ls != null && ld != null) {
189+
doca = ls.distance(ld).length();
190+
// Python: z1 = cp_d.z, where cp_d is the point on line_s closest to line_d's midpoint
191+
// z2 = cp_s.z, where cp_s is the point on line_d closest to line_s's midpoint
192+
z1 = clampZ(ls.distance(ld.midpoint()).origin().z());
193+
z2 = clampZ(ld.distance(ls.midpoint()).origin().z());
194+
} else {
195+
double ex = xRaw[s] - xRaw[d];
196+
double ey = yRaw[s] - yRaw[d];
197+
doca = Math.hypot(ex, ey);
198+
z1 = 0.0;
199+
z2 = 0.0;
200+
}
201+
202+
double edgeDetType = 0.5 * (detTypeRaw[s] + detTypeRaw[d]);
203+
204+
edgeAttr[k][0] = (float)(dphi / Math.PI);
205+
edgeAttr[k][1] = (float) dlayer;
206+
edgeAttr[k][2] = (float)(doca / GNNConstants.MAX_R);
207+
edgeAttr[k][3] = (float)(z1 / GNNConstants.Z_HALF_LENGTH);
208+
edgeAttr[k][4] = (float)(z2 / GNNConstants.Z_HALF_LENGTH);
209+
edgeAttr[k][5] = (float)(rRaw[s] / GNNConstants.DOCA_STD);
210+
edgeAttr[k][6] = (float)(rRaw[d] / GNNConstants.DOCA_STD);
211+
edgeAttr[k][7] = (float)((stereoRaw[s] - stereoRaw[d]) / (2.0 * GNNConstants.STEREO_ANGLE_MAX));
212+
edgeAttr[k][8] = (float) edgeDetType;
213+
}
214+
215+
Hit[] nodeToHit = nodeHit.toArray(new Hit[0]);
216+
return new GraphInput(nodeFeatures, edgeIndex, edgeAttr, nodeToHit);
217+
}
218+
219+
private static double clampZ(double z) {
220+
if (z < -GNNConstants.Z_HALF_LENGTH) return -GNNConstants.Z_HALF_LENGTH;
221+
if (z > GNNConstants.Z_HALF_LENGTH) return GNNConstants.Z_HALF_LENGTH;
222+
return z;
223+
}
224+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package org.jlab.rec.ahdc.AI;
2+
3+
import java.util.ArrayList;
4+
import java.util.HashMap;
5+
import java.util.List;
6+
import java.util.Map;
7+
import java.util.logging.Logger;
8+
9+
import org.jlab.io.base.DataBank;
10+
import org.jlab.rec.ahdc.Cluster.Cluster;
11+
import org.jlab.rec.ahdc.Hit.Hit;
12+
import org.jlab.rec.ahdc.PreCluster.PreCluster;
13+
import org.jlab.rec.ahdc.PreCluster.PreClusterFinder;
14+
import org.jlab.rec.ahdc.Track.Track;
15+
16+
/** Orchestrates GNN-based track finding: builds the graph, runs the exported
17+
* edge scorer, extracts tracks via connected components on edge scores
18+
* thresholded at 0.1, and converts each node-set back into a {@link Track}
19+
* carrying per-superlayer Clusters so the downstream helix fit / Kalman
20+
* stages can consume it.
21+
*/
22+
public final class GNNPrediction {
23+
24+
private static final Logger LOGGER = Logger.getLogger(GNNPrediction.class.getName());
25+
26+
public ArrayList<Track> prediction(List<Hit> ahdcHits,
27+
DataBank atofHitsBank,
28+
ModelTrackFindingGNN model) {
29+
ArrayList<Track> out = new ArrayList<>();
30+
if (ahdcHits == null || ahdcHits.isEmpty() || model == null) return out;
31+
32+
GNNGraphBuilder.GraphInput g = GNNGraphBuilder.build(ahdcHits, atofHitsBank);
33+
int nNodes = g.nodeToSource.length;
34+
int nEdges = g.edgeIndex[0].length;
35+
if (nNodes < GNNConstants.MIN_NODES || nEdges == 0) {
36+
return out; // model cannot run on graphs this small
37+
}
38+
39+
float[] edgeScores;
40+
try {
41+
edgeScores = model.predictEdgeScores(g.nodeFeatures, g.edgeIndex, g.edgeAttr);
42+
} catch (Exception ex) {
43+
LOGGER.warning(() -> "GNN inference failed: " + ex);
44+
return out;
45+
}
46+
47+
// Connected components at TRACK_SCORE_THRESHOLD, filtered to
48+
// components of size >= MIN_TRACK_NODES — mirrors gnn/evaluate.py.
49+
List<int[]> trackNodeSets = SeedExtendTrackExtractor.extract(edgeScores, g.edgeIndex, nNodes);
50+
51+
for (int[] nodes : trackNodeSets) {
52+
// Collect just the AHDC Hits in this track — ATOF nodes were graph
53+
// context only, they don't belong in AHDC::track or AHDC::hits.
54+
ArrayList<Hit> trackHits = new ArrayList<>(nodes.length);
55+
for (int n : nodes) {
56+
Hit h = g.nodeToSource[n];
57+
if (h != null) trackHits.add(h);
58+
}
59+
if (trackHits.isEmpty()) continue;
60+
61+
ArrayList<Cluster> clusters = buildSuperlayerClusters(trackHits);
62+
if (clusters.size() < 3) continue; // matches the downstream >=3 filter
63+
64+
out.add(new Track(clusters));
65+
}
66+
67+
return out;
68+
}
69+
70+
/** One {@link Cluster} per superlayer built from two {@link PreCluster}s (one
71+
* per layer within the superlayer). Using real PreClusters — instead of the
72+
* 3-arg {@code Cluster(x,y,z)} constructor — keeps
73+
* {@code Track.generateHitList()} and {@code DocaClusterRefiner}'s stereo
74+
* pairing working for GNN-discovered tracks just like they do for MLP tracks.
75+
*/
76+
private static ArrayList<Cluster> buildSuperlayerClusters(List<Hit> hits) {
77+
// Feed the track's hits through the same preclustering the MLP path uses.
78+
// findPreclusters mutates its input (it calls setUse(true) on consumed
79+
// hits), so pass a copy and ensure each hit starts unmarked.
80+
ArrayList<Hit> hitsForPre = new ArrayList<>(hits.size());
81+
for (Hit h : hits) { h.setUse(false); hitsForPre.add(h); }
82+
PreClusterFinder pcf = new PreClusterFinder();
83+
pcf.findPreclusters(hitsForPre);
84+
ArrayList<PreCluster> preclusters = pcf.get_AHDCPreClusters();
85+
86+
// Index by (superlayer, layer). If the GNN assigns two PreClusters of the
87+
// same superlayer+layer to one track (rare — it would mean two disjoint
88+
// wire runs on the same layer), keep the largest and drop the rest.
89+
Map<Integer, PreCluster[]> bySuperlayer = new HashMap<>();
90+
for (PreCluster pc : preclusters) {
91+
int sl = pc.get_Super_layer();
92+
int layerIdx = pc.get_Layer() - 1; // layer is 1-based, slots are [0,1]
93+
if (layerIdx < 0 || layerIdx > 1) continue;
94+
PreCluster[] slot = bySuperlayer.computeIfAbsent(sl, k -> new PreCluster[2]);
95+
PreCluster prev = slot[layerIdx];
96+
if (prev == null || pc.get_Num_wire() > prev.get_Num_wire()) slot[layerIdx] = pc;
97+
}
98+
99+
ArrayList<Cluster> clusters = new ArrayList<>();
100+
// Iterate superlayers in ascending order to keep downstream output stable.
101+
// If both stereo layers have a PreCluster, pair them (full stereo cluster).
102+
// If only one has hits, use the single-layer Cluster(PreCluster) ctor —
103+
// DocaClusterRefiner handles PreClusters_list.size() != 2 with a
104+
// degenerate DocaCluster fallback, so the helix fit still runs.
105+
for (int sl = 1; sl <= 5; sl++) {
106+
PreCluster[] slot = bySuperlayer.get(sl);
107+
if (slot == null) continue;
108+
if (slot[0] != null && slot[1] != null) {
109+
clusters.add(new Cluster(slot[0], slot[1]));
110+
} else {
111+
PreCluster single = (slot[0] != null) ? slot[0] : slot[1];
112+
if (single != null) clusters.add(new Cluster(single));
113+
}
114+
}
115+
return clusters;
116+
}
117+
}

0 commit comments

Comments
 (0)