Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions examples/devices.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
#!/usr/bin/env python3
# pylint: disable=W0621
"""Asynchronous client for the Tailscale API."""

import asyncio
import os

from tailscale import Tailscale

# "-" is the default tailnet of the API key
TAILNET = os.environ.get("TS_TAILNET", "-")
API_KEY = os.environ.get("TS_API_KEY", "")


async def main() -> None:
"""Show example on using the Tailscale API client."""
async with Tailscale(
tailnet="frenck",
api_key="tskey-somethingsomething",
tailnet=TAILNET,
api_key=API_KEY,
) as tailscale:
devices = await tailscale.devices()
print(devices)
Expand Down
30 changes: 30 additions & 0 deletions examples/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3
# pylint: disable=W0621
"""Asynchronous client for the Tailscale API."""

import asyncio
import os

from tailscale import Tailscale

# "-" is the default tailnet of the API key
TAILNET = os.environ.get("TS_TAILNET", "-")

# OAuth client ID and secret are required for OAuth authentication
OAUTH_CLIENT_ID = os.environ.get("TS_API_CLIENT_ID", "")
OAUTH_CLIENT_SECRET = os.environ.get("TS_API_CLIENT_SECRET", "")


async def main_oauth() -> None:
"""Show example on using the Tailscale API client with OAuth."""
async with Tailscale(
tailnet=TAILNET,
oauth_client_id=OAUTH_CLIENT_ID,
oauth_client_secret=OAUTH_CLIENT_SECRET,
) as tailscale:
devices = await tailscale.devices()
print(devices)


if __name__ == "__main__":
asyncio.run(main_oauth())
2 changes: 2 additions & 0 deletions src/tailscale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TailscaleError,
)
from .models import ClientConnectivity, ClientSupports, Device, Devices
from .storage import TokenStorage
from .tailscale import Tailscale

__all__ = [
Expand All @@ -17,4 +18,5 @@
"TailscaleAuthenticationError",
"TailscaleConnectionError",
"TailscaleError",
"TokenStorage",
]
29 changes: 29 additions & 0 deletions src/tailscale/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Abstract token storage."""

from abc import ABC, abstractmethod
from datetime import datetime


class TokenStorage(ABC):
"""Abstract class for token storage implementations."""

@abstractmethod
async def get_token(self) -> tuple[str, datetime] | None:
"""Get the stored token.

Returns:
The stored token and expiration time, or None if no token is stored.

"""
raise NotImplementedError

@abstractmethod
async def set_token(self, access_token: str, expires_at: datetime) -> None:
"""Store the given token.

Args:
access_token: The access token to store.
expires_at: The expiration time of the access token.

"""
raise NotImplementedError
152 changes: 140 additions & 12 deletions src/tailscale/tailscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from __future__ import annotations

import asyncio
import json
import socket
from dataclasses import dataclass
from typing import Any, Self
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Self

from aiohttp import BasicAuth
from aiohttp.client import ClientError, ClientResponseError, ClientSession
from aiohttp.hdrs import METH_GET
from aiohttp.hdrs import METH_GET, METH_POST
from yarl import URL

from .exceptions import (
Expand All @@ -19,25 +20,136 @@
)
from .models import Device, Devices

if TYPE_CHECKING:
from .storage import TokenStorage


@dataclass
# pylint: disable-next=too-many-instance-attributes
class Tailscale:
"""Main class for handling connections with the Tailscale API."""

tailnet: str
api_key: str
# tailnet of '-' is the default tailnet of the API key
tailnet: str = "-"
api_key: str | None = None
oauth_client_id: str | None = None
oauth_client_secret: str | None = None

request_timeout: int = 8
session: ClientSession | None = None
token_storage: TokenStorage | None = None

_get_oauth_token_task: asyncio.Task[None] | None = None
_expire_oauth_token_task: asyncio.Task[None] | None = None
_close_session: bool = False

async def _check_api_key(self) -> None:
"""Initialize the Tailscale client.

Raises:
TailscaleAuthenticationError: when neither api_key nor oauth_client_id and
oauth_client_secret are provided.

"""
if not (
(self.api_key and not self.oauth_client_id and not self.oauth_client_secret)
or (not self.api_key and self.oauth_client_id and self.oauth_client_secret)
or (
self.api_key
and self.oauth_client_id
and self.oauth_client_secret
and self._get_oauth_token_task
)
):
msg = (
"Either api_key or oauth_client_id and oauth_client_secret "
"are required when Tailscale client is initialized"
)
raise TailscaleAuthenticationError(msg)
if not self.api_key:
# Handle some inconsistent state
# possibly caused by user manually deleting api_key
if self._expire_oauth_token_task:
self._expire_oauth_token_task.cancel()
self._expire_oauth_token_task = None
if self._get_oauth_token_task:
self._get_oauth_token_task.cancel()
self._get_oauth_token_task = None
# Get a new OAuth token if not already in the process of getting one
if not self._get_oauth_token_task:
self._get_oauth_token_task = asyncio.create_task(
self._get_oauth_token()
)
# Wait for the OAuth token to be retrieved
await self._get_oauth_token_task

async def _get_oauth_token(self) -> None:
"""Get an OAuth token from the Tailscale API or token storage.

