Skip to content

Commit cfb3c15

Browse files
fix: import/export typing
1 parent 3bde9ce commit cfb3c15

File tree

6 files changed

+66
-47
lines changed

6 files changed

+66
-47
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ You will also need the following env variables set for the Exchange integration
105105
2. Activate the virtual environment
106106
`source venv/bin/activate`
107107
3. Install the client
108-
`pip3 install --editable .`
108+
`pip3 install --editable .[all]`
109109
4. Install test deps
110110
`pip3 install "pytest<8" "requests-mock>=1.8.0" "pytest-asyncio>0.21"`
111111
5. Run tests

docker-compose.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
version: "3"
2-
31
services:
42
indico-client-build:
53
build:

indico/queries/model_export.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
from typing import TYPE_CHECKING
2+
13
from indico.client.request import Delay, GraphQLRequest, RequestChain
24
from indico.types.model_export import ModelExport
35

6+
if TYPE_CHECKING: # pragma: no cover
7+
from typing import Any, Iterator, List, Union
8+
9+
from indico.typing import Payload
10+
411

5-
class _CreateModelExport(GraphQLRequest):
12+
class _CreateModelExport(GraphQLRequest["ModelExport"]):
613
query = """
714
mutation ($modelId: Int!) {
815
createModelExport(
@@ -20,11 +27,11 @@ def __init__(self, model_id: int):
2027
self.model_id = model_id
2128
super().__init__(self.query, variables={"modelId": model_id})
2229

23-
def process_response(self, response) -> ModelExport:
24-
return ModelExport(**super().process_response(response)["createModelExport"])
30+
def process_response(self, response: "Payload") -> ModelExport:
31+
return ModelExport(**super().parse_payload(response)["createModelExport"])
2532

2633

27-
class CreateModelExport(RequestChain):
34+
class CreateModelExport(RequestChain["List[ModelExport]"]):
2835
"""
2936
Create a model export.
3037
@@ -36,20 +43,20 @@ class CreateModelExport(RequestChain):
3643
request_interval (int | float): the interval between requests in seconds. Defaults to 5.
3744
"""
3845

39-
previous: ModelExport | None = None
46+
previous: "Any" = None
4047

4148
def __init__(
4249
self,
4350
model_id: int,
4451
wait: bool = True,
45-
request_interval: int | float = 5,
52+
request_interval: "Union[int, float]" = 5,
4653
):
4754
self.wait = wait
4855
self.model_id = model_id
4956
self.request_interval = request_interval
5057
super().__init__()
5158

52-
def requests(self):
59+
def requests(self) -> "Iterator[Union[_CreateModelExport, Delay, GetModelExports]]":
5360
yield _CreateModelExport(self.model_id)
5461
if self.wait:
5562
while self.previous and self.previous.status not in ["COMPLETE", "FAILED"]:
@@ -60,7 +67,7 @@ def requests(self):
6067
yield GetModelExports([self.previous.id], with_signed_url=self.wait is True)
6168

6269

63-
class GetModelExports(GraphQLRequest):
70+
class GetModelExports(GraphQLRequest["List[ModelExport]"]):
6471
"""
6572
Get model export(s).
6673
@@ -91,17 +98,17 @@ class GetModelExports(GraphQLRequest):
9198
"createdBy",
9299
]
93100

94-
def __init__(self, export_ids: list[int], with_signed_url: bool = False):
101+
def __init__(self, export_ids: "List[int]", with_signed_url: bool = False):
95102
if with_signed_url:
96103
self._base_fields.append("signedUrl")
97104

98105
query_with_fields = self.query.replace("{fields}", "\n".join(self._base_fields))
99106
super().__init__(query_with_fields, variables={"exportIds": export_ids})
100107

101-
def process_response(self, response) -> list[ModelExport]:
108+
def process_response(self, response: "Payload") -> "List[ModelExport]":
102109
return [
103110
ModelExport(**export)
104-
for export in super().process_response(response)["modelExports"][
111+
for export in super().parse_payload(response)["modelExports"][
105112
"modelExports"
106113
]
107114
]

indico/queries/model_import.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1-
from typing import Generator
1+
from typing import TYPE_CHECKING, cast
22

33
import requests
44

55
from indico.client.request import GraphQLRequest, RequestChain
66
from indico.errors import IndicoInputError, IndicoRequestError
7-
from indico.queries.jobs import JobStatus
87
from indico.types.jobs import Job
98

9+
from .jobs import JobStatus
1010

11-
class _UploadSMExport(GraphQLRequest):
11+
if TYPE_CHECKING: # pragma: no cover
12+
from typing import Dict, Iterator, Optional, Union # noqa: F401
13+
14+
from indico.typing import Payload
15+
16+
17+
class _UploadSMExport(GraphQLRequest[str]):
1218
query = """
1319
query exportUpload {
1420
exportUpload {
@@ -22,25 +28,26 @@ def __init__(self, file_path: str):
2228
self.file_path = file_path
2329
super().__init__(self.query)
2430

25-
def process_response(self, response) -> str:
26-
resp = super().process_response(response)["exportUpload"]
31+
def process_response(self, response: "Payload") -> str:
32+
resp: "Dict[str, str]" = super().parse_payload(response)["exportUpload"]
2733
signed_url = resp["signedUrl"]
2834
storage_uri = resp["storageUri"]
2935

3036
with open(self.file_path, "rb") as file:
3137
file_content = file.read()
3238

3339
headers = {"Content-Type": "application/zip"}
34-
response = requests.put(signed_url, data=file_content, headers=headers)
40+
export_response = requests.put(signed_url, data=file_content, headers=headers)
3541

36-
if response.status_code != 200:
42+
if export_response.status_code != 200:
3743
raise IndicoRequestError(
38-
f"Failed to upload static model export: {response.text}"
44+
f"Failed to upload static model export: {export_response.text}",
45+
export_response.status_code,
3946
)
4047
return storage_uri
4148

4249

43-
class ProcessStaticModelExport(GraphQLRequest):
50+
class ProcessStaticModelExport(GraphQLRequest["Job"]):
4451
"""
4552
Process a static model export.
4653
@@ -77,12 +84,12 @@ def __init__(
7784
},
7885
)
7986

