|
1 | 1 | from beanie import PydanticObjectId |
2 | | -from app.models.executed_models import ExecutedRequestModel, ExecutedResponseModel |
3 | | - |
4 | 2 | from fastapi import HTTPException, status, BackgroundTasks |
5 | 3 |
|
| 4 | +from app.models.executed_models import ExecutedRequestModel, ExecutedResponseModel |
6 | 5 | from app.models.db.state import State |
7 | 6 | from app.models.state_status_enum import StateStatusEnum |
8 | 7 | from app.singletons.logs_manager import LogsManager |
9 | 8 | from app.tasks.create_next_states import create_next_states |
10 | 9 |
|
11 | 10 | logger = LogsManager().get_logger() |
12 | 11 |
|
13 | | -async def executed_state(namespace_name: str, state_id: PydanticObjectId, body: ExecutedRequestModel, x_exosphere_request_id: str, background_tasks: BackgroundTasks) -> ExecutedResponseModel: |
14 | 12 |
|
| 13 | +async def executed_state( |
| 14 | + namespace_name: str, |
| 15 | + state_id: PydanticObjectId, |
| 16 | + body: ExecutedRequestModel, |
| 17 | + x_exosphere_request_id: str, |
| 18 | + background_tasks: BackgroundTasks, |
| 19 | +) -> ExecutedResponseModel: |
15 | 20 | try: |
16 | | - logger.info(f"Executed state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) |
| 21 | + logger.info( |
| 22 | + f"Executed state {state_id} for namespace {namespace_name}", |
| 23 | + x_exosphere_request_id=x_exosphere_request_id, |
| 24 | + ) |
17 | 25 |
|
18 | 26 | state = await State.find_one(State.id == state_id) |
19 | 27 | if not state or not state.id: |
20 | | - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") |
| 28 | + raise HTTPException( |
| 29 | + status_code=status.HTTP_404_NOT_FOUND, |
| 30 | + detail="State not found", |
| 31 | + ) |
21 | 32 |
|
22 | 33 | if state.status != StateStatusEnum.QUEUED: |
23 | | - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not queued") |
24 | | - |
25 | | - next_state_ids = [] |
| 34 | + raise HTTPException( |
| 35 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 36 | + detail="State is not queued", |
| 37 | + ) |
| 38 | + |
| 39 | + next_state_ids: list[PydanticObjectId] = [] |
| 40 | + |
| 41 | + # ---- Handle outputs ---- |
26 | 42 | if len(body.outputs) == 0: |
27 | 43 | state.status = StateStatusEnum.EXECUTED |
28 | 44 | state.outputs = {} |
29 | 45 | await state.save() |
30 | 46 |
|
31 | 47 | next_state_ids.append(state.id) |
32 | 48 |
|
33 | | - else: |
| 49 | + else: |
| 50 | + # First output updates the current state |
34 | 51 | state.outputs = body.outputs[0] |
35 | 52 | state.status = StateStatusEnum.EXECUTED |
36 | 53 | await state.save() |
| 54 | + |
37 | 55 | next_state_ids.append(state.id) |
38 | 56 |
|
| 57 | + # Remaining outputs create new states |
39 | 58 | new_states = [] |
40 | 59 | for output in body.outputs[1:]: |
41 | | - new_states.append(State( |
42 | | - node_name=state.node_name, |
43 | | - namespace_name=state.namespace_name, |
44 | | - identifier=state.identifier, |
45 | | - graph_name=state.graph_name, |
46 | | - run_id=state.run_id, |
47 | | - status=StateStatusEnum.EXECUTED, |
48 | | - inputs=state.inputs, |
49 | | - outputs=output, |
50 | | - error=None, |
51 | | - parents=state.parents |
52 | | - )) |
53 | | - |
54 | | - if len(new_states) > 0: |
55 | | - inserted_ids = (await State.insert_many(new_states)).inserted_ids |
| 60 | + new_states.append( |
| 61 | + State( |
| 62 | + node_name=state.node_name, |
| 63 | + namespace_name=state.namespace_name, |
| 64 | + identifier=state.identifier, |
| 65 | + graph_name=state.graph_name, |
| 66 | + run_id=state.run_id, |
| 67 | + status=StateStatusEnum.EXECUTED, |
| 68 | + inputs=state.inputs, |
| 69 | + outputs=output, |
| 70 | + error=None, |
| 71 | + parents=state.parents, |
| 72 | + ) |
| 73 | + ) |
| 74 | + |
| 75 | + if new_states: |
| 76 | + inserted_ids = ( |
| 77 | + await State.insert_many(new_states) |
| 78 | + ).inserted_ids |
56 | 79 | next_state_ids.extend(inserted_ids) |
57 | 80 |
|
58 | | - background_tasks.add_task(create_next_states, next_state_ids, state.identifier, state.namespace_name, state.graph_name, state.parents) |
| 81 | + # ---- Create next states ---- |
| 82 | + background_tasks.add_task( |
| 83 | + create_next_states, |
| 84 | + next_state_ids, |
| 85 | + state.identifier, |
| 86 | + state.namespace_name, |
| 87 | + state.graph_name, |
| 88 | + state.parents, |
| 89 | + ) |
59 | 90 |
|
60 | | - return ExecutedResponseModel(status=StateStatusEnum.EXECUTED) |
| 91 | + return ExecutedResponseModel( |
| 92 | + status=StateStatusEnum.EXECUTED |
| 93 | + ) |
61 | 94 |
|
62 | 95 | except Exception as e: |
63 | | - logger.error(f"Error executing state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id, error=e) |
64 | | - raise e |
| 96 | + logger.error( |
| 97 | + f"Error executing state {state_id} for namespace {namespace_name}", |
| 98 | + x_exosphere_request_id=x_exosphere_request_id, |
| 99 | + error=e, |
| 100 | + ) |
| 101 | + raise |
0 commit comments