Skip to content

Commit 63239c0

Browse files
committed
Refactor code for improved type safety and null handling
1 parent 138b32d commit 63239c0

22 files changed

+197
-103
lines changed

adf_core_python/cli/template/src/team_name/module/complex/sample_road_detector.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
)
3838

3939
self.register_sub_module(self._path_planning)
40-
self._result = None
40+
self._result: Optional[EntityID] = None
4141

4242
def precompute(self, precompute_data: PrecomputeData) -> RoadDetector:
4343
super().precompute(precompute_data)
@@ -108,7 +108,7 @@ def update_info(self, message_manager: MessageManager) -> RoadDetector:
108108
if isinstance(entity, Building):
109109
self._result = None
110110
elif isinstance(entity, Road):
111-
road: Road = cast(Road, entity)
111+
road = entity
112112
if road.get_blockades() == []:
113113
self._target_areas.remove(self._result)
114114
self._result = None
@@ -117,7 +117,9 @@ def update_info(self, message_manager: MessageManager) -> RoadDetector:
117117

118118
def calculate(self) -> RoadDetector:
119119
if self._result is None:
120-
position_entity_id: EntityID = self._agent_info.get_position_entity_id()
120+
position_entity_id = self._agent_info.get_position_entity_id()
121+
if position_entity_id is None:
122+
return self
121123
if position_entity_id in self._target_areas:
122124
self._result = position_entity_id
123125
return self
@@ -128,21 +130,22 @@ def calculate(self) -> RoadDetector:
128130

129131
self._priority_roads = self._priority_roads - set(remove_list)
130132
if len(self._priority_roads) > 0:
131-
_nearest_target_area = self._agent_info.get_position_entity_id()
133+
agent_position = self._agent_info.get_position_entity_id()
134+
if agent_position is None:
135+
return self
136+
_nearest_target_area = agent_position
132137
_nearest_distance = float("inf")
133138
for target_area in self._target_areas:
134139
if (
135-
self._world_info.get_distance(
136-
self._agent_info.get_position_entity_id(), target_area
137-
)
140+
self._world_info.get_distance(agent_position, target_area)
138141
< _nearest_distance
139142
):
140143
_nearest_target_area = target_area
141144
_nearest_distance = self._world_info.get_distance(
142-
self._agent_info.get_position_entity_id(), target_area
145+
agent_position, target_area
143146
)
144147
path: list[EntityID] = self._path_planning.get_path(
145-
self._agent_info.get_position_entity_id(), _nearest_target_area
148+
agent_position, _nearest_target_area
146149
)
147150
if path is not None and len(path) > 0:
148151
self._result = path[-1]

adf_core_python/cli/template/src/team_name/module/complex/sample_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def update_info(self, message_manager: MessageManager) -> Search:
6464
)
6565

6666
searched_building_id = self._agent_info.get_position_entity_id()
67-
self._unreached_building_ids.discard(searched_building_id)
67+
if searched_building_id is not None:
68+
self._unreached_building_ids.discard(searched_building_id)
6869

6970
if len(self._unreached_building_ids) == 0:
7071
self._unreached_building_ids = self._get_search_targets()

adf_core_python/core/agent/info/world_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, cast
1+
from typing import Any, Optional
22

33
from rcrscore.entities import EntityID
44
from rcrscore.entities.area import Area
@@ -199,7 +199,7 @@ def get_blockades(self, area: Area) -> set[Blockade]:
199199
for blockade_entity_id in blockade_entity_ids:
200200
blockades_entity = self.get_entity(blockade_entity_id)
201201
if isinstance(blockades_entity, Blockade):
202-
blockades.add(cast(Blockade, blockades_entity))
202+
blockades.add(blockades_entity)
203203
return blockades
204204

205205
def add_entity(self, entity: Entity) -> None:

adf_core_python/core/gateway/component/module/algorithm/gateway_clustering.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def get_cluster_entities(self, cluster_index: int) -> list[Entity]:
7878
entity_ids: list[int] = json.loads(json_str)
7979
entities: list[Entity] = []
8080
for entity_id in entity_ids:
81-
entities.append(self._world_info.get_entity(EntityID(entity_id)))
81+
entity = self._world_info.get_entity(EntityID(entity_id))
82+
if entity is not None:
83+
entities.append(entity)
8284
return entities
8385

