Skip to content

Commit f4f4370

Browse files
Release v0.6.5: add SDK 0.6.0 (#19)
* **New Features** * Added configurable log streaming level for better control over logging output. * Introduced optional output file support for saving results to specified file paths. * Enhanced support for handling and saving multiple results to individual files with automatic path generation. * **Bug Fixes** * Improved logging consistency across result submission and node lifecycle events. * **Chores** * Updated package version and dependencies. --------- Co-authored-by: davidhieber <david.hieber@uni-tuebingen.de> Co-authored-by: Nightknight3000 <alexander.roehl@uni-tuebingen.de>
1 parent 7340e32 commit f4f4370

6 files changed

Lines changed: 244 additions & 226 deletions

File tree

flame/star/star_localdp/star_localdp_model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Type, Literal, Union, Any
22

33
from flamesdk import FlameCoreSDK
4+
from flamesdk.resources.utils.constants import LogTypeLiteral
45
from flame.star.aggregator_client import Aggregator
56
from flame.star.analyzer_client import Analyzer
67
from flame.star.star_model import StarModel, _ERROR_MESSAGES
@@ -24,6 +25,8 @@ def __init__(self,
2425
simple_analysis: bool = True,
2526
output_type: Union[Literal['str', 'bytes', 'pickle'], list] = 'str',
2627
multiple_results: bool = False,
28+
filename: Optional[Union[str, list[str]]] = None,
29+
stream_log_level: int = 20,
2730
analyzer_kwargs: Optional[dict] = None,
2831
aggregator_kwargs: Optional[dict] = None,
2932
epsilon: Optional[float] = None,
@@ -39,6 +42,8 @@ def __init__(self,
3942
simple_analysis=simple_analysis,
4043
output_type=output_type,
4144
multiple_results=multiple_results,
45+
filename=filename,
46+
stream_log_level=stream_log_level,
4247
analyzer_kwargs=analyzer_kwargs,
4348
aggregator_kwargs=aggregator_kwargs,
4449
test_mode=test_mode,
@@ -49,6 +54,7 @@ def _start_aggregator(self,
4954
simple_analysis: bool = True,
5055
output_type: Union[Literal['str', 'bytes', 'pickle'], list] = 'str',
5156
multiple_results: bool = False,
57+
filename: Optional[Union[str, list[str]]] = None,
5258
aggregator_kwargs: Optional[dict] = None) -> None:
5359
if issubclass(aggregator, Aggregator):
5460
# init custom aggregator subclass
@@ -69,24 +75,27 @@ def _start_aggregator(self,
6975

7076
# Aggregate results
7177
agg_res, converged = aggregator.aggregate(list(result_dict.values()), simple_analysis)
72-
self.flame.flame_log(f"Aggregated results: {str(agg_res)[:100]}")
7378

7479
if converged:
7580
if not self.test_mode:
7681
self.flame.flame_log("Submitting final results using differential privacy...",
77-
log_type='info',
78-
end='')
82+
log_type=LogTypeLiteral.INFO.value,
83+
halt_submission=True)
7984
if aggregator.delta_criteria and (self.epsilon is not None) and (self.sensitivity is not None):
8085
local_dp = {"epsilon": self.epsilon, "sensitivity": self.sensitivity}
8186
else:
8287
local_dp = None
8388
if self.test_mode and (local_dp is not None):
8489
self.flame.flame_log(f"\tTest mode: Would apply local DP with epsilon={local_dp['epsilon']} "
8590
f"and sensitivity={local_dp['sensitivity']}",
86-
log_type='info')
87-
response = self.flame.submit_final_result(agg_res, output_type, multiple_results, local_dp=local_dp)
91+
log_type=LogTypeLiteral.INFO.value)
92+
response = self.flame.submit_final_result(agg_res,
93+
output_type,
94+
multiple_results,
95+
local_dp=local_dp,
96+
filename=filename)
8897
if not self.test_mode:
89-
self.flame.flame_log(f"success (response={response})", log_type='info')
98+
self.flame.flame_log(f"success (response={response})", log_type=LogTypeLiteral.INFO.value)
9099
self.flame.analysis_finished()
91100
aggregator.node_finished() # LOOP BREAK
92101
else:

flame/star/star_model.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Type, Literal, Union, Any
33

44
from flamesdk import FlameCoreSDK
5+
from flamesdk.resources.utils.constants import LogTypeLiteral
56
from flame.star.aggregator_client import Aggregator
67
from flame.star.analyzer_client import Analyzer
78
from flame.utils.mock_flame_core import MockFlameCoreSDK
@@ -28,6 +29,8 @@ def __init__(self,
2829
simple_analysis: bool = True,
2930
output_type: Union[Literal['str', 'bytes', 'pickle'], list] = 'str',
3031
multiple_results: bool = False,
32+
filename: Optional[Union[str, list[str]]] = None,
33+
stream_log_level: int = 20,
3134
analyzer_kwargs: Optional[dict] = None,
3235
aggregator_kwargs: Optional[dict] = None,
3336
test_mode: bool = False,
@@ -38,27 +41,28 @@ def __init__(self,
3841
self.flame = MockFlameCoreSDK(test_kwargs=test_kwargs)
3942
else:
4043
self.test_kwargs = None
41-
self.flame = FlameCoreSDK()
44+
self.flame = FlameCoreSDK(stream_log_level=stream_log_level)
4245

4346
if self._is_analyzer():
4447
self.flame.flame_log(f"Analyzer {test_kwargs['node_id'] + ' ' if self.test_mode else ''}started",
45-
log_type='info')
48+
log_type=LogTypeLiteral.INFO.value)
4649
self._start_analyzer(analyzer,
4750
data_type=data_type,
4851
query=query,
4952
simple_analysis=simple_analysis,
5053
analyzer_kwargs=analyzer_kwargs)
5154
elif self._is_aggregator():
52-
self.flame.flame_log("Aggregator started", log_type='info')
55+
self.flame.flame_log("Aggregator started", log_type=LogTypeLiteral.INFO.value)
5356
self._start_aggregator(aggregator,
5457
simple_analysis=simple_analysis,
5558
output_type=output_type,
5659
multiple_results=multiple_results,
60+
filename=filename,
5761
aggregator_kwargs=aggregator_kwargs)
5862
else:
5963
raise BrokenPipeError("Has to be either analyzer or aggregator")
6064
if not self.test_mode:
61-
self.flame.flame_log("Analysis finished!", log_type='info')
65+
self.flame.flame_log("Analysis finished!", log_type=LogTypeLiteral.INFO.value)
6266
while True:
6367
pass # keep the node alive to allow for orderly shutdown
6468

@@ -73,6 +77,7 @@ def _start_aggregator(self,
7377
simple_analysis: bool = True,
7478
output_type: Union[Literal['str', 'bytes', 'pickle'], list] = 'str',
7579
multiple_results: bool = False,
80+
filename: Optional[Union[str, list[str]]] = None,
7681
aggregator_kwargs: Optional[dict] = None) -> None:
7782
if issubclass(aggregator, Aggregator):
7883
# init custom aggregator subclass
@@ -89,21 +94,26 @@ def _start_aggregator(self,
8994

9095
while not aggregator.finished: # (**)
9196
# Await intermediate results
97+
self.flame.flame_log(f"Awaiting intermediate results...", log_type=LogTypeLiteral.INFO.value)
9298
result_dict = self.flame.await_intermediate_data(analyzers)
9399

94100
# Aggregate results
95101
agg_res, converged = aggregator.aggregate(list(result_dict.values()), simple_analysis)
96102

97103
if converged:
98104
if not self.test_mode:
99-
self.flame.flame_log("Submitting final results...", log_type='info', end='')
100-
response = self.flame.submit_final_result(agg_res, output_type, multiple_results)
105+
self.flame.flame_log("Submitting final results...",
106+
log_type=LogTypeLiteral.INFO.value,
107+
halt_submission=True)
108+
response = self.flame.submit_final_result(agg_res, output_type, multiple_results,
109+
filename=filename)
101110
if not self.test_mode:
102-
self.flame.flame_log(f"success (response={response})", log_type='info')
111+
self.flame.flame_log(f"success (response={response})", log_type=LogTypeLiteral.INFO.value)
103112
self.flame.analysis_finished()
104113
aggregator.node_finished() # LOOP BREAK
105114
else:
106115
# Send aggregated result to analyzers
116+
self.flame.flame_log(f"Sending aggregated results...", log_type=LogTypeLiteral.INFO.value)
107117
self.flame.send_intermediate_data(analyzers, agg_res)
108118
else:
109119
raise BrokenPipeError(_ERROR_MESSAGES.IS_INCORRECT_CLASS.value)
@@ -128,7 +138,6 @@ def _start_analyzer(self,
128138

129139
# Get data
130140
self._get_data(query=query, data_type=data_type)
131-
self.flame.flame_log(f"\tData extracted: {str(self.data)[:100]}", log_type='info')
132141

133142
# Check converged status on Hub
134143
while not analyzer.finished: # (**)
@@ -151,23 +160,27 @@ def _wait_until_partners_ready(self) -> None:
151160
if self._is_analyzer():
152161
aggregator_id = self.flame.get_aggregator_id()
153162
if not self.test_mode:
154-
self.flame.flame_log("Awaiting contact with aggregator node...", log_type='info')
163+
self.flame.flame_log("Awaiting contact with aggregator node...",
164+
log_type=LogTypeLiteral.INFO.value)
155165
ready_check_dict = self.flame.ready_check([aggregator_id])
156166

157167
if not ready_check_dict[aggregator_id]:
158168
raise BrokenPipeError("Could not contact aggregator")
159169

160170
if not self.test_mode:
161-
self.flame.flame_log("Awaiting contact with aggregator node...success", log_type='info')
171+
self.flame.flame_log("Awaiting contact with aggregator node...success",
172+
log_type=LogTypeLiteral.INFO.value)
162173
else:
163174
analyzer_ids = self.flame.get_participant_ids()
164175
if not self.test_mode:
165-
self.flame.flame_log("Awaiting contact with analyzer nodes...", log_type='info')
176+
self.flame.flame_log("Awaiting contact with analyzer nodes...",
177+
log_type=LogTypeLiteral.INFO.value)
166178
ready_check_dict = self.flame.ready_check(analyzer_ids)
167179
if not all(ready_check_dict.values()):
168180
raise BrokenPipeError("Could not contact all analyzers")
169181
if not self.test_mode:
170-
self.flame.flame_log("Awaiting contact with analyzer nodes...success", log_type='info')
182+
self.flame.flame_log("Awaiting contact with analyzer nodes...success",
183+
log_type=LogTypeLiteral.INFO.value)
171184

172185
def _get_data(self,
173186
data_type: Literal['fhir', 's3'],

flame/star/star_model_tester.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ def __init__(self,
1919
simple_analysis: bool = True,
2020
output_type: Union[Literal['str', 'bytes', 'pickle'], list] = 'str',
2121
multiple_results: bool = False,
22+
filename: Optional[Union[str, list[str]]] = None,
23+
stream_log_level: int = 20,
2224
analyzer_kwargs: Optional[dict] = None,
2325
aggregator_kwargs: Optional[dict] = None,
2426
epsilon: Optional[float] = None,
25-
sensitivity: Optional[float] = None,
26-
result_filepath: Optional[Union[str, list[str]]] = None) -> None:
27+
sensitivity: Optional[float] = None) -> None:
2728
num_splits = len(data_splits)
2829
self.test_input(data_splits[0])
2930
participants = []
@@ -52,6 +53,7 @@ def __init__(self,
5253
'simple_analysis': simple_analysis,
5354
'output_type': output_type,
5455
'multiple_results': multiple_results,
56+
'stream_log_level': stream_log_level,
5557
'analyzer_kwargs': analyzer_kwargs,
5658
'aggregator_kwargs': aggregator_kwargs,
5759
'test_mode': True,
@@ -75,7 +77,8 @@ def run_node(kwargs=test_kwargs, use_dp=use_local_dp):
7577
flame = StarModel(**kwargs).flame
7678
else:
7779
flame = StarLocalDPModel(**kwargs).flame
78-
results_queue.append(flame.final_results_storage)
80+
if kwargs['test_kwargs']['role'] == 'aggregator':
81+
results_queue.append(flame.final_results_storage)
7982
except Exception:
8083
stop_event = MockFlameCoreSDK.stop_event
8184
if not stop_event:
@@ -103,7 +106,7 @@ def run_node(kwargs=test_kwargs, use_dp=use_local_dp):
103106

104107
# write final results
105108
if results_queue:
106-
self.write_result(results_queue[0], output_type, result_filepath, multiple_results)
109+
self.write_result(results_queue[0], output_type, filename, multiple_results)
107110
else:
108111
print("No results to write. All threads failed with errors:")
109112
for (role, node_id), error in thread_errors.items():
@@ -147,12 +150,12 @@ def test_input(data: Any) -> None:
147150
@staticmethod
148151
def write_result(result: Any,
149152
output_type: Union[Literal['str', 'bytes', 'pickle'], list],
150-
result_filepath: Optional[Union[str, list[str]]] = None,
153+
filename: Optional[Union[str, list[str]]] = None,
151154
multiple_results: bool = False) -> None:
152155
if multiple_results:
153156
if isinstance(result, list) or isinstance(result, tuple):
154-
if isinstance(result_filepath, list) and (len(result_filepath) != len(result)):
155-
print(f"Warning! Inconsistent number of result_filepaths (len={result_filepath}) "
157+
if isinstance(filename, list) and (len(filename) != len(result)):
158+
print(f"Warning! Inconsistent number of filenames (len={filename}) "
156159
f"and results (len={len(result)}) -> multiple_results will be ignored.")
157160
multi_iterable_results = False
158161
else:
@@ -164,20 +167,21 @@ def write_result(result: Any,
164167
else:
165168
multi_iterable_results = False
166169

167-
if result_filepath is not None:
170+
if filename is not None:
168171
if not multi_iterable_results:
169172
result = [result]
170-
result_filepath = [result_filepath]
173+
if not isinstance(filename, list):
174+
filename = [filename]
171175

172176
for i, res in enumerate(result):
173-
if isinstance(result_filepath, list):
174-
current_path = result_filepath[i]
177+
if isinstance(filename, list):
178+
current_path = filename[i]
175179
else:
176-
if '.' in result_filepath:
177-
result_filename, result_extension = result_filepath.rsplit('.', 1)
180+
if '.' in filename:
181+
result_filename, result_extension = filename.rsplit('.', 1)
178182
current_path = f"{result_filename}_{i + 1}.{result_extension}"
179183
else:
180-
current_path = f"{result_filepath}_{i + 1}"
184+
current_path = f"{filename}_{i + 1}"
181185
if isinstance(output_type, list) and (len(output_type) == len(result)):
182186
out_type = output_type[i]
183187
else:

0 commit comments

Comments
 (0)