Skip to content

Commit 1a63aec

Browse files
committed
Add return type annotations to all functions
Add -> None to all __init__ and test methods, and explicit return types to all public methods in response.py, field.py, query.py, and client.py. Closes #32
1 parent 50cf840 commit 1a63aec

7 files changed

Lines changed: 119 additions & 84 deletions

File tree

leakix/client.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from enum import Enum
3+
from typing import Any
34

45
import requests
56
from l9format import l9format
@@ -8,7 +9,12 @@
89
from leakix.domain import L9Subdomain
910
from leakix.plugin import APIResult
1011
from leakix.query import EmptyQuery, Query
11-
from leakix.response import ErrorResponse, RateLimitResponse, SuccessResponse
12+
from leakix.response import (
13+
AbstractResponse,
14+
ErrorResponse,
15+
RateLimitResponse,
16+
SuccessResponse,
17+
)
1218

1319
__VERSION__ = "0.1.9"
1420

@@ -33,17 +39,17 @@ def __init__(
3339
self,
3440
api_key: str | None = None,
3541
base_url: str | None = DEFAULT_URL,
36-
):
42+
) -> None:
3743
self.api_key = api_key
3844
self.base_url = base_url if base_url else DEFAULT_URL
39-
self.headers = {
45+
self.headers: dict[str, str] = {
4046
"Accept": "application/json",
4147
"User-agent": f"leakix-client-python/{__VERSION__}",
4248
}
4349
if api_key:
4450
self.headers["api-key"] = api_key
4551

46-
def __get(self, url, params):
52+
def __get(self, url: str, params: dict[str, Any] | None) -> AbstractResponse:
4753
r = requests.get(
4854
url,
4955
params=params,
@@ -59,7 +65,12 @@ def __get(self, url, params):
5965
else:
6066
return ErrorResponse(response=r, response_json=r.json())
6167

62-
def get(self, scope: Scope, queries: list[Query] | None = None, page: int = 0):
68+
def get(
69+
self,
70+
scope: Scope,
71+
queries: list[Query] | None = None,
72+
page: int = 0,
73+
) -> AbstractResponse:
6374
"""
6475
The function takes a scope (either "leaks" or "services"). The value can be constructed using `Scope.SERVICE` or
6576
`Scope.LEAK`.
@@ -92,11 +103,18 @@ def get(self, scope: Scope, queries: list[Query] | None = None, page: int = 0):
92103
serialized_query = f"{serialized_query}"
93104
url = f"{self.base_url}/search"
94105
r = self.__get(
95-
url=url, params={"scope": scope.value, "q": serialized_query, "page": page}
106+
url=url,
107+
params={
108+
"scope": scope.value,
109+
"q": serialized_query,
110+
"page": page,
111+
},
96112
)
97113
return r
98114

99-
def get_service(self, queries: list[Query] | None = None, page: int = 0):
115+
def get_service(
116+
self, queries: list[Query] | None = None, page: int = 0
117+
) -> AbstractResponse:
100118
"""
101119
Shortcut for `get` with the scope `Scope.Service`.
102120
@@ -108,7 +126,9 @@ def get_service(self, queries: list[Query] | None = None, page: int = 0):
108126
]
109127
return r
110128

111-
def get_leak(self, queries: list[Query] | None = None, page: int = 0):
129+
def get_leak(
130+
self, queries: list[Query] | None = None, page: int = 0
131+
) -> AbstractResponse:
112132
"""
113133
Shortcut for `get` with the scope `Scope.Leak`.
114134
"""
@@ -119,7 +139,7 @@ def get_leak(self, queries: list[Query] | None = None, page: int = 0):
119139
]
120140
return r
121141

122-
def get_host(self, ipv4: str):
142+
def get_host(self, ipv4: str) -> AbstractResponse:
123143
"""
124144
Returns the list of services and associated leaks for a given host. Only the ipv4 format is supported at the
125145
moment.
@@ -136,7 +156,7 @@ def get_host(self, ipv4: str):
136156
r.response_json = response_json
137157
return r
138158

139-
def get_plugins(self):
159+
def get_plugins(self) -> AbstractResponse:
140160
"""
141161
Returns the list of plugins the authenticated user with the given API key has access to.
142162
@@ -152,7 +172,7 @@ def get_plugins(self):
152172
r.response_json = [APIResult.from_dict(d) for d in r.json()]
153173
return r
154174

