Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions python/lib/sift_client/_internal/low_level_wrappers/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from sift.rules.v1.rules_pb2_grpc import RuleServiceStub

from sift_client._internal.low_level_wrappers.base import DEFAULT_PAGE_SIZE, LowLevelClientBase
from sift_client._internal.low_level_wrappers.reports import ReportsLowLevelClient
from sift_client._internal.low_level_wrappers.jobs import JobsLowLevelClient
from sift_client._internal.util.timestamp import to_pb_timestamp
from sift_client._internal.util.util import count_non_none
from sift_client.sift_types.rule import (
Expand All @@ -69,7 +69,7 @@
from datetime import datetime

from sift_client.sift_types.channel import ChannelReference
from sift_client.sift_types.report import Report
from sift_client.sift_types.job import Job

# Configure logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -587,8 +587,8 @@ async def evaluate_rules(
report_name: str | None = None,
tags: list[str | Tag] | None = None,
organization_id: str | None = None,
) -> tuple[int, Report | None, str | None]:
"""Evaluate a rule.
) -> tuple[int, str | None, Job | None]:
"""Evaluate rules.

Args:
run_id: The run ID to evaluate.
Expand All @@ -604,7 +604,7 @@ async def evaluate_rules(
organization_id: The organization ID to evaluate.

Returns:
The result of the rule execution.
The annotation_count, report_id, and job for the pending report.
"""
if count_non_none(run_id, asset_ids) > 1:
raise ValueError(
Expand Down Expand Up @@ -664,13 +664,13 @@ async def evaluate_rules(
request
)
response = cast("EvaluateRulesResponse", response)
created_annotation_count = response.created_annotation_count
report_id = response.report_id
job_id = response.job_id
if report_id:
report = await ReportsLowLevelClient(self._grpc_client).get_report(report_id=report_id)
return created_annotation_count, report, job_id
return created_annotation_count, None, job_id

if job_id:
job = await JobsLowLevelClient(self._grpc_client).get_job(job_id=job_id)
else:
job = None
return response.created_annotation_count, response.report_id, job

async def get_rule_version(self, rule_version_id: str) -> Rule:
"""Get a rule at a specific version by rule_version_id.
Expand Down
101 changes: 101 additions & 0 deletions python/lib/sift_client/_tests/resources/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from grpc.aio import AioRpcError
Expand Down Expand Up @@ -291,6 +292,106 @@ async def test_retry_finished_job_no_effect(self, jobs_api_async):
with pytest.raises(AioRpcError, match="job cannot be retried"):
await jobs_api_async.retry(job)

class TestWaitUntilComplete:
"""Tests for the async wait_until_complete method."""

@pytest.mark.asyncio
async def test_returns_immediately_when_job_already_complete(self, jobs_api_async):
"""When get returns a completed job on first call, wait returns immediately."""
job_id = "test-job-id"
mock_job = MagicMock()
mock_job.job_status = JobStatus.FINISHED

with patch(
"sift_client.resources.jobs.JobsAPIAsync.get",
new_callable=AsyncMock,
return_value=mock_job,
) as mock_get:
result = await jobs_api_async.wait_until_complete(job=job_id)

assert result is mock_job
assert result.job_status == JobStatus.FINISHED
mock_get.assert_called_once_with(job_id)

@pytest.mark.asyncio
async def test_returns_immediately_when_job_already_failed(self, jobs_api_async):
"""When get returns a failed job on first call, wait returns immediately."""
job_id = "test-job-id"
mock_job = MagicMock()
mock_job.job_status = JobStatus.FAILED

with patch(
"sift_client.resources.jobs.JobsAPIAsync.get",
new_callable=AsyncMock,
return_value=mock_job,
) as mock_get:
result = await jobs_api_async.wait_until_complete(job=job_id)

assert result is mock_job
assert result.job_status == JobStatus.FAILED
mock_get.assert_called_once_with(job_id)

@pytest.mark.asyncio
async def test_returns_immediately_when_job_already_cancelled(self, jobs_api_async):
"""When get returns a cancelled job on first call, wait returns immediately."""
job_id = "test-job-id"
mock_job = MagicMock()
mock_job.job_status = JobStatus.CANCELLED

with patch(
"sift_client.resources.jobs.JobsAPIAsync.get",
new_callable=AsyncMock,
return_value=mock_job,
) as mock_get:
result = await jobs_api_async.wait_until_complete(job=job_id)

assert result is mock_job
assert result.job_status == JobStatus.CANCELLED
mock_get.assert_called_once_with(job_id)

@pytest.mark.asyncio
async def test_polls_until_complete(self, jobs_api_async):
"""When get returns running then finished, wait returns after second poll."""
job_id = "test-job-id"
running_job = MagicMock()
running_job.job_status = JobStatus.RUNNING
finished_job = MagicMock()
finished_job.job_status = JobStatus.FINISHED

with patch(
"sift_client.resources.jobs.JobsAPIAsync.get",
new_callable=AsyncMock,
side_effect=[running_job, finished_job],
) as mock_get:
result = await jobs_api_async.wait_until_complete(
job=job_id,
polling_interval_secs=0.01,
timeout_secs=10.0,
)

assert result is finished_job
assert result.job_status == JobStatus.FINISHED
assert mock_get.call_count == 2

@pytest.mark.asyncio
async def test_raises_timeout_error_when_not_complete_in_time(self, jobs_api_async):
"""When job never reaches a completed state, TimeoutError is raised."""
job_id = "test-job-id"
running_job = MagicMock()
running_job.job_status = JobStatus.RUNNING

with patch(
"sift_client.resources.jobs.JobsAPIAsync.get",
new_callable=AsyncMock,
return_value=running_job,
):
with pytest.raises(TimeoutError):
await jobs_api_async.wait_until_complete(
job=job_id,
polling_interval_secs=0.05,
timeout_secs=0.1,
)

class TestJobProperties:
"""Tests for job property methods."""

Expand Down
Loading
Loading