Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ Run them with `poetry run pytest`.

Note that tests marked with `aws` are skipped by default, to avoid the need for an AWS setup.
They are however ran in the GitHub Action.
For this to work, they must have been ran once locally with an account with sufficient permissions (`poetry run pytest -m "aws"`), since for security reasons, the AWS account used on GitHub does not have permissions to create RDS instances.
You can run them locally by adding `-m 'aws or not(aws)'` to the `pytest` command.
51 changes: 36 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
REGION_NAME: BucketLocationConstraintType = "eu-central-1"


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def monkeypatch_module() -> Generator[pytest.MonkeyPatch, Any, None]:
with pytest.MonkeyPatch.context() as mp:
yield mp


@pytest.fixture(autouse=True, scope="module")
@pytest.fixture(autouse=True, scope="session")
def patch_update_job(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock:
mock_update_job = MagicMock()
monkeypatch_module.setattr(job_tracking, "update_job", mock_update_job)
Expand All @@ -33,14 +33,16 @@ def patch_update_job(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock:
class RDSTestingInstance:
def __init__(self, db_name: str):
self.db_name = db_name

def create(self) -> None:
self.rds_client = boto3.client("rds", "eu-central-1")
self.ec2_client = boto3.client("ec2", "eu-central-1")
self.add_ingress_rule()
self.db_url = self.create_db_url()
self.engine = self.get_engine()
self.delete_db_tables()

@property
def engine(self) -> Engine:
def get_engine(self) -> Engine:
for _ in range(5):
try:
engine = create_engine(self.db_url)
Expand Down Expand Up @@ -112,12 +114,14 @@ def create_db_url(self) -> str:
DBName=self.db_name,
DBInstanceIdentifier=self.db_name,
AllocatedStorage=20,
DBInstanceClass="db.t3.micro",
DBInstanceClass="db.t4g.micro",
Engine="postgres",
MasterUsername=user,
MasterUserPassword=password,
DeletionProtection=False,
BackupRetentionPeriod=0,
MultiAZ=False,
EnablePerformanceInsights=False,
)
break
except self.rds_client.exceptions.DBInstanceAlreadyExistsFault:
Expand All @@ -138,19 +142,34 @@ def cleanup(self) -> None:
self.delete_db_tables()
self.ec2_client.revoke_security_group_ingress(**self.vpc_sg_rule_params)

def delete(self) -> None:
self.rds_client.delete_db_instance(
DBInstanceIdentifier=self.db_name,
SkipFinalSnapshot=True,
DeleteAutomatedBackups=True,
)


class S3TestingBucket:
def __init__(self, bucket_name_suffix: str):
# S3 bucket names must be globally unique - avoid collisions by adding suffix
self.bucket_name = f"{TEST_BUCKET_PREFIX}-{bucket_name_suffix}"
self.region_name: BucketLocationConstraintType = REGION_NAME

def create(self) -> None:
self.s3_client = boto3.client(
"s3",
region_name=self.region_name,
# required for pre-signing URLs to work
endpoint_url=f"https://s3.{self.region_name}.amazonaws.com",
)
self.initialize_bucket()
exists = self.cleanup()
if not exists:
self.s3_client.create_bucket(
Bucket=self.bucket_name,
CreateBucketConfiguration={"LocationConstraint": self.region_name},
)
self.s3_client.get_waiter("bucket_exists").wait(Bucket=self.bucket_name)

def cleanup(self) -> bool:
"""Returns True if bucket exists and all objects are deleted."""
Expand All @@ -169,19 +188,21 @@ def cleanup(self) -> bool:
s3_bucket.objects.all().delete()
return True

def initialize_bucket(self) -> None:
def delete(self) -> None:
exists = self.cleanup()
if not exists:
self.s3_client.create_bucket(
Bucket=self.bucket_name,
CreateBucketConfiguration={"LocationConstraint": self.region_name},
)
self.s3_client.get_waiter("bucket_exists").wait(Bucket=self.bucket_name)
if exists:
self.s3_client.delete_bucket(Bucket=self.bucket_name)


@pytest.fixture(scope="session")
def rds_testing_instance() -> RDSTestingInstance:
return RDSTestingInstance("decodecloudintegrationtestsworkerapi")


@pytest.fixture(scope="session")
def bucket_suffix() -> str:
return datetime.datetime.now(datetime.UTC).strftime("%Y%m%d%H%M%S")
def s3_testing_bucket() -> S3TestingBucket:
bucket_suffix = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d%H%M%S")
return S3TestingBucket(bucket_suffix)


@pytest.mark.aws
Expand Down
59 changes: 36 additions & 23 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import shutil
from typing import Any, Generator
from typing import Any, Generator, cast

import pytest

Expand All @@ -18,33 +18,46 @@
from workerfacing_api.main import workerfacing_app


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def test_username() -> str:
return "test_user"


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def base_dir() -> str:
return "int_test_dir"


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def internal_api_key_secret() -> str:
return "test_internal_api_key"


@pytest.fixture(
scope="module",
scope="session",
params=["local", pytest.param("aws", marks=pytest.mark.aws)],
)
def env(request: pytest.FixtureRequest) -> str:
assert isinstance(request.param, str)
return request.param
def env(
request: pytest.FixtureRequest,
rds_testing_instance: RDSTestingInstance,
s3_testing_bucket: S3TestingBucket,
) -> Generator[str, Any, None]:
env = cast(str, request.param)
if env == "aws":
rds_testing_instance.create()
s3_testing_bucket.create()
yield env
if env == "aws":
rds_testing_instance.delete()
s3_testing_bucket.delete()


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def base_filesystem(
env: str, base_dir: str, monkeypatch_module: pytest.MonkeyPatch, bucket_suffix: str
env: str,
base_dir: str,
monkeypatch_module: pytest.MonkeyPatch,
s3_testing_bucket: S3TestingBucket,
) -> Generator[FileSystem, Any, None]:
monkeypatch_module.setattr(
settings,
Expand All @@ -63,39 +76,39 @@ def base_filesystem(
shutil.rmtree(base_dir, ignore_errors=True)

elif env == "aws":
testing_bucket = S3TestingBucket(bucket_suffix)
# Update settings to use the actual unique bucket name created by S3TestingBucket
monkeypatch_module.setattr(
settings,
"s3_bucket",
testing_bucket.bucket_name,
s3_testing_bucket.bucket_name,
)
yield S3Filesystem(testing_bucket.s3_client, testing_bucket.bucket_name)
testing_bucket.cleanup()
yield S3Filesystem(s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name)
s3_testing_bucket.cleanup()

else:
raise NotImplementedError


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def queue(
env: str, tmpdir_factory: pytest.TempdirFactory
env: str,
rds_testing_instance: RDSTestingInstance,
tmpdir_factory: pytest.TempdirFactory,
) -> Generator[RDSJobQueue, Any, None]:
if env == "local":
queue = RDSJobQueue(
f"sqlite:///{tmpdir_factory.mktemp('integration')}/local.db"
)
else:
db = RDSTestingInstance("decodecloudintegrationtests")
queue = RDSJobQueue(db.db_url)
queue = RDSJobQueue(rds_testing_instance.db_url)
queue.create(err_on_exists=True)
yield queue
queue.delete()
if env == "aws":
db.cleanup()
rds_testing_instance.cleanup()


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(scope="session", autouse=True)
def override_filesystem_dep(
base_filesystem: FileSystem, monkeypatch_module: pytest.MonkeyPatch
) -> None:
Expand All @@ -106,7 +119,7 @@ def override_filesystem_dep(
)


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(scope="session", autouse=True)
def override_queue_dep(
queue: RDSJobQueue, monkeypatch_module: pytest.MonkeyPatch
) -> None:
Expand All @@ -117,7 +130,7 @@ def override_queue_dep(
)


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(scope="session", autouse=True)
def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) -> None:
monkeypatch_module.setitem(
workerfacing_app.dependency_overrides, # type: ignore
Expand All @@ -132,7 +145,7 @@ def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) ->
)


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(scope="session", autouse=True)
def override_internal_api_key_secret(
monkeypatch_module: pytest.MonkeyPatch, internal_api_key_secret: str
) -> str:
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/endpoints/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class _TestEndpoint(abc.ABC):
endpoint = ""

@abc.abstractmethod
@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def passing_params(self, *args: Any, **kwargs: Any) -> list[EndpointParams]:
raise NotImplementedError

@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def client(self) -> TestClient:
return TestClient(workerfacing_app)

Expand Down
10 changes: 5 additions & 5 deletions tests/integration/endpoints/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,25 @@
from workerfacing_api.core.filesystem import FileSystem, S3Filesystem


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def data_file1_name(base_dir: str) -> str:
return f"{base_dir}/data/test/data_file1.txt"


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def data_file1_path(env: str, data_file1_name: str, base_filesystem: FileSystem) -> str:
if env == "aws":
base_filesystem = cast(S3Filesystem, base_filesystem)
return f"s3://{base_filesystem.bucket}/{data_file1_name}"
return data_file1_name


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def data_file1_contents() -> str:
return "data_file1"


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(scope="session", autouse=True)
def data_file1(
env: str,
base_filesystem: FileSystem,
Expand All @@ -51,7 +51,7 @@ def data_file1(
class TestFiles(_TestEndpoint):
endpoint = "/files"

@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def passing_params(self, data_file1_path: str) -> list[EndpointParams]:
return [
EndpointParams("get", f"{data_file1_path}/url"),
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/endpoints/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@
from workerfacing_api.schemas.rds_models import JobStates


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def app() -> AppSpecs:
return AppSpecs(cmd=["cmd"], env={"env": "var"})


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def handler() -> HandlerSpecs:
return HandlerSpecs(image_url="u", files_up={"output": "out"})


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def paths_upload(
env: str, test_username: str, base_filesystem: FileSystem
) -> PathsUploadSpecs:
Expand All @@ -55,7 +55,7 @@ def paths_upload(
class TestJobs(_TestEndpoint):
endpoint = "/jobs"

@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def passing_params(self) -> list[EndpointParams]:
return [EndpointParams("get", params={"memory": 1})]

Expand Down
10 changes: 6 additions & 4 deletions tests/unit/core/test_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,15 @@ def mock_aws_(self, request: pytest.FixtureRequest) -> bool:

@pytest.fixture(scope="class")
def base_filesystem(
self, mock_aws_: bool, bucket_suffix: str
self, mock_aws_: bool, s3_testing_bucket: S3TestingBucket
) -> Generator[S3Filesystem, Any, None]:
context_manager = mock_aws if mock_aws_ else nullcontext
with context_manager():
testing_bucket = S3TestingBucket(bucket_suffix)
yield S3Filesystem(testing_bucket.s3_client, testing_bucket.bucket_name)
testing_bucket.cleanup()
s3_testing_bucket.create()
yield S3Filesystem(
s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name
)
s3_testing_bucket.delete()

@pytest.fixture(scope="class", autouse=True)
def data_file1(
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/core/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def base_queue(
@pytest.mark.aws
class TestRDSAWSQueue(_TestRDSQueue):
@pytest.fixture(scope="class")
def base_queue(self) -> Generator[RDSJobQueue, Any, None]:
db = RDSTestingInstance("decodecloudqueuetests")
yield RDSJobQueue(db.db_url)
db.cleanup()
def base_queue(
self, rds_testing_instance: RDSTestingInstance
) -> Generator[RDSJobQueue, Any, None]:
yield RDSJobQueue(rds_testing_instance.db_url)
rds_testing_instance.cleanup()