11import unittest
22import pandas as pd
33import numpy as np
4- from AliasDataFrame import AliasDataFrame # Adjust this if you're using a different import method
4+ import os
5+ from AliasDataFrame import AliasDataFrame # Adjust if needed
56
67class TestAliasDataFrame (unittest .TestCase ):
78 def setUp (self ):
@@ -14,36 +15,94 @@ def setUp(self):
1415
1516 def test_basic_alias (self ):
1617 self .adf .add_alias ("z" , "x + y" )
17- self .adf .materialize_alias ( "z" )
18+ self .adf .materialize_all ( )
1819 expected = self .adf .df ["x" ] + self .adf .df ["y" ]
1920 pd .testing .assert_series_equal (self .adf .df ["z" ], expected , check_names = False )
2021
2122 def test_dtype (self ):
2223 self .adf .add_alias ("z" , "x + y" , dtype = np .float16 )
23- self .adf .materialize_alias ( "z" )
24+ self .adf .materialize_all ( )
2425 self .assertEqual (self .adf .df ["z" ].dtype , np .float16 )
2526
2627 def test_constant (self ):
2728 self .adf .add_alias ("c" , "42.0" , dtype = np .float32 , is_constant = True )
2829 self .adf .add_alias ("z" , "x + c" )
29- self .adf .materialize_alias ( "z" )
30+ self .adf .materialize_all ( )
3031 expected = self .adf .df ["x" ] + 42.0
3132 pd .testing .assert_series_equal (self .adf .df ["z" ], expected , check_names = False )
3233
3334 def test_dependency_order (self ):
3435 self .adf .add_alias ("a" , "x + y" )
3536 self .adf .add_alias ("b" , "a * 2" )
36- self .adf .materialize_alias ( "b" )
37+ self .adf .materialize_all ( )
3738 expected = (self .adf .df ["x" ] + self .adf .df ["y" ]) * 2
3839 pd .testing .assert_series_equal (self .adf .df ["b" ], expected , check_names = False )
3940
4041 def test_log_rate_with_constant (self ):
4142 median = self .adf .df ["CTPLumi_countsFV0" ].median ()
4243 self .adf .add_alias ("countsFV0_median" , f"{ median } " , dtype = np .float16 , is_constant = True )
4344 self .adf .add_alias ("logRate" , "log(CTPLumi_countsFV0/countsFV0_median)" , dtype = np .float16 )
44- self .adf .materialize_alias ( "logRate" )
45+ self .adf .materialize_all ( )
4546 expected = np .log (self .adf .df ["CTPLumi_countsFV0" ] / median ).astype (np .float16 )
4647 pd .testing .assert_series_equal (self .adf .df ["logRate" ], expected , check_names = False )
4748
49+ class TestAliasDataFrameWithSubframes (unittest .TestCase ):
50+ @classmethod
51+ def setUpClass (cls ):
52+ n_tracks = 1000
53+ n_clusters = 100
54+ cls .df_tracks = pd .DataFrame ({
55+ "track_index" : np .arange (n_tracks ),
56+ "mX" : np .random .normal (0 , 10 , n_tracks ),
57+ "mY" : np .random .normal (0 , 10 , n_tracks ),
58+ "mZ" : np .random .normal (0 , 10 , n_tracks ),
59+ "mPt" : np .random .exponential (1.0 , n_tracks ),
60+ "mEta" : np .random .normal (0 , 1 , n_tracks ),
61+ })
62+
63+ cluster_idx = np .repeat (cls .df_tracks ["track_index" ], n_clusters )
64+ cls .df_clusters = pd .DataFrame ({
65+ "track_index" : cluster_idx ,
66+ "mX" : np .random .normal (0 , 10 , len (cluster_idx )),
67+ "mY" : np .random .normal (0 , 10 , len (cluster_idx )),
68+ "mZ" : np .random .normal (0 , 10 , len (cluster_idx )),
69+ })
70+
71+ cls .adf_tracks = AliasDataFrame (cls .df_tracks )
72+ cls .adf_clusters = AliasDataFrame (cls .df_clusters )
73+ cls .adf_clusters .register_subframe ("T" , cls .adf_tracks )
74+
75+ def test_alias_cluster_radius (self ):
76+ self .adf_clusters .add_alias ("mR" , "sqrt(mX**2 + mY**2)" )
77+ self .adf_clusters .materialize_all ()
78+ expected = np .sqrt (self .adf_clusters .df ["mX" ]** 2 + self .adf_clusters .df ["mY" ]** 2 )
79+ pd .testing .assert_series_equal (self .adf_clusters .df ["mR" ], expected , check_names = False )
80+
81+ def test_alias_cluster_track_dx (self ):
82+ self .adf_clusters .add_alias ("mDX" , "mX - T.mX" )
83+ self .adf_clusters .materialize_all ()
84+ merged = self .adf_clusters .df .merge (self .adf_tracks .df , on = "track_index" , suffixes = ("" , "_track" ))
85+ expected = merged ["mX" ] - merged ["mX_track" ]
86+ pd .testing .assert_series_equal (self .adf_clusters .df ["mDX" ].reset_index (drop = True ), expected .reset_index (drop = True ), check_names = False )
87+
88+ def test_save_and_load_integrity (self ):
89+ import tempfile
90+ with tempfile .TemporaryDirectory () as tmpdir :
91+ path_clusters = os .path .join (tmpdir , "clusters.parquet" )
92+ path_tracks = os .path .join (tmpdir , "tracks.parquet" )
93+ self .adf_clusters .save (path_clusters )
94+ self .adf_tracks .save (path_tracks )
95+
96+ adf_tracks_loaded = AliasDataFrame .load (path_tracks )
97+ adf_clusters_loaded = AliasDataFrame .load (path_clusters )
98+ adf_clusters_loaded .register_subframe ("T" , adf_tracks_loaded )
99+ adf_clusters_loaded .add_alias ("mDX" , "mX - T.mX" )
100+ adf_clusters_loaded .materialize_all ()
101+
102+ assert "mDX" in adf_clusters_loaded .df .columns
103+ # Check mean difference is negligible
104+ mean_diff = np .mean (adf_clusters_loaded .df ["mDX" ] - self .adf_clusters .df ["mDX" ])
105+ assert abs (mean_diff ) < 1e-3 , f"Mean difference too large: { mean_diff } "
106+
48107if __name__ == "__main__" :
49108 unittest .main ()
0 commit comments