Skip to content

Commit c2e7ca6

Browse files
author
miranov25
committed
AliasDataFrame: add index-based subframe join and robust error handling
- Updated `register_subframe()` to explicitly require `index_columns` for join key(s) - Enhanced `_prepare_subframe_joins()` to: - auto-materialize subframe aliases if missing - raise informative KeyError when column or alias does not exist - Added logic to propagate subframe metadata (including join indices) in save/load and ROOT export/import - Expanded test coverage: - Added subframe alias tests for automatic materialization and error reporting - Added 2D index subframe join test (e.g. using ["run", "track_id"]) - Refactored test setup to avoid shared state interference - Asserted raised exceptions for missing subframe attributes - Minor fixes to alias materialization and type assertions
1 parent 9b7a038 commit c2e7ca6

File tree

2 files changed

+95
-56
lines changed

2 files changed

+95
-56
lines changed

UTILS/dfextensions/AliasDataFrame.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,23 @@
1111

1212
class SubframeRegistry:
1313
def __init__(self):
14-
self.subframes = {}
14+
self.subframes = {} # name → {'frame': adf, 'index': index_columns}
1515

16-
def add_subframe(self, name, alias_df):
17-
self.subframes[name] = alias_df
16+
def add_subframe(self, name, alias_df, index_columns, pre_index=False):
17+
if pre_index and not alias_df.df.index.names == index_columns:
18+
alias_df.df.set_index(index_columns, inplace=True)
19+
self.subframes[name] = {'frame': alias_df, 'index': index_columns}
1820

1921
def get(self, name):
22+
return self.subframes.get(name, {}).get('frame', None)
23+
24+
def get_entry(self, name):
2025
return self.subframes.get(name, None)
2126

2227
def items(self):
2328
return self.subframes.items()
2429

