Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions examples/devices.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
#!/usr/bin/env python3
# pylint: disable=W0621
"""Asynchronous client for the Tailscale API."""

import asyncio
import os

from tailscale import Tailscale

# "-" is the default tailnet of the API key
TAILNET = os.environ.get("TS_TAILNET", "-")
API_KEY = os.environ.get("TS_API_KEY", "")


async def main() -> None:
"""Show example on using the Tailscale API client."""
async with Tailscale(
tailnet="frenck",
api_key="tskey-somethingsomething",
tailnet=TAILNET,
api_key=API_KEY,
) as tailscale:
devices = await tailscale.devices()
print(devices)
Expand Down
30 changes: 30 additions & 0 deletions examples/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3
# pylint: disable=W0621
"""Asynchronous client for the Tailscale API."""

import asyncio
import os

from tailscale import Tailscale

# "-" is the default tailnet of the API key
TAILNET = os.environ.get("TS_TAILNET", "-")

# OAuth client ID and secret are required for OAuth authentication
OAUTH_CLIENT_ID = os.environ.get("TS_API_CLIENT_ID", "")
OAUTH_CLIENT_SECRET = os.environ.get("TS_API_CLIENT_SECRET", "")


async def main_oauth() -> None:
"""Show example on using the Tailscale API client with OAuth."""
async with Tailscale(
tailnet=TAILNET,
oauth_client_id=OAUTH_CLIENT_ID,
oauth_client_secret=OAUTH_CLIENT_SECRET,
) as tailscale:
devices = await tailscale.devices()
print(devices)


if __name__ == "__main__":
asyncio.run(main_oauth())
78 changes: 70 additions & 8 deletions src/tailscale/tailscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from __future__ import annotations

import asyncio
import json
import socket
from dataclasses import dataclass
from typing import Any, Self

from aiohttp import BasicAuth
from aiohttp.client import ClientError, ClientResponseError, ClientSession
from aiohttp.hdrs import METH_GET
from aiohttp.hdrs import METH_GET, METH_POST
from yarl import URL

from .exceptions import (
Expand All @@ -19,25 +19,80 @@
)
from .models import Device, Devices

# Placeholder value for the access token when it is not yet set.
ACCESS_TOKEN_PENDING = "<pending>" # noqa: S105


@dataclass
class Tailscale:
"""Main class for handling connections with the Tailscale API."""

tailnet: str
api_key: str
# tailnet of '-' is the default tailnet of the API key
tailnet: str = "-"
api_key: str = "" # nosec
oauth_client_id: str = "" # nosec
oauth_client_secret: str = "" # nosec

request_timeout: int = 8
session: ClientSession | None = None

_close_session: bool = False

async def _check_access(self) -> None:
"""Initialize the Tailscale client.

Raises:
TailscaleAuthenticationError: when neither api_key nor oauth_client_id and
oauth_client_secret are provided.

"""
if (
not self.api_key
and not self.oauth_client_id
and not self.oauth_client_secret
):
msg = "Either api_key or oauth client is required"
raise TailscaleAuthenticationError(msg)
if not self.api_key:
self.api_key = ACCESS_TOKEN_PENDING
self.api_key = await self._get_oauth_token()

async def _get_oauth_token(self) -> str:
"""Get an OAuth token from the Tailscale API.

Raises:
TailscaleAuthenticationError: when access key not found in response.

Returns:
A string with the OAuth token, or nothing on error

"""
# Tailscale's OAuth endpoint requires form-encoded body
# with client_id and client_secret
data = {
"client_id": self.oauth_client_id,
"client_secret": self.oauth_client_secret,
}
response = await self._request(
"oauth/token",
data=data,
method=METH_POST,
use_form_encoding=True,
)

token = json.loads(response).get("access_token", "")
if not token:
msg = "Failed to get OAuth token"
raise TailscaleAuthenticationError(msg)
return str(token)

async def _request(
self,
uri: str,
*,
method: str = METH_GET,
data: dict[str, Any] | None = None,
use_form_encoding: bool = False,
) -> str:
"""Handle a request to the Tailscale API.

Expand Down Expand Up @@ -66,22 +121,29 @@ async def _request(
"""
url = URL("https://api.tailscale.com/api/v2/").join(URL(uri))

headers = {
await self._check_access()

headers: dict[str, str] = {
"Accept": "application/json",
}

if self.api_key and self.api_key != ACCESS_TOKEN_PENDING:
# API keys and oauth tokens can use Bearer authentication
headers["Authorization"] = f"Bearer {self.api_key}"

if self.session is None:
self.session = ClientSession()
self._close_session = True

try:
async with asyncio.timeout(self.request_timeout):
# Use form encoding for OAuth token requests, JSON for others
response = await self.session.request(
method,
url,
json=data,
auth=BasicAuth(self.api_key),
headers=headers,
headers=headers if headers else None,
data=data if use_form_encoding else None,
json=data if not use_form_encoding else None,
)
response.raise_for_status()
except asyncio.TimeoutError as exception:
Expand Down
78 changes: 78 additions & 0 deletions tests/test_tailscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,84 @@
)


@pytest.mark.asyncio
async def test_no_access() -> None:
"""Test api key or oauth key is checked correctly."""
async with Tailscale(tailnet="frenck") as tailscale:
with pytest.raises(TailscaleAuthenticationError):
assert await tailscale._request("test")


@pytest.mark.asyncio
async def test_key_from_oauth(aresponses: ResponsesMockServer) -> None:
"""Test oauth key response is handled correctly."""
aresponses.add(
"api.tailscale.com",
"/api/v2/oauth/token",
"POST",
aresponses.Response(
status=200,
headers={"Content-Type": "application/json"},
text='{"access_token": "short-lived-token"}',
),
)
aresponses.add(
"api.tailscale.com",
"/api/v2/test",
"GET",
aresponses.Response(
status=200,
headers={"Content-Type": "application/json"},
text='{"status": "ok"}',
),
)
async with aiohttp.ClientSession() as session:
tailscale = Tailscale(
tailnet="frenck",
oauth_client_id="client", # nosec
oauth_client_secret="notsosecret", # noqa: S106
session=session,
)
await tailscale._request("test")
second_request = aresponses.history[1].request
assert "Bearer" in second_request.headers["Authorization"]
await tailscale.close()

aresponses.assert_plan_strictly_followed()


@pytest.mark.asyncio
async def test_bad_oauth(aresponses: ResponsesMockServer) -> None:
"""Test bad oauth error is handled correctly."""
aresponses.add(
"api.tailscale.com",
"/api/v2/oauth/token",
"POST",
aresponses.Response(
status=200,
headers={"Content-Type": "application/json"},
text='{"no_access_token": "unauthorized"}',
),
)

async with aiohttp.ClientSession() as session:
tailscale = Tailscale(
tailnet="frenck",
oauth_client_id="client", # nosec
oauth_client_secret="notsosecret", # noqa: S106
session=session,
)
with pytest.raises(TailscaleAuthenticationError) as excinfo:
assert await tailscale._request("test")

assert excinfo.value.args[0] == "Failed to get OAuth token"

await tailscale.close()

aresponses.assert_plan_strictly_followed()


@pytest.mark.asyncio
async def test_json_request(aresponses: ResponsesMockServer) -> None:
"""Test JSON response is handled correctly."""
aresponses.add(
Expand Down
Loading