Raises:
TailscaleAuthenticationError: when access token not found in response or
access token expires in less than 5 minutes.

"""
if self.token_storage:
token_data = await self.token_storage.get_token()
if token_data:
access_token, expires_at = token_data
expires_in = (expires_at - datetime.now(timezone.utc)).total_seconds()
if expires_in > 60:
self._expire_oauth_token_task = asyncio.create_task(
self._expire_oauth_token(expires_in)
)
self.api_key = access_token
return

# Tailscale's OAuth endpoint requires form-encoded body
# with client_id and client_secret
data = {
"client_id": self.oauth_client_id,
"client_secret": self.oauth_client_secret,
}
response = await self._request(
"oauth/token",
data=data,
method=METH_POST,
_use_authentication=False,
_use_form_encoding=True,
)

json_response = json.loads(response)
access_token = str(json_response.get("access_token", ""))
expires_in = float(json_response.get("expires_in", 0))
if not access_token or not expires_in:
msg = "Failed to get OAuth token"
raise TailscaleAuthenticationError(msg)
if expires_in <= 60:
msg = "OAuth token expires in less than 1 minute"
raise TailscaleAuthenticationError(msg)

self._expire_oauth_token_task = asyncio.create_task(
self._expire_oauth_token(expires_in)
)
if self.token_storage:
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
await self.token_storage.set_token(access_token, expires_at)
self.api_key = access_token

async def _expire_oauth_token(self, expires_in: float) -> None:
"""Expires the OAuth token 1 minute before expiration."""
await asyncio.sleep(expires_in - 60)
self.api_key = None
self._get_oauth_token_task = None
self._expire_oauth_token_task = None

async def _request(
self,
uri: str,
*,
method: str = METH_GET,
data: dict[str, Any] | None = None,
_use_authentication: bool = True,
_use_form_encoding: bool = False,
) -> str:
"""Handle a request to the Tailscale API.

Expand All @@ -52,8 +164,7 @@ async def _request(

Returns:
-------
A Python dictionary (JSON decoded) with the response from
the Tailscale API.
The response from the Tailscale API.

Raises:
------
Expand All @@ -66,29 +177,42 @@ async def _request(
"""
url = URL("https://api.tailscale.com/api/v2/").join(URL(uri))

headers = {
headers: dict[str, str] = {
"Accept": "application/json",
}

if _use_authentication:
await self._check_api_key()
# API keys and oauth tokens can use Bearer authentication
headers["Authorization"] = f"Bearer {self.api_key}"

if self.session is None:
self.session = ClientSession()
self._close_session = True

try:
async with asyncio.timeout(self.request_timeout):
# Use form encoding for OAuth token requests, JSON for others
response = await self.session.request(
method,
url,
json=data,
auth=BasicAuth(self.api_key),
headers=headers,
headers=headers if headers else None,
data=data if _use_form_encoding else None,
json=data if not _use_form_encoding else None,
)
response.raise_for_status()
except asyncio.TimeoutError as exception:
msg = "Timeout occurred while connecting to the Tailscale API"
raise TailscaleConnectionError(msg) from exception
except ClientResponseError as exception:
if exception.status in [401, 403]:
if _use_authentication and self.api_key and self.oauth_client_id:
# Invalidate the current OAuth token
self.api_key = None
self._get_oauth_token_task = None
if self._expire_oauth_token_task:
self._expire_oauth_token_task.cancel()
self._expire_oauth_token_task = None
msg = "Authentication to the Tailscale API failed"
raise TailscaleAuthenticationError(msg) from exception
msg = "Error occurred while connecting to the Tailscale API"
Expand All @@ -114,9 +238,13 @@ async def devices(self) -> dict[str, Device]:
return Devices.from_json(data).devices

async def close(self) -> None:
"""Close open client session."""
"""Close open client session and cancel tasks."""
if self.session and self._close_session:
await self.session.close()
if self._get_oauth_token_task:
self._get_oauth_token_task.cancel()
if self._expire_oauth_token_task:
self._expire_oauth_token_task.cancel()

async def __aenter__(self) -> Self:
"""Async enter.
Expand Down
27 changes: 27 additions & 0 deletions tests/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Dummy token storage."""

from datetime import datetime

from tailscale.storage import TokenStorage


class InMemoryTokenStorage(TokenStorage):
"""In-memory token storage for testing purposes."""

def __init__(
self, access_token: str | None = None, expires_at: datetime | None = None
) -> None:
"""Initialize the in-memory token storage."""
self._access_token = access_token
self._expires_at = expires_at

async def get_token(self) -> tuple[str, datetime] | None:
"""Get the stored token."""
if self._access_token and self._expires_at:
return self._access_token, self._expires_at
return None

async def set_token(self, access_token: str, expires_at: datetime) -> None:
"""Store the token."""
self._access_token = access_token
self._expires_at = expires_at
Loading