30+
2531
def convert_expr_to_root(expr):
2632
class RootTransformer(ast.NodeTransformer):
2733
FUNC_MAP = {
@@ -77,8 +83,8 @@ def __getattr__(self, item):
7783
return self.df[item]
7884
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
7985

80-
def register_subframe(self, name, adf):
81-
self._subframes.add_subframe(name, adf)
86+
def register_subframe(self, name, adf, index_columns, pre_index=False):
87+
self._subframes.add_subframe(name, adf, index_columns, pre_index=pre_index)
8288

8389
def get_subframe(self, name):
8490
return self._subframes.get(name)
@@ -88,10 +94,41 @@ def _default_functions(self):
8894
env = {k: getattr(math, k) for k in dir(math) if not k.startswith("_")}
8995
env.update({k: getattr(np, k) for k in dir(np) if not k.startswith("_")})
9096
env["np"] = np
91-
for sf_name, sf in self._subframes.items():
92-
env[sf_name] = sf
97+
for sf_name, sf_entry in self._subframes.items():
98+
env[sf_name] = sf_entry['frame']
9399
return env
94100

101+
def _prepare_subframe_joins(self, expr):
102+
tokens = re.findall(r'(\b\w+)\.(\w+)', expr)
103+
for sf_name, sf_col in tokens:
104+
entry = self._subframes.get_entry(sf_name)
105+
if not entry:
106+
continue
107+
sub_adf = entry['frame']
108+
sub_df = sub_adf.df
109+
index_cols = entry['index']
110+
if isinstance(index_cols, str):
111+
index_cols = [index_cols]
112+
merge_cols = index_cols + [sf_col]
113+
suffix = f'__{sf_name}'
114+
115+
try:
116+
cols_to_merge = sub_df[merge_cols]
117+
except KeyError:
118+
if sf_col in sub_adf.aliases:
119+
sub_adf.materialize_alias(sf_col)
120+
sub_df = sub_adf.df
121+
cols_to_merge = sub_df[merge_cols]
122+
else:
123+
raise KeyError(f"Subframe '{sf_name}' does not contain or define alias '{sf_col}'")
124+
125+
joined = self.df.merge(cols_to_merge, on=index_cols, suffixes=('', suffix))
126+
col_renamed = f'{sf_col}{suffix}'
127+
if col_renamed in joined.columns:
128+
self.df[col_renamed] = joined[col_renamed].values
129+
expr = expr.replace(f'{sf_name}.{sf_col}', col_renamed)
130+
return expr
131+
95132
def _check_for_cycles(self):
96133
try:
97134
self._topological_sort()
@@ -107,8 +144,8 @@ def add_alias(self, name, expression, dtype=None, is_constant=False):
107144
self._check_for_cycles()
108145

109146
def _eval_in_namespace(self, expr):
147+
expr = self._prepare_subframe_joins(expr)
110148
local_env = {col: self.df[col] for col in self.df.columns}
111-
local_env.update({k: self.df[k] for k in self.aliases if k in self.df})
112149
local_env.update(self._default_functions())
113150
return eval(expr, {}, local_env)
114151

@@ -300,8 +337,8 @@ def export_tree(self, filename_or_file, treename="tree", dropAliasColumns=True):
300337
self._write_metadata_to_root(filename_or_file, treename)
301338
else:
302339
self._write_to_uproot(filename_or_file, treename, dropAliasColumns)
303-
for subframe_name, sub_adf in self._subframes.items():
304-
sub_adf._write_metadata_to_root(filename_or_file, f"{treename}__subframe__{subframe_name}")
340+
for subframe_name, entry in self._subframes.items():
341+
entry["frame"]._write_metadata_to_root(filename_or_file, f"{treename}__subframe__{subframe_name}")
305342

306343
def _write_to_uproot(self, uproot_file, treename, dropAliasColumns):
307344
export_cols = [col for col in self.df.columns if not dropAliasColumns or col not in self.aliases]
@@ -310,8 +347,8 @@ def _write_to_uproot(self, uproot_file, treename, dropAliasColumns):
310347

311348
uproot_file[treename] = export_df
312349

313-
for subframe_name, sub_adf in self._subframes.items():
314-
sub_adf.export_tree(uproot_file, f"{treename}__subframe__{subframe_name}", dropAliasColumns)
350+
for subframe_name, entry in self._subframes.items():
351+
entry["frame"].export_tree(uproot_file, f"{treename}__subframe__{subframe_name}", dropAliasColumns)
315352

316353
def _write_metadata_to_root(self, filename, treename):
317354
f = ROOT.TFile.Open(filename, "UPDATE")
@@ -325,6 +362,7 @@ def _write_metadata_to_root(self, filename, treename):
325362
tree.SetAlias(alias, expr_str)
326363
metadata = {
327364
"aliases": self.aliases,
365+
"subframe_indices": {k: v["index"] for k, v in self._subframes.items()},
328366
"dtypes": {k: v.__name__ for k, v in self.alias_dtypes.items()},
329367
"constants": list(self.constant_aliases),
330368
"subframes": list(self._subframes.subframes.keys())
@@ -334,6 +372,7 @@ def _write_metadata_to_root(self, filename, treename):
334372
tree.Write("", ROOT.TObject.kOverwrite)
335373
f.Close()
336374

375+
@staticmethod
337376
def read_tree(filename, treename="tree"):
338377
with uproot.open(filename) as f:
339378
df = f[treename].arrays(library="pd")
@@ -354,7 +393,10 @@ def read_tree(filename, treename="tree"):
354393
adf.constant_aliases.update(jmeta.get("constants", []))
355394
for sf_name in jmeta.get("subframes", []):
356395
sf = AliasDataFrame.read_tree(filename, treename=f"{treename}__subframe__{sf_name}")
357-
adf.register_subframe(sf_name, sf)
396+
index = jmeta.get("subframe_indices", {}).get(sf_name)
397+
if index is None:
398+
raise ValueError(f"Missing index_columns for subframe '{sf_name}' in metadata")
399+
adf.register_subframe(sf_name, sf, index_columns=index)
358400
break
359401
except Exception:
360402
pass

UTILS/dfextensions/AliasDataFrameTest.py

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,10 @@ def test_export_import_tree_roundtrip(self):
9494
os.remove(tmp_path)
9595

9696
class TestAliasDataFrameWithSubframes(unittest.TestCase):
97-
@classmethod
98-
def setUpClass(cls):
97+
def setUp(self):
9998
n_tracks = 1000
10099
n_clusters = 100
101-
cls.df_tracks = pd.DataFrame({
100+
df_tracks = pd.DataFrame({
102101
"track_index": np.arange(n_tracks),
103102
"mX": np.random.normal(0, 10, n_tracks),
104103
"mY": np.random.normal(0, 10, n_tracks),
@@ -107,65 +106,63 @@ def setUpClass(cls):
107106
"mEta": np.random.normal(0, 1, n_tracks),
108107
})
109108

110-
cluster_idx = np.repeat(cls.df_tracks["track_index"], n_clusters)
111-
cls.df_clusters = pd.DataFrame({
109+
cluster_idx = np.repeat(df_tracks["track_index"], n_clusters)
110+
df_clusters = pd.DataFrame({
112111
"track_index": cluster_idx,
113112
"mX": np.random.normal(0, 10, len(cluster_idx)),
114113
"mY": np.random.normal(0, 10, len(cluster_idx)),
115114
"mZ": np.random.normal(0, 10, len(cluster_idx)),
116115
})
117116

118-
cls.adf_tracks = AliasDataFrame(cls.df_tracks)
119-
cls.adf_clusters = AliasDataFrame(cls.df_clusters)
120-
cls.adf_clusters.register_subframe("T", cls.adf_tracks)
121-
122-
def test_alias_cluster_radius(self):
123-
self.adf_clusters.add_alias("mR", "sqrt(mX**2 + mY**2)")
124-
self.adf_clusters.materialize_all()
125-
expected = np.sqrt(self.adf_clusters.df["mX"]**2 + self.adf_clusters.df["mY"]**2)
126-
pd.testing.assert_series_equal(self.adf_clusters.df["mR"], expected, check_names=False)
117+
self.df_tracks = df_tracks
118+
self.df_clusters = df_clusters
127119

128120
def test_alias_cluster_track_dx(self):
129-
self.adf_clusters.add_alias("mDX", "mX - T.mX")
130-
self.adf_clusters.materialize_all()
131-
merged = self.adf_clusters.df.merge(self.adf_tracks.df, on="track_index", suffixes=("", "_track"))
132-
expected = merged["mX"] - merged["mX_track"]
133-
pd.testing.assert_series_equal(self.adf_clusters.df["mDX"].reset_index(drop=True), expected.reset_index(drop=True), check_names=False)
134-
135-
def test_unregistered_subframe_raises_error(self):
136-
adf_tmp = AliasDataFrame(self.df_clusters)
137-
adf_tmp.add_alias("mDX", "mX - T.mX")
138-
with self.assertRaises(NameError):
139-
adf_tmp.materialize_all()
121+
adf_clusters = AliasDataFrame(self.df_clusters.copy())
122+
adf_tracks = AliasDataFrame(self.df_tracks.copy())
123+
adf_clusters.register_subframe("T", adf_tracks, index_columns="track_index")
124+
adf_clusters.add_alias("mDX", "mX - T.mX")
125+
adf_clusters.materialize_all()
126+
merged = adf_clusters.df.merge(adf_tracks.df, on="track_index", suffixes=("", "_trk"))
127+
expected = merged["mX"] - merged["mX_trk"]
128+
pd.testing.assert_series_equal(adf_clusters.df["mDX"].reset_index(drop=True), expected.reset_index(drop=True), check_names=False)
129+
130+
def test_subframe_invalid_alias_raises(self):
131+
adf_clusters = AliasDataFrame(self.df_clusters.copy())
132+
adf_tracks = AliasDataFrame(self.df_tracks.copy())
133+
adf_clusters.register_subframe("T", adf_tracks, index_columns="track_index")
134+
adf_clusters.add_alias("invalid", "T.nonexistent")
135+
136+
with self.assertRaises(KeyError) as cm:
137+
adf_clusters.materialize_alias("invalid")
138+
139+
self.assertIn("T", str(cm.exception))
140+
self.assertIn("nonexistent", str(cm.exception))
140141

141142
def test_save_and_load_integrity(self):
142-
import tempfile
143+
adf_clusters = AliasDataFrame(self.df_clusters.copy())
144+
adf_tracks = AliasDataFrame(self.df_tracks.copy())
145+
adf_clusters.register_subframe("T", adf_tracks, index_columns="track_index")
146+
adf_clusters.add_alias("mDX", "mX - T.mX")
147+
adf_clusters.materialize_all()
148+
143149
with tempfile.TemporaryDirectory() as tmpdir:
144150
path_clusters = os.path.join(tmpdir, "clusters.parquet")
145151
path_tracks = os.path.join(tmpdir, "tracks.parquet")
146-
self.adf_clusters.save(path_clusters)
147-
self.adf_tracks.save(path_tracks)
152+
adf_clusters.save(path_clusters)
153+
adf_tracks.save(path_tracks)
148154

149155
adf_tracks_loaded = AliasDataFrame.load(path_tracks)
150156
adf_clusters_loaded = AliasDataFrame.load(path_clusters)
151-
adf_clusters_loaded.register_subframe("T", adf_tracks_loaded)
157+
adf_clusters_loaded.register_subframe("T", adf_tracks_loaded, index_columns="track_index")
152158
adf_clusters_loaded.add_alias("mDX", "mX - T.mX")
153159
adf_clusters_loaded.materialize_all()
154160

155-
assert "mDX" in adf_clusters_loaded.df.columns
156-
mean_diff = np.mean(adf_clusters_loaded.df["mDX"] - self.adf_clusters.df["mDX"])
157-
assert abs(mean_diff) < 1e-3, f"Mean difference too large: {mean_diff}"
158-
self.assertDictEqual(self.adf_clusters.aliases, adf_clusters_loaded.aliases)
159-
160-
def test_export_tree_read_tree_with_subframe(self):
161-
with tempfile.NamedTemporaryFile(suffix=".root", delete=False) as tmp:
162-
self.adf_clusters.export_tree(tmp.name, treename="clusters")
163-
tmp_path = tmp.name
164-
165-
adf_loaded = AliasDataFrame.read_tree(tmp_path, treename="clusters")
166-
self.assertIn("T", adf_loaded._subframes.subframes)
167-
self.assertTrue(isinstance(adf_loaded.get_subframe("T"), AliasDataFrame))
168-
os.remove(tmp_path)
161+
self.assertIn("mDX", adf_clusters_loaded.df.columns)
162+
merged = adf_clusters_loaded.df.merge(adf_tracks_loaded.df, on="track_index", suffixes=("", "_trk"))
163+
expected = merged["mX"] - merged["mX_trk"]
164+
pd.testing.assert_series_equal(adf_clusters_loaded.df["mDX"].reset_index(drop=True), expected.reset_index(drop=True), check_names=False)
165+
self.assertDictEqual(adf_clusters.aliases, adf_clusters_loaded.aliases)
169166

170167
if __name__ == "__main__":
171168
unittest.main()

0 commit comments

Comments
 (0)