8486
def get_cluster_entity_ids(self, cluster_index: int) -> list[EntityID]:

adf_core_python/core/gateway/component/module/complex/gateway_target_allocator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ def calculate(self) -> GatewayTargetAllocator:
6363

6464
def get_result(self) -> dict[EntityID, EntityID]:
6565
response = self._gateway_module.execute("getResult")
66-
response_keys = response.get_all_keys()
66+
response_keys = response.data.keys()
6767
result: dict[EntityID, EntityID] = {}
6868
for key in response_keys:
69+
value = response.get_value(key)
6970
result[EntityID(int(key))] = EntityID(
70-
int(response.get_value_or_default(key, "-1"))
71+
int(value if value is not None else "-1")
7172
)
7273

7374
return result

adf_core_python/core/gateway/gateway_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def initialize(self, module_name: str, default_class_name: str) -> str:
5454
return self.get_gateway_class_name()
5555

5656
def get_execute_response(self) -> Config:
57+
if self._result is None:
58+
raise RuntimeError("No execution result available")
5759
return self._result
5860

5961
def set_execute_response(self, result: Config) -> None:

adf_core_python/implement/action/default_extend_action_clear.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
)
6060
)
6161

62-
self._target_entity_id = None
62+
self._target_entity_id: Optional[EntityID] = None
6363
self._move_point_cache: dict[EntityID, Optional[set[tuple[float, float]]]] = {}
6464
self._old_clear_x = 0
6565
self._old_clear_y = 0
@@ -135,14 +135,14 @@ def calculate(self) -> ExtendAction:
135135
return self
136136

137137
agent_position_entity_id = police_force.get_position()
138+
if agent_position_entity_id is None:
139+
return self
138140
target_entity = self.world_info.get_entity(self._target_entity_id)
139141
position_entity = self.world_info.get_entity(agent_position_entity_id)
140142
if target_entity is None or isinstance(target_entity, Area) is False:
141143
return self
142144
if isinstance(position_entity, Road):
143-
self.result = self._get_rescue_action(
144-
police_force, cast(Road, position_entity)
145-
)
145+
self.result = self._get_rescue_action(police_force, position_entity)
146146
if self.result is not None:
147147
return self
148148

@@ -180,7 +180,7 @@ def calculate(self) -> ExtendAction:
180180
police_force, cast(Area, entity)
181181
)
182182
if self.result is not None and isinstance(self.result, ActionMove):
183-
action_move = cast(ActionMove, self.result)
183+
action_move = self.result
184184
if action_move.is_destination_defined():
185185
self.result = None
186186

@@ -193,15 +193,16 @@ def _need_rest(self, police_force: PoliceForce) -> bool:
193193
hp = police_force.get_hp()
194194
damage = police_force.get_damage()
195195

196-
if hp == 0 or damage == 0:
196+
if hp is None or damage is None or hp == 0 or damage == 0:
197197
return False
198198

199199
active_time = (hp / damage) + (1 if (hp % damage) != 0 else 0)
200200
if self._kernel_time == -1:
201201
self._kernel_time = self.scenario_info.get_value("kernel.timesteps", -1)
202202

