22from typing import Optional , Type , Literal , Union , Any
33
44from flamesdk import FlameCoreSDK
5+ from flamesdk .resources .utils .constants import LogTypeLiteral
56from flame .star .aggregator_client import Aggregator
67from flame .star .analyzer_client import Analyzer
78from flame .utils .mock_flame_core import MockFlameCoreSDK
@@ -28,6 +29,8 @@ def __init__(self,
2829 simple_analysis : bool = True ,
2930 output_type : Union [Literal ['str' , 'bytes' , 'pickle' ], list ] = 'str' ,
3031 multiple_results : bool = False ,
32+ filename : Optional [Union [str , list [str ]]] = None ,
33+ stream_log_level : int = 20 ,
3134 analyzer_kwargs : Optional [dict ] = None ,
3235 aggregator_kwargs : Optional [dict ] = None ,
3336 test_mode : bool = False ,
@@ -38,27 +41,28 @@ def __init__(self,
3841 self .flame = MockFlameCoreSDK (test_kwargs = test_kwargs )
3942 else :
4043 self .test_kwargs = None
41- self .flame = FlameCoreSDK ()
44+ self .flame = FlameCoreSDK (stream_log_level = stream_log_level )
4245
4346 if self ._is_analyzer ():
4447 self .flame .flame_log (f"Analyzer { test_kwargs ['node_id' ] + ' ' if self .test_mode else '' } started" ,
45- log_type = 'info' )
48+ log_type = LogTypeLiteral . INFO . value )
4649 self ._start_analyzer (analyzer ,
4750 data_type = data_type ,
4851 query = query ,
4952 simple_analysis = simple_analysis ,
5053 analyzer_kwargs = analyzer_kwargs )
5154 elif self ._is_aggregator ():
52- self .flame .flame_log ("Aggregator started" , log_type = 'info' )
55+ self .flame .flame_log ("Aggregator started" , log_type = LogTypeLiteral . INFO . value )
5356 self ._start_aggregator (aggregator ,
5457 simple_analysis = simple_analysis ,
5558 output_type = output_type ,
5659 multiple_results = multiple_results ,
60+ filename = filename ,
5761 aggregator_kwargs = aggregator_kwargs )
5862 else :
5963 raise BrokenPipeError ("Has to be either analyzer or aggregator" )
6064 if not self .test_mode :
61- self .flame .flame_log ("Analysis finished!" , log_type = 'info' )
65+ self .flame .flame_log ("Analysis finished!" , log_type = LogTypeLiteral . INFO . value )
6266 while True :
6367 pass # keep the node alive to allow for orderly shutdown
6468
@@ -73,6 +77,7 @@ def _start_aggregator(self,
7377 simple_analysis : bool = True ,
7478 output_type : Union [Literal ['str' , 'bytes' , 'pickle' ], list ] = 'str' ,
7579 multiple_results : bool = False ,
80+ filename : Optional [Union [str , list [str ]]] = None ,
7681 aggregator_kwargs : Optional [dict ] = None ) -> None :
7782 if issubclass (aggregator , Aggregator ):
7883 # init custom aggregator subclass
@@ -89,21 +94,26 @@ def _start_aggregator(self,
8994
9095 while not aggregator .finished : # (**)
9196 # Await intermediate results
97+ self .flame .flame_log (f"Awaiting intermediate results..." , log_type = LogTypeLiteral .INFO .value )
9298 result_dict = self .flame .await_intermediate_data (analyzers )
9399
94100 # Aggregate results
95101 agg_res , converged = aggregator .aggregate (list (result_dict .values ()), simple_analysis )
96102
97103 if converged :
98104 if not self .test_mode :
99- self .flame .flame_log ("Submitting final results..." , log_type = 'info' , end = '' )
100- response = self .flame .submit_final_result (agg_res , output_type , multiple_results )
105+ self .flame .flame_log ("Submitting final results..." ,
106+ log_type = LogTypeLiteral .INFO .value ,
107+ halt_submission = True )
108+ response = self .flame .submit_final_result (agg_res , output_type , multiple_results ,
109+ filename = filename )
101110 if not self .test_mode :
102- self .flame .flame_log (f"success (response={ response } )" , log_type = 'info' )
111+ self .flame .flame_log (f"success (response={ response } )" , log_type = LogTypeLiteral . INFO . value )
103112 self .flame .analysis_finished ()
104113 aggregator .node_finished () # LOOP BREAK
105114 else :
106115 # Send aggregated result to analyzers
116+ self .flame .flame_log (f"Sending aggregated results..." , log_type = LogTypeLiteral .INFO .value )
107117 self .flame .send_intermediate_data (analyzers , agg_res )
108118 else :
109119 raise BrokenPipeError (_ERROR_MESSAGES .IS_INCORRECT_CLASS .value )
@@ -128,7 +138,6 @@ def _start_analyzer(self,
128138
129139 # Get data
130140 self ._get_data (query = query , data_type = data_type )
131- self .flame .flame_log (f"\t Data extracted: { str (self .data )[:100 ]} " , log_type = 'info' )
132141
133142 # Check converged status on Hub
134143 while not analyzer .finished : # (**)
@@ -151,23 +160,27 @@ def _wait_until_partners_ready(self) -> None:
151160 if self ._is_analyzer ():
152161 aggregator_id = self .flame .get_aggregator_id ()
153162 if not self .test_mode :
154- self .flame .flame_log ("Awaiting contact with aggregator node..." , log_type = 'info' )
163+ self .flame .flame_log ("Awaiting contact with aggregator node..." ,
164+ log_type = LogTypeLiteral .INFO .value )
155165 ready_check_dict = self .flame .ready_check ([aggregator_id ])
156166
157167 if not ready_check_dict [aggregator_id ]:
158168 raise BrokenPipeError ("Could not contact aggregator" )
159169
160170 if not self .test_mode :
161- self .flame .flame_log ("Awaiting contact with aggregator node...success" , log_type = 'info' )
171+ self .flame .flame_log ("Awaiting contact with aggregator node...success" ,
172+ log_type = LogTypeLiteral .INFO .value )
162173 else :
163174 analyzer_ids = self .flame .get_participant_ids ()
164175 if not self .test_mode :
165- self .flame .flame_log ("Awaiting contact with analyzer nodes..." , log_type = 'info' )
176+ self .flame .flame_log ("Awaiting contact with analyzer nodes..." ,
177+ log_type = LogTypeLiteral .INFO .value )
166178 ready_check_dict = self .flame .ready_check (analyzer_ids )
167179 if not all (ready_check_dict .values ()):
168180 raise BrokenPipeError ("Could not contact all analyzers" )
169181 if not self .test_mode :
170- self .flame .flame_log ("Awaiting contact with analyzer nodes...success" , log_type = 'info' )
182+ self .flame .flame_log ("Awaiting contact with analyzer nodes...success" ,
183+ log_type = LogTypeLiteral .INFO .value )
171184
172185 def _get_data (self ,
173186 data_type : Literal ['fhir' , 's3' ],
0 commit comments