Skip to content

Commit 74e9c71

Browse files
authored
Bump flame-sdk version and update project dependencies (#13)
- Bump flame-sdk dependency to version 0.4.2 - Update node_finished method to log completion and add orderly shutdown loop - Reset shared stop event before thread start for failure handling - Enable parallelization and collective termination in local testing - Remove unused variable in star_model.py - Update README with project title and description - Fix typos in comments Co-authored-by: Nightknight3000 <alexander.roehl@uni-tuebingen.de> Co-authored-by: antidodo <albin2993@gmail.com>
1 parent 41d31f0 commit 74e9c71

File tree

4 files changed

+88
-43
lines changed

4 files changed

+88
-43
lines changed

flame/star/star_model_tester.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import threading
33
import uuid
44
from typing import Any, Type, Literal, Optional, Union
5+
import traceback
56

67
from flame.star import StarModel, StarLocalDPModel, StarAnalyzer, StarAggregator
8+
from flame.utils.mock_flame_core import MockFlameCoreSDK
79

810

911
class StarModelTester:
@@ -28,6 +30,9 @@ def __init__(self,
2830
participant_ids = [str(uuid.uuid4()) for _ in range(len(node_roles) + 1)]
2931

3032
threads = []
33+
thread_errors = {}
34+
results_queue = []
35+
MockFlameCoreSDK.stop_event = [] # shared stop event for all threads in case of failure in any thread
3136
for i, participant_id in enumerate(participant_ids):
3237
test_kwargs = {
3338
'analyzer': analyzer,
@@ -54,13 +59,28 @@ def __init__(self,
5459
test_kwargs['epsilon'] = epsilon
5560
test_kwargs['sensitivity'] = sensitivity
5661

57-
results_queue = []
5862
def run_node(kwargs=test_kwargs, use_dp=use_local_dp):
59-
if not use_dp:
60-
flame = StarModel(**kwargs).flame
61-
else:
62-
flame = StarLocalDPModel(**kwargs).flame
63-
results_queue.append(flame.final_results_storage)
63+
try:
64+
if not use_dp:
65+
flame = StarModel(**kwargs).flame
66+
else:
67+
flame = StarLocalDPModel(**kwargs).flame
68+
results_queue.append(flame.final_results_storage)
69+
except Exception:
70+
stop_event = MockFlameCoreSDK.stop_event
71+
if not stop_event:
72+
stack_trace = traceback.format_exc()#.replace('\n', '\\n').replace('\t', '\\t')
73+
thread_errors[(kwargs['test_kwargs']['role'],
74+
kwargs['test_kwargs']['node_id'])] = f"\033[31m{stack_trace}\033[0m"
75+
stop_event.append(kwargs['test_kwargs']['node_id'])
76+
mock = MockFlameCoreSDK(test_kwargs=kwargs['test_kwargs'])
77+
mock.__pop_logs__(failure_message=True)
78+
else:
79+
thread_errors[(kwargs['test_kwargs']['role'],
80+
kwargs['test_kwargs']['node_id'])] = (Exception("Another thread already failed, "
81+
"stopping this thread as well."))
82+
return
83+
6484
thread = threading.Thread(target=run_node)
6585
threads.append(thread)
6686

@@ -70,8 +90,14 @@ def run_node(kwargs=test_kwargs, use_dp=use_local_dp):
7090
for thread in threads:
7191
thread.join()
7292

93+
7394
# write final results
74-
self.write_result(results_queue[0], output_type, result_filepath, multiple_results)
95+
if results_queue:
96+
self.write_result(results_queue[0], output_type, result_filepath, multiple_results)
97+
else:
98+
print("No results to write. All threads failed with errors:")
99+
for (role, node_id), error in thread_errors.items():
100+
print(f"\t{(role if role != 'default' else 'analyzer').capitalize()} {node_id}: {error}")
75101

76102

77103
@staticmethod

flame/utils/mock_flame_core.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,23 @@ def __init__(self, test_kwargs) -> None:
4444
self.finished: bool = False
4545

4646

47+
class IterationTracker:
48+
def __init__(self):
49+
self.iter = 0
50+
51+
def increment(self):
52+
self.iter += 1
53+
54+
def get_iterations(self):
55+
return self.iter
56+
57+
4758
class MockFlameCoreSDK:
48-
num_iterations: int = 0
59+
num_iterations: IterationTracker = IterationTracker()
4960
logger: dict[str, list[str]] = {}
5061
message_broker: dict[str, list[dict[str, Any]]] = {}
5162
final_results_storage: Optional[Any] = None
63+
stop_event: list[tuple[str]] = []
5264

5365
def __init__(self, test_kwargs):
5466
self.sanity_check(test_kwargs)
@@ -202,6 +214,8 @@ def await_messages(self,
202214
break
203215
raise KeyError
204216
except KeyError:
217+
if self.stop_event:
218+
raise Exception
205219
time.sleep(.01)
206220
pass
207221

@@ -323,12 +337,17 @@ def _node_finished(self) -> bool:
323337
self.config.finished = True
324338
return self.config.finished
325339

326-
def __pop_logs__(self) -> None:
327-
print(f"--- Starting Iteration {self.num_iterations} ---")
340+
def __pop_logs__(self, failure_message: bool = False) -> None:
341+
print(f"--- Starting Iteration {self.__get_iteration__()} ---")
342+
if failure_message:
343+
self.flame_log("Exception was raised (see Stacktrace)!", log_type='error')
328344
for k, v in self.logger.items():
329345
role, log = self.logger[k]
330346
print(f"Logs for {'Analyzer' if role == 'default' else role.capitalize()} {k}:")
331347
self.logger[k] = [role, '']
332348
print(log, end='')
333-
print(f"--- Ending Iteration {self.num_iterations} ---\n")
334-
self.num_iterations += 1
349+
print(f"--- Ending Iteration {self.__get_iteration__()} ---\n")
350+
self.num_iterations.increment()
351+
352+
def __get_iteration__(self):
353+
return self.num_iterations.get_iterations()

poetry.lock

Lines changed: 29 additions & 29 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "flame"
3-
version = "0.6.0"
3+
version = "0.6.1"
44
description = ""
55
authors = ["Alexander Röhl <alexander.roehl@uni-tuebingen.de>", "David Hieber <david.hieber@uni-tuebingen.de>"]
66
readme = "README.md"
@@ -9,7 +9,7 @@ packages = [{ include = "flame" }]
99

1010
[tool.poetry.dependencies]
1111
python = ">=3.9,<4.0"
12-
flamesdk = {git = "https://github.com/PrivateAIM/python-sdk.git", tag = "0.4.1"}
12+
flamesdk = {git = "https://github.com/PrivateAIM/python-sdk.git", tag = "0.4.2"}
1313
opendp = ">=0.12.1,<0.13.0"
1414

1515
[tool.poetry.group.dev.dependencies]

0 commit comments

Comments
 (0)