203-
return damage >= self._threshold_rest or (
204-
active_time + self.agent_info.get_time() < self._kernel_time
203+
return damage is not None and (
204+
damage >= self._threshold_rest
205+
or (active_time + self.agent_info.get_time() < self._kernel_time)
205206
)
206207

207208
def _calc_rest(
@@ -211,6 +212,8 @@ def _calc_rest(
211212
target_entity_ids: list[EntityID],
212213
) -> Optional[Action]:
213214
position_entity_id = police_force.get_position()
215+
if position_entity_id is None:
216+
return None
214217
refuges = self.world_info.get_entity_ids_of_types([Refuge])
215218
current_size = len(refuges)
216219
if position_entity_id in refuges:
@@ -244,12 +247,13 @@ def _calc_rest(
244247
def _get_rescue_action(
245248
self, police_entity: PoliceForce, road: Road
246249
) -> Optional[Action]:
250+
road_blockades = road.get_blockades()
247251
blockades = set(
248252
[]
249-
if road.get_blockades() is None
253+
if road_blockades is None
250254
else [
251255
cast(Blockade, self.world_info.get_entity(blockade_entity_id))
252-
for blockade_entity_id in road.get_blockades()
256+
for blockade_entity_id in road_blockades
253257
]
254258
)
255259
agent_entities = set(
@@ -263,15 +267,30 @@ def _get_rescue_action(
263267

264268
for agent_entity in agent_entities:
265269
human = cast(Human, agent_entity)
266-
if human.get_position().get_value() != road.get_entity_id().get_value():
270+
human_position = human.get_position()
271+
if (
272+
human_position is None
273+
or human_position.get_value() != road.get_entity_id().get_value()
274+
):
267275
continue
268276

269277
human_x = human.get_x()
270278
human_y = human.get_y()
279+
if (
280+
human_x is None
281+
or human_y is None
282+
or police_x is None
283+
or police_y is None
284+
):
285+
continue
286+
271287
action_clear: Optional[ActionClear | ActionClearArea] = None
272288
clear_blockade: Optional[Blockade] = None
273289
for blockade in blockades:
274-
if not self._is_inside(human_x, human_y, blockade.get_apexes()):
290+
blockade_apexes = blockade.get_apexes()
291+
if blockade_apexes is None or not self._is_inside(
292+
human_x, human_y, blockade_apexes
293+
):
275294
continue
276295

277296
distance = self._get_distance(police_x, police_y, human_x, human_y)
@@ -373,7 +392,10 @@ def _get_distance(self, x1: float, y1: float, x2: float, y2: float) -> float:
373392
def _is_intersecting_area(
374393
self, agent_x: float, agent_y: float, point_x: float, point_y: float, area: Area
375394
) -> bool:
376-
for edge in area.get_edges():
395+
edges = area.get_edges()
396+
if edges is None:
397+
return False
398+
for edge in edges:
377399
start_x = edge.get_start_x()
378400
start_y = edge.get_start_y()
379401
end_x = edge.get_end_x()
@@ -483,11 +505,13 @@ def _get_move_points(self, road: Road) -> set[tuple[float, float]]:
483505
if self._is_inside(mid_x, mid_y, apex):
484506
points.add((mid_x, mid_y))
485507

486-
for edge in road.get_edges():
487-
mid_x = (edge.get_start_x() + edge.get_end_x()) / 2.0
488-
mid_y = (edge.get_start_y() + edge.get_end_y()) / 2.0
489-
if (mid_x, mid_y) in points:
490-
points.remove((mid_x, mid_y))
508+
edges = road.get_edges()
509+
if edges is not None:
510+
for edge in edges:
511+
mid_x = (edge.get_start_x() + edge.get_end_x()) / 2.0
512+
mid_y = (edge.get_start_y() + edge.get_end_y()) / 2.0
513+
if (mid_x, mid_y) in points:
514+
points.remove((mid_x, mid_y))
491515

492516
self._move_point_cache[road.get_entity_id()] = points
493517

@@ -518,6 +542,8 @@ def _is_intersecting_blockade(
518542
blockade: Blockade,
519543
) -> bool:
520544
apexes = blockade.get_apexes()
545+
if apexes is None or len(apexes) < 4:
546+
return False
521547
for i in range(0, len(apexes) - 3, 2):
522548
line1 = LineString(
523549
[(apexes[i], apexes[i + 1]), (apexes[i + 2], apexes[i + 3])]
@@ -532,6 +558,8 @@ def _is_intersecting_blockades(
532558
) -> bool:
533559
apexes1 = blockade1.get_apexes()
534560
apexes2 = blockade2.get_apexes()
561+
if apexes1 is None or apexes2 is None or len(apexes1) < 4 or len(apexes2) < 4:
562+
return False
535563
for i in range(0, len(apexes1) - 2, 2):
536564
for j in range(0, len(apexes2) - 2, 2):
537565
line1 = LineString(
@@ -593,20 +621,26 @@ def _get_area_clear_action(
593621
if min_distance < self._clear_distance:
594622
return ActionClear(clear_blockade)
595623
else:
596-
return ActionMove(
597-
[police_entity.get_position()],
598-
clear_blockade.get_x(),
599-
clear_blockade.get_y(),
600-
)
624+
position = police_entity.get_position()
625+
if position is not None:
626+
return ActionMove(
627+
[position],
628+
clear_blockade.get_x(),
629+
clear_blockade.get_y(),
630+
)
601631

602632
agent_x = police_entity.get_x()
603633
agent_y = police_entity.get_y()
634+
if agent_x is None or agent_y is None:
635+
return None
604636
clear_blockade = None
605637
min_point_distance = sys.float_info.max
606638
clear_x = 0
607639
clear_y = 0
608640
for blockade in blockades:
609641
apexes = blockade.get_apexes()
642+
if apexes is None or len(apexes) < 4:
643+
continue
610644
for i in range(0, len(apexes) - 2, 2):
611645
distance = self._get_distance(
612646
agent_x, agent_y, apexes[i], apexes[i + 1]
@@ -625,7 +659,9 @@ def _get_area_clear_action(
625659
clear_x = int(agent_x + vector[0])
626660
clear_y = int(agent_y + vector[1])
627661
return ActionClearArea(clear_x, clear_y)
628-
return ActionMove([police_entity.get_position()], clear_x, clear_y)
662+
position = police_entity.get_position()
663+
if position is not None:
664+
return ActionMove([position], clear_x, clear_y)
629665

630666
return None
631667

@@ -637,7 +673,12 @@ def _get_neighbour_position_action(
637673
) -> Optional[Action]:
638674
agent_x = police_entity.get_x()
639675
agent_y = police_entity.get_y()
640-
position = self.world_info.get_entity(police_entity.get_position())
676+
if agent_x is None or agent_y is None:
677+
return None
678+
position_id = police_entity.get_position()
679+
if position_id is None:
680+
return None
681+
position = self.world_info.get_entity(position_id)
641682
if position is None:
642683
return None
643684

@@ -646,7 +687,7 @@ def _get_neighbour_position_action(
646687
return None
647688

648689
if isinstance(position, Road):
649-
road = cast(Road, position)
690+
road = position
650691
if road.get_blockades() != []:
651692
mid_x = (edge.get_start_x() + edge.get_end_x()) / 2.0
652693
mid_y = (edge.get_start_y() + edge.get_end_y()) / 2.0
@@ -716,7 +757,7 @@ def _get_neighbour_position_action(
716757
return action_move
717758

718759
if isinstance(target, Road):
719-
road = cast(Road, target)
760+
road = target
720761
if road.get_blockades() == []:
721762
return ActionMove([position.get_entity_id(), target.get_entity_id()])
722763

@@ -726,6 +767,8 @@ def _get_neighbour_position_action(
726767
clear_y = 0
727768
for blockade in self.world_info.get_blockades(road):
728769
apexes = blockade.get_apexes()
770+
if apexes is None or len(apexes) < 4:
771+
continue
729772
for i in range(0, len(apexes) - 2, 2):
730773
distance = self._get_distance(
731774
agent_x, agent_y, apexes[i], apexes[i + 1]

adf_core_python/implement/action/default_extend_action_rescue.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _calc_rescue(
117117
return None
118118

119119
if isinstance(target_entity, Human):
120-
human = cast(Human, target_entity)
120+
human = target_entity
121121
if human.get_hp() == 0:
122122
return None
123123

@@ -139,7 +139,7 @@ def _calc_rescue(
139139
return None
140140

141141
if isinstance(target_entity, Blockade):
142-
blockade = cast(Blockade, target_entity)
142+
blockade = target_entity
143143
blockade_position = blockade.get_position()
144144
if blockade_position is None:
145145
return None

0 commit comments

Comments
 (0)