155-
def get_subdomains(self, domain: str):
175+
def get_subdomains(self, domain: str) -> AbstractResponse:
156176
"""
157177
Returns the list of subdomains for a given domain.
158178
The output is a list of `L9Subdomain` objects. The fields are `subdomain`, `distinct_ips` and `last_seen`.
@@ -164,7 +184,7 @@ def get_subdomains(self, domain: str):
164184
r.response_json = [L9Subdomain.from_dict(d) for d in r.json()]
165185
return r
166186

167-
def bulk_export(self, queries: list[Query] | None = None):
187+
def bulk_export(self, queries: list[Query] | None = None) -> AbstractResponse:
168188
url = f"{self.base_url}/bulk/search"
169189
if queries is None or len(queries) == 0:
170190
serialized_query = EmptyQuery().serialize()
@@ -188,7 +208,9 @@ def bulk_export(self, queries: list[Query] | None = None):
188208
return ErrorResponse(response=r, response_json=r.json())
189209
return r
190210

191-
def bulk_export_last_event(self, queries: list[Query] | None = None):
211+
def bulk_export_last_event(
212+
self, queries: list[Query] | None = None
213+
) -> AbstractResponse:
192214
response = self.bulk_export(queries)
193215
if response.is_success():
194216
for aggreg in response.json():
@@ -201,7 +223,7 @@ def bulk_export_last_event(self, queries: list[Query] | None = None):
201223
aggreg.events = [sorted_events[0]]
202224
return response
203225

204-
def bulk_service(self, queries: list[Query] | None = None):
226+
def bulk_service(self, queries: list[Query] | None = None) -> AbstractResponse:
205227
url = f"{self.base_url}/bulk/service"
206228
if queries is None or len(queries) == 0:
207229
serialized_query = EmptyQuery().serialize()

leakix/field.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ class Operator(Enum):
1111

1212

1313
class CustomField:
14-
def __init__(self, v: str, field_name: str, operator: Operator | None = None):
14+
def __init__(
15+
self, v: str, field_name: str, operator: Operator | None = None
16+
) -> None:
1517
if operator is None:
1618
operator = Operator.Equal
1719
self.operator = operator
@@ -27,40 +29,40 @@ def serialize(self) -> str:
2729

2830

2931
class TimeField(CustomField):
30-
def __init__(self, d: datetime, operator: Operator | None = None):
32+
def __init__(self, d: datetime, operator: Operator | None = None) -> None:
3133
v = '"{}"'.format(d.strftime("%Y-%m-%d"))
3234
super().__init__(v=v, operator=operator, field_name="time")
3335

3436

3537
class UpdateDateField(CustomField):
36-
def __init__(self, d: datetime, operator: Operator | None = None):
38+
def __init__(self, d: datetime, operator: Operator | None = None) -> None:
3739
# v = '"%s"' % d.strftime("%Y-%m-%d %H:%M:%S")
3840
v = '"{}"'.format(d.strftime("%Y-%m-%d"))
3941
super().__init__(v=v, operator=operator, field_name="update_date")
4042

4143

4244
class AgeField(CustomField):
43-
def __init__(self, age: int, operator: Operator | None = None):
45+
def __init__(self, age: int, operator: Operator | None = None) -> None:
4446
super().__init__(v=str(age), operator=operator, field_name="age")
4547

4648

4749
class PluginField(CustomField):
48-
def __init__(self, p: Plugin):
50+
def __init__(self, p: Plugin) -> None:
4951
v = p.value
5052
super().__init__(v=v, operator=None, field_name="plugin")
5153

5254

5355
class IPField(CustomField):
54-
def __init__(self, ip: str, operator: Operator | None = None):
56+
def __init__(self, ip: str, operator: Operator | None = None) -> None:
5557
super().__init__(v=ip, operator=operator, field_name="ip")
5658

5759

5860
class PortField(CustomField):
59-
def __init__(self, port: int, operator: Operator | None = None):
61+
def __init__(self, port: int, operator: Operator | None = None) -> None:
6062
assert 0 <= port < 65536
6163
super().__init__(v=str(port), operator=operator, field_name="port")
6264

6365

6466
class CountryField(CustomField):
65-
def __init__(self, country: str):
67+
def __init__(self, country: str) -> None:
6668
super().__init__(v=country, operator=None, field_name="country")

leakix/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Query(AbstractQuery):
2929
A list of fields can be found in `field.py`.
3030
"""
3131