80-
def process_response(self, response) -> Job:
81-
job_id = super().process_response(response)["processStaticModelExport"]["jobId"]
87+
def process_response(self, response: "Payload") -> Job:
88+
job_id = super().parse_payload(response)["processStaticModelExport"]["jobId"]
8289
return Job(id=job_id)
8390

8491

85-
class UploadStaticModelExport(RequestChain):
92+
class UploadStaticModelExport(RequestChain["Union[Job, str]"]):
8693
"""
8794
Upload a static model export to Indico.
8895
@@ -100,22 +107,27 @@ class UploadStaticModelExport(RequestChain):
100107
"""
101108

102109
def __init__(
103-
self, file_path: str, auto_process: bool = False, workflow_id: int | None = None
110+
self,
111+
file_path: str,
112+
auto_process: bool = False,
113+
workflow_id: "Optional[int]" = None,
104114
):
105-
self.file_path = file_path
106-
self.auto_process = auto_process
107-
if auto_process and not workflow_id:
115+
if auto_process and workflow_id is None:
108116
raise IndicoInputError(
109117
"Must provide `workflow_id` if `auto_process` is True."
110118
)
111119

120+
self.file_path = file_path
121+
self.auto_process = auto_process
112122
self.workflow_id = workflow_id
113123

114-
def requests(self) -> Generator[str | Job, None, None]:
124+
def requests(
125+
self,
126+
) -> "Iterator[Union[_UploadSMExport, ProcessStaticModelExport, JobStatus]]":
115127
if self.auto_process:
116128
yield _UploadSMExport(self.file_path)
117129
yield ProcessStaticModelExport(
118-
storage_uri=self.previous, workflow_id=self.workflow_id
130+
storage_uri=self.previous, workflow_id=cast(int, self.workflow_id)
119131
)
120132
yield JobStatus(self.previous.id)
121133
if self.previous.status == "FAILURE":

indico/queries/workflow_components.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, cast
22

33
import jsons
44

@@ -15,7 +15,7 @@
1515
)
1616

1717
if TYPE_CHECKING: # pragma: no cover
18-
from typing import Iterator, List, Optional, Union
18+
from typing import Any, Iterator, List, Optional, Union
1919

2020
from indico.typing import AnyDict, Payload
2121

@@ -455,7 +455,7 @@ def process_response(self, response: "Payload") -> "Workflow":
455455
)
456456

457457

458-
class AddStaticModelComponent(RequestChain):
458+
class AddStaticModelComponent(RequestChain["Workflow"]):
459459
"""
460460
Add a static model component to a workflow.
461461
@@ -470,17 +470,17 @@ class AddStaticModelComponent(RequestChain):
470470
`export_file(str)`: the path to the static model export file.
471471
"""
472472

473-
previous = None
473+
previous: "Any" = None
474474

475475
def __init__(
476476
self,
477477
workflow_id: int,
478-
after_component_id: int | None = None,
479-
after_component_link_id: int | None = None,
480-
static_component_config: dict[str, Any] | None = None,
481-
component_name: str | None = None,
478+
after_component_id: "Optional[int]" = None,
479+
after_component_link_id: "Optional[int]" = None,
480+
static_component_config: "Optional[AnyDict]" = None,
481+
component_name: "Optional[str]" = None,
482482
auto_process: bool = False,
483-
export_file: str | None = None,
483+
export_file: "Optional[str]" = None,
484484
):
485485
if not export_file and auto_process:
486486
raise IndicoInputError("Must provide export_file if auto_process is True.")
@@ -511,11 +511,13 @@ def __init__(
511511
self.auto_process = auto_process
512512
self.export_file = export_file
513513

514-
def requests(self):
514+
def requests(
515+
self,
516+
) -> "Iterator[Union[UploadStaticModelExport, _AddWorkflowComponent]]":
515517
if self.auto_process:
516518
yield UploadStaticModelExport(
517519
auto_process=True,
518-
file_path=self.export_file,
520+
file_path=cast(str, self.export_file),
519521
workflow_id=self.workflow_id,
520522
)
521523
self.component.update(

tests/integration/queries/test_workflow_component.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
ModelGroup,
2121
ModelTaskType,
2222
NewLabelsetArguments,
23-
StaticModelConfig,
23+
# StaticModelConfig,
2424
)
2525

2626
from ..data.datasets import * # noqa
@@ -257,9 +257,9 @@ def test_add_static_model_component(indico, org_annotate_dataset):
257257
static_model_req = AddStaticModelComponent(
258258
workflow_id=wf.id,
259259
after_component_id=after_component_id,
260-
static_component_config=StaticModelConfig(
261-
export_meta=finished_job.result,
262-
),
260+
static_component_config={
261+
"export_meta": finished_job.result,
262+
},
263263
)
264264
wf = client.call(static_model_req)
265265

0 commit comments

Comments
 (0)