Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions src/simple_github/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import json
import time
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional

from aiohttp import ClientResponse, ClientSession
from gql import Client as GqlClient
Expand All @@ -12,22 +13,19 @@
from requests import Response as RequestsResponse
from requests import Session

from simple_github.util.rate_limit import get_wait_time, is_rate_limited
from simple_github.util.types import BaseDict, BaseNone, BaseResponse, RequestData

if TYPE_CHECKING:
from simple_github.auth import Auth

GITHUB_API_ENDPOINT = "https://api.github.com"
GITHUB_GRAPHQL_ENDPOINT = "https://api.github.com/graphql"

Response = Union[RequestsResponse, ClientResponse]
RequestData = Optional[Dict[str, Any]]

# Implementations of the base class can be either sync or async.
BaseDict = Union[Dict[str, Any], Coroutine[None, None, Dict[str, Any]]]
BaseNone = Union[None, Coroutine[None, None, None]]
BaseResponse = Union[Response, Coroutine[None, None, Response]]


class Client:
MAX_RETRIES = 5

def __init__(self, auth: "Auth"):
"""A Github client.

Expand Down Expand Up @@ -137,8 +135,15 @@ def request(self, method: str, query: str, **kwargs) -> RequestsResponse:
url = f"{GITHUB_API_ENDPOINT}/{query.lstrip('/')}"
session = self._get_requests_session()

with session.request(method, url, **kwargs) as resp:
return resp
attempt = 1
while True:
with session.request(method, url, **kwargs) as resp:
if attempt <= self.MAX_RETRIES and is_rate_limited(resp):
time.sleep(get_wait_time(resp, attempt))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might want to log something before sleep?

attempt += 1
continue

return resp

def get(self, query: str) -> RequestsResponse:
"""Make a GET request to Github's REST API.
Expand Down Expand Up @@ -208,7 +213,16 @@ def execute(self, query: str, variables: RequestData = None) -> Dict[str, Any]:
Dict: The result of the executed query.
"""
session = self._get_gql_session()
return session.execute(gql(query), variable_values=variables)

attempt = 1
while True:
resp = session.execute(gql(query), variable_values=variables)
if attempt <= self.MAX_RETRIES and is_rate_limited(resp):
time.sleep(get_wait_time(resp, attempt))
attempt += 1
continue

return resp


class AsyncClient(Client):
Expand Down Expand Up @@ -278,7 +292,15 @@ async def request(self, method: str, query: str, **kwargs: Any) -> ClientRespons
"""
url = f"{GITHUB_API_ENDPOINT}/{query.lstrip('/')}"
session = await self._get_aiohttp_session()
return await session.request(method, url, **kwargs)
attempt = 1
while True:
async with session.request(method, url, **kwargs) as resp:
if attempt <= self.MAX_RETRIES and is_rate_limited(resp):
await asyncio.sleep(get_wait_time(resp, attempt))
attempt += 1
continue

return resp

async def get(self, query: str) -> ClientResponse:
"""Make a GET request to Github's REST API.
Expand Down Expand Up @@ -350,4 +372,12 @@ async def execute(
Dict: The result of the executed query.
"""
session = await self._get_gql_session()
return await session.execute(gql(query), variable_values=variables)
attempt = 1
while True:
resp = await session.execute(gql(query), variable_values=variables)
if attempt <= self.MAX_RETRIES and is_rate_limited(resp):
await asyncio.sleep(get_wait_time(resp, attempt))
attempt += 1
continue

return resp
82 changes: 82 additions & 0 deletions src/simple_github/util/rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from time import time

from simple_github.util.types import BaseResponse


def is_rate_limited(resp: BaseResponse) -> bool:
"""
Determine if a response indicates a rate limit has been reached.
Checks the response headers and body to identify if the request was
rate-limited. It handles both GraphQL and REST API responses, looking for
specific status codes, headers, and error messages.
Args:
resp (Response): The HTTP response object to evaluate.
Returns:
bool: True if the response indicates a rate limit, False otherwise.
"""
resource = resp.headers.get("x-ratelimit-resource")

if resource == "graphql":
if resp.status_code in (200, 403):
data = resp.json()
errors = data.get("errors", [])
for error in errors:
if (
error.get("type") == "RATE_LIMITED"
or error.get("extensions", {}).get("code") == "RATE_LIMITED"
or "rate limit" in error.get("message", "").lower()
):
return True

elif resp.status_code in (403, 429):
if resp.headers.get("x-ratelimit-remaining") == "0" or resp.headers.get(
"retry-after"
):
return True

try:
data = resp.json()
message = data.get("message", "").lower()
if "rate limit exceeded" or "too many requests" in message:
return True
except ValueError:
pass

return False


def get_wait_time(resp: BaseResponse, attempt: int = 1) -> int:
"""
Calculate the wait time before retrying a request after hitting a rate limit.
Determines the appropriate wait time based on the response headers and the
number of retry attempts. It prioritizes the `x-ratelimit-reset` and
`retry-after` headers if available, and falls back to a default wait time
strategy otherwise.
Args:
resp (Response): The HTTP response object containing rate limit headers.
attempt (int): The current retry attempt number (default is 1).
Returns:
int: The calculated wait time in seconds.
"""
attempt = max(attempt, 1)
remaining = resp.headers.get("x-ratelimit-remaining")
reset = resp.headers.get("x-ratelimit-reset")
retry_after = resp.headers.get("retry-after")

if remaining == "0" and reset:
wait_time = max(0, int(reset) - int(time()))
elif retry_after:
wait_time = int(retry_after)
Comment on lines +73 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle ValueError from the casts?

else:
# If the `x-ratelimit-reset` or `retry-after` headers aren't set, then
# the recommendation is to wait at least one minute and increase the
# interval with each new attempt.
wait_time = 60 + (20 * (attempt - 1))

return wait_time
12 changes: 12 additions & 0 deletions src/simple_github/util/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any, Coroutine, Dict, Optional, Union

from aiohttp import ClientResponse
from requests import Response as RequestsResponse

Response = Union[RequestsResponse, ClientResponse]
RequestData = Optional[Dict[str, Any]]

# Implementations of the base class can be either sync or async.
BaseDict = Union[Dict[str, Any], Coroutine[None, None, Dict[str, Any]]]
BaseNone = Union[None, Coroutine[None, None, None]]
BaseResponse = Union[Response, Coroutine[None, None, Response]]
Loading
Loading