Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
run: |
make test

- name: Lint with mypy and flake8.
- name: Lint with mypy and ruff.
run: |
make lint

Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ coverage: venv
source ${VENV}/bin/activate && pytest --cov=disruptive tests/

lint: venv
source ${VENV}/bin/activate && mypy --config-file ./mypy.ini disruptive/ && flake8 disruptive/
source ${VENV}/bin/activate && mypy --config-file ./mypy.ini disruptive/ && ruff check .

clean:
rm -rf build/ dist/ pip-wheel-metadata/ *.egg-info .pytest_cache/ .mypy_cache/ $(VENV) coverage.xml
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ Run unit-tests against the currently active python version:
make test
```

Lint the package code using MyPy and flake8:
Lint the package code using MyPy and ruff:
```
make lint
```
Expand Down
6 changes: 3 additions & 3 deletions disruptive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Metadata
__version__ = '1.7.3'
__version__ = "1.7.3"

# If set, logs of chosen level and higher are printed to console.
# Default value None results in no logs at any level.
log_level = None

# REST API base URLs of which all endpoints are an expansion.
base_url = 'https://api.disruptive-technologies.com/v2'
emulator_base_url = 'https://emulator.disruptive-technologies.com/v2'
base_url = "https://api.disruptive-technologies.com/v2"
emulator_base_url = "https://emulator.disruptive-technologies.com/v2"

# If a request response contains an error for which a series of retries is
# worth considering, these variable determine how long to wait without an
Expand Down
191 changes: 102 additions & 89 deletions disruptive/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,30 @@ def base64url_encode(data: bytes) -> str:


def base64url_decode(data: str) -> bytes:
padding = '=' * (4 - (len(data) % 4))
padding = "=" * (4 - (len(data) % 4))
return base64.urlsafe_b64decode(data + padding)


def create_jwt(payload: dict,
secret: str,
algorithm: str,
headers: dict,
) -> str:
def create_jwt(
payload: dict,
secret: str,
algorithm: str,
headers: dict,
) -> str:
headers["typ"] = "JWT"

header_encoded = base64url_encode(data=json.dumps(
obj=headers,
separators=(',', ':'),
).encode("utf-8"))
payload_encoded = base64url_encode(json.dumps(
obj=payload,
separators=(',', ':'),
).encode("utf-8"))
header_encoded = base64url_encode(
data=json.dumps(
obj=headers,
separators=(",", ":"),
).encode("utf-8")
)
payload_encoded = base64url_encode(
json.dumps(
obj=payload,
separators=(",", ":"),
).encode("utf-8")
)

message = f"{header_encoded}.{payload_encoded}"

Expand All @@ -53,11 +58,10 @@ def create_jwt(payload: dict,


class _AuthRoutineBase(object):

def __init__(self) -> None:
# Set default attributes.
self._expiration: int = 0
self._token: str = ''
self._token: str = ""

def _has_expired(self) -> bool:
"""
Expand Down Expand Up @@ -106,7 +110,6 @@ def refresh(self) -> None:


class Unauthenticated(_AuthRoutineBase):

def __init__(self) -> None:
# Inherit parent class methods and attributes.
super().__init__()
Expand All @@ -125,21 +128,23 @@ def refresh(self) -> None:

"""

msg = 'Missing Service Account credentials.\n\n' \
'Either set the following environment variables:\n\n' \
' DT_SERVICE_ACCOUNT_KEY_ID: Unique Service Account key ID.\n' \
' DT_SERVICE_ACCOUNT_SECRET: Unique Service Account secret.\n' \
' DT_SERVICE_ACCOUNT_EMAIL: Unique Service Account email.\n\n' \
'or provide them programmatically:\n\n' \
' import disruptive as dt\n\n' \
' dt.default_auth = dt.Auth.service_account(\n' \
' key_id="<SERVICE_ACCOUNT_KEY_ID>",\n' \
' secret="<SERVICE_ACCOUNT_SECRET>",\n' \
' email="<SERVICE_ACCOUNT_EMAIL>",\n' \
' )\n\n' \
'See https://developer.d21s.com/api/' \
'libraries/python/client/authentication.html' \
' for more details.\n'
msg = (
"Missing Service Account credentials.\n\n"
"Either set the following environment variables:\n\n"
" DT_SERVICE_ACCOUNT_KEY_ID: Unique Service Account key ID.\n"
" DT_SERVICE_ACCOUNT_SECRET: Unique Service Account secret.\n"
" DT_SERVICE_ACCOUNT_EMAIL: Unique Service Account email.\n\n"
"or provide them programmatically:\n\n"
" import disruptive as dt\n\n"
" dt.default_auth = dt.Auth.service_account(\n"
' key_id="<SERVICE_ACCOUNT_KEY_ID>",\n'
' secret="<SERVICE_ACCOUNT_SECRET>",\n'
' email="<SERVICE_ACCOUNT_EMAIL>",\n'
" )\n\n"
"See https://developer.d21s.com/api/"
"libraries/python/client/authentication.html"
" for more details.\n"
)

raise dterrors.Unauthorized(msg)

Expand All @@ -155,9 +160,10 @@ class ServiceAccountAuth(_AuthRoutineBase):

"""

supported_algorithms = ['HS256']
token_endpoint = 'https://identity.'\
'disruptive-technologies.com/oauth2/token'
supported_algorithms = ["HS256"]
token_endpoint = (
"https://identity.disruptive-technologies.com/oauth2/token"
)

def __init__(self, key_id: str, secret: str, email: str):
# Inherit parent class methods and attributes.
Expand All @@ -184,7 +190,7 @@ def email(self) -> str:
return self._email

def __repr__(self) -> str:
return '{}.{}({}, {}, {})'.format(
return "{}.{}({}, {}, {})".format(
self.__class__.__module__,
self.__class__.__name__,
repr(self.key_id),
Expand All @@ -202,25 +208,25 @@ def algorithm(self, algorithm: str) -> None:
self._algorithm = algorithm
else:
raise dterrors.ConfigurationError(
f'unsupported algorithm {algorithm}'
f"unsupported algorithm {algorithm}"
)

@classmethod
def from_credentials_file(cls, credentials: dict) -> ServiceAccountAuth:
for key in ['keyId', 'secret', 'email', 'algorithm', 'tokenEndpoint']:
if key not in credentials['serviceAccount']:
for key in ["keyId", "secret", "email", "algorithm", "tokenEndpoint"]:
if key not in credentials["serviceAccount"]:
raise dterrors.ConfigurationError(
f'Invalid credentials file. Missing field "{key}".'
)

cfg = credentials['serviceAccount']
cfg = credentials["serviceAccount"]
auth_obj = cls(
key_id=cfg['keyId'],
secret=cfg['secret'],
email=cfg['email'],
key_id=cfg["keyId"],
secret=cfg["secret"],
email=cfg["email"],
)
auth_obj.algorithm = cfg['algorithm']
auth_obj.token_endpoint = cfg['tokenEndpoint']
auth_obj.algorithm = cfg["algorithm"]
auth_obj.token_endpoint = cfg["tokenEndpoint"]

return auth_obj

Expand All @@ -234,8 +240,8 @@ def refresh(self) -> None:
"""

response: dict = self._get_access_token()
self._expiration = time.time() + response['expires_in']
self._token = 'Bearer {}'.format(response['access_token'])
self._expiration = time.time() + response["expires_in"]
self._token = "Bearer {}".format(response["access_token"])

def _get_access_token(self) -> dict:
"""
Expand All @@ -255,16 +261,16 @@ def _get_access_token(self) -> dict:

# Construct the JWT header.
jwt_headers: dict[str, str] = {
'alg': self.algorithm,
'kid': self.key_id,
"alg": self.algorithm,
"kid": self.key_id,
}

# Construct the JWT payload.
jwt_payload: dict[str, Any] = {
'iat': int(time.time()), # current unixtime
'exp': int(time.time()) + 3600, # expiration unixtime
'aud': self.token_endpoint,
'iss': self.email,
"iat": int(time.time()), # current unixtime
"exp": int(time.time()) + 3600, # expiration unixtime
"aud": self.token_endpoint,
"iss": self.email,
}

# Sign and encode JWT with the secret.
Expand All @@ -277,62 +283,66 @@ def _get_access_token(self) -> dict:

# Prepare HTTP POST request data.
# Note: The requests package applies Form URL-Encoding by default.
request_data: str = urllib.parse.urlencode({
'assertion': encoded_jwt,
'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer'
})
request_data: str = urllib.parse.urlencode(
{
"assertion": encoded_jwt,
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
}
)

# Exchange the JWT for an access token.
try:
access_token_response: dict = dtrequests.DTRequest.post(
url='',
url="",
base_url=self.token_endpoint,
data=request_data,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
headers={"Content-Type": "application/x-www-form-urlencoded"},
skip_auth=True,
)
except dterrors.BadRequest:
# Re-raise exception with more specific information.
raise dterrors.Unauthorized(
'Could not authenticate with the provided credentials.\n\n'
'Read more: https://developer.d21s.com/docs/authentication'
'/oauth2#common-errors'
"Could not authenticate with the provided credentials.\n\n"
"Read more: https://developer.d21s.com/docs/authentication"
"/oauth2#common-errors"
)

# Return the access token in the request.
return access_token_response


def _service_account_env_vars() -> Unauthenticated | ServiceAccountAuth:
key_id = os.getenv('DT_SERVICE_ACCOUNT_KEY_ID', '')
secret = os.getenv('DT_SERVICE_ACCOUNT_SECRET', '')
email = os.getenv('DT_SERVICE_ACCOUNT_EMAIL', '')
key_id = os.getenv("DT_SERVICE_ACCOUNT_KEY_ID", "")
secret = os.getenv("DT_SERVICE_ACCOUNT_SECRET", "")
email = os.getenv("DT_SERVICE_ACCOUNT_EMAIL", "")

if '' in [key_id, secret, email]:
if "" in [key_id, secret, email]:
return Unauthenticated()
else:
return Auth.service_account(key_id, secret, email)


def _credentials_file() -> Unauthenticated | ServiceAccountAuth:
file_path = os.getenv('DT_CREDENTIALS_FILE')
file_path = os.getenv("DT_CREDENTIALS_FILE")
if file_path is not None:
if not os.path.exists(file_path):
msg = 'Missing credentials file.\n\n' \
'Environment variable DT_CREDENTIALS_FILE is set, but' \
' no file found at target path.\n' \
f'{file_path}'
msg = (
"Missing credentials file.\n\n"
"Environment variable DT_CREDENTIALS_FILE is set, but"
" no file found at target path.\n"
f"{file_path}"
)
raise FileNotFoundError(msg)
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
credentials = json.load(f)

if 'serviceAccount' in credentials:
if "serviceAccount" in credentials:
return ServiceAccountAuth.from_credentials_file(credentials)

return Unauthenticated()


class Auth():
class Auth:
"""
Authenticates the API using a factory design pattern.
The Auth class itself is only for namespacing purposes.
Expand Down Expand Up @@ -360,11 +370,12 @@ def unauthenticated() -> Unauthenticated:
return Unauthenticated()

@classmethod
def service_account(cls,
key_id: str,
secret: str,
email: str,
) -> ServiceAccountAuth:
def service_account(
cls,
key_id: str,
secret: str,
email: str,
) -> ServiceAccountAuth:
"""
This method uses an OAuth2 authentication flow. With the provided
credentials, a `JWT <https://jwt.io/>`_ is created and exchanged for
Expand Down Expand Up @@ -396,11 +407,13 @@ def service_account(cls,
"""

# Check that credentials are populated strings.
cls._verify_str_credentials({
'key_id': key_id,
'secret': secret,
'email': email,
})
cls._verify_str_credentials(
{
"key_id": key_id,
"secret": secret,
"email": email,
}
)

return ServiceAccountAuth(key_id, secret, email)

Expand Down Expand Up @@ -428,16 +441,16 @@ def _verify_str_credentials(credentials: dict) -> None:
# the environment with a fallback to an empty string.
if len(credentials[key]) == 0:
raise dterrors.ConfigurationError(
'Authentication credential <{}> is'
' empty string.'.format(key)
"Authentication credential <{}> is"
" empty string.".format(key)
)

# If not, raise TypeError.
else:
raise dterrors._raise_builtin(
TypeError,
'Authentication credential <{}> got type <{}>. '
'Expected <str>.'.format(
"Authentication credential <{}> got type <{}>. "
"Expected <str>.".format(
key, type(credentials[key]).__name__
)
),
)
Loading