32-
def __init__(self, field: CustomField):
32+
def __init__(self, field: CustomField) -> None:
3333
self.field = field
3434

3535

@@ -70,7 +70,7 @@ class RawQuery(AbstractQuery):
7070
RawQuery("+host:.be").
7171
"""
7272

73-
def __init__(self, raw_q: str):
73+
def __init__(self, raw_q: str) -> None:
7474
self.raw_q = raw_q
7575

7676
def serialize(self) -> str:

leakix/response.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from abc import ABCMeta, abstractmethod
2+
from typing import Any
23

34

45
class AbstractResponse(metaclass=ABCMeta):
5-
def __init__(self, response, response_json=None, status_code=None):
6+
def __init__(
7+
self,
8+
response: Any,
9+
response_json: Any = None,
10+
status_code: int | None = None,
11+
) -> None:
612
self.response = response
713
self._status_code = (
814
status_code if status_code is not None else self.response.status_code
@@ -11,34 +17,34 @@ def __init__(self, response, response_json=None, status_code=None):
1117
response_json if response_json is not None else response.json()
1218
)
1319

14-
def json(self):
20+
def json(self) -> Any:
1521
return self.response_json
1622

17-
def status_code(self):
23+
def status_code(self) -> int:
1824
return self._status_code
1925

2026
@abstractmethod
21-
def is_success(self):
27+
def is_success(self) -> bool:
2228
pass
2329

2430
@abstractmethod
25-
def is_error(self):
31+
def is_error(self) -> bool:
2632
pass
2733

2834

2935
class SuccessResponse(AbstractResponse):
30-
def is_success(self):
36+
def is_success(self) -> bool:
3137
return True
3238

33-
def is_error(self):
39+
def is_error(self) -> bool:
3440
return False
3541

3642

3743
class ErrorResponse(AbstractResponse):
38-
def is_success(self):
44+
def is_success(self) -> bool:
3945
return False
4046

41-
def is_error(self):
47+
def is_error(self) -> bool:
4248
return True
4349

4450

@@ -47,5 +53,10 @@ class RateLimitResponse(ErrorResponse):
4753

4854

4955
class R(AbstractResponse):
50-
def __init__(self, response, response_json=None, status_code=None):
56+
def __init__(
57+
self,
58+
response: Any,
59+
response_json: Any = None,
60+
status_code: int | None = None,
61+
) -> None:
5162
super().__init__(response, response_json=response_json, status_code=status_code)

tests/test_client.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,42 @@
1717

1818

1919
@pytest.fixture
20-
def client():
20+
def client() -> None:
2121
return Client()
2222

2323

2424
@pytest.fixture
25-
def client_with_api_key():
25+
def client_with_api_key() -> None:
2626
return Client(api_key="test-api-key")
2727

2828

2929
@pytest.fixture
30-
def fake_ipv4():
30+
def fake_ipv4() -> None:
3131
return "33.33.33.33"
3232

3333

3434
class TestClientInit:
35-
def test_default_base_url(self):
35+
def test_default_base_url(self) -> None:
3636
client = Client()
3737
assert client.base_url == "https://leakix.net"
3838

39-
def test_custom_base_url(self):
39+
def test_custom_base_url(self) -> None:
4040
client = Client(base_url="https://custom.leakix.net")
4141
assert client.base_url == "https://custom.leakix.net"
4242

43-
def test_api_key_in_headers(self):
43+
def test_api_key_in_headers(self) -> None:
4444
client = Client(api_key="my-api-key")
4545
assert client.headers["api-key"] == "my-api-key"
4646

47-
def test_no_api_key_header_when_not_provided(self):
47+
def test_no_api_key_header_when_not_provided(self) -> None:
4848
client = Client()
4949
assert "api-key" not in client.headers
5050

51-
def test_user_agent_header(self):
51+
def test_user_agent_header(self) -> None:
5252
client = Client()
5353
assert "leakix-client-python" in client.headers["User-agent"]
5454

55-
def test_accept_header(self):
55+
def test_accept_header(self) -> None:
5656
client = Client()
5757
assert client.headers["Accept"] == "application/json"
5858

0 commit comments

Comments
 (0)