Skip to content

Commit c03f4c8

Browse files
authored
python(feat): Add more support for rule versions to SiftClient (#479)
1 parent 0204628 commit c03f4c8

7 files changed

Lines changed: 456 additions & 23 deletions

File tree

python/lib/sift_client/_internal/low_level_wrappers/rules.py

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,18 @@
33
import logging
44
from typing import TYPE_CHECKING, Any, Sequence, cast
55

6-
from sift.common.type.v1.resource_identifier_pb2 import ResourceIdentifier, ResourceIdentifiers
6+
from sift.common.type.v1.resource_identifier_pb2 import (
7+
NamedResources,
8+
Names,
9+
ResourceIdentifier,
10+
ResourceIdentifiers,
11+
)
712
from sift.rule_evaluation.v1.rule_evaluation_pb2 import (
813
AssetsTimeRange,
14+
EvaluateRulesAnnotationOptions,
15+
EvaluateRulesFromCurrentRuleVersions,
16+
EvaluateRulesFromReportTemplate,
17+
EvaluateRulesFromRuleVersions,
918
EvaluateRulesRequest,
1019
EvaluateRulesResponse,
1120
RunTimeRange,
@@ -15,6 +24,8 @@
1524
ArchiveRuleRequest,
1625
BatchArchiveRulesRequest,
1726
BatchGetRulesRequest,
27+
BatchGetRuleVersionsRequest,
28+
BatchGetRuleVersionsResponse,
1829
BatchUnarchiveRulesRequest,
1930
BatchUpdateRulesRequest,
2031
BatchUpdateRulesResponse,
@@ -24,7 +35,11 @@
2435
CreateRuleResponse,
2536
GetRuleRequest,
2637
GetRuleResponse,
38+
GetRuleVersionRequest,
39+
GetRuleVersionResponse,
2740
ListRulesRequest,
41+
ListRuleVersionsRequest,
42+
ListRuleVersionsResponse,
2843
RuleAssetConfiguration,
2944
RuleConditionExpression,
3045
UnarchiveRuleRequest,
@@ -45,6 +60,7 @@
4560
Rule,
4661
RuleCreate,
4762
RuleUpdate,
63+
RuleVersion,
4864
)
4965
from sift_client.sift_types.tag import Tag
5066
from sift_client.transport import GrpcClient, WithGrpcClient
@@ -506,6 +522,57 @@ async def list_all_rules(
506522
max_results=max_results,
507523
)
508524

525+
async def list_rule_versions(
526+
self,
527+
rule_id: str,
528+
*,
529+
filter_query: str | None = None,
530+
order_by: str | None = None,
531+
page_size: int | None = None,
532+
page_token: str | None = None,
533+
) -> tuple[list[RuleVersion], str]:
534+
"""List rule versions for a rule.
535+
536+
Args:
537+
rule_id: The rule ID to list versions for.
538+
filter_query: Optional CEL filter (fields: rule_version_id, user_notes, change_message).
539+
order_by: Unused, for _handle_pagination compatibility.
540+
page_size: Maximum number of versions per page.
541+
page_token: Token for the next page.
542+
543+
Returns:
544+
Tuple of (list of RuleVersions, next page token or empty string).
545+
"""
546+
_ = order_by
547+
request_kwargs: dict[str, Any] = {
548+
"rule_id": rule_id,
549+
"page_size": page_size or DEFAULT_PAGE_SIZE,
550+
"page_token": page_token or "",
551+
}
552+
if filter_query:
553+
request_kwargs["filter"] = filter_query
554+
request = ListRuleVersionsRequest(**request_kwargs)
555+
response = await self._grpc_client.get_stub(RuleServiceStub).ListRuleVersions(request)
556+
response = cast("ListRuleVersionsResponse", response)
557+
versions = [RuleVersion._from_proto(p) for p in response.rule_versions]
558+
return versions, response.next_page_token or ""
559+
560+
async def list_all_rule_versions(
561+
self,
562+
rule_id: str,
563+
*,
564+
filter_query: str | None = None,
565+
max_results: int | None = None,
566+
page_size: int | None = DEFAULT_PAGE_SIZE,
567+
) -> list[RuleVersion]:
568+
"""List all rule versions for a rule, with optional CEL filter."""
569+
return await self._handle_pagination(
570+
self.list_rule_versions,
571+
kwargs={"rule_id": rule_id, "filter_query": filter_query},
572+
page_size=page_size,
573+
max_results=max_results,
574+
)
575+
509576
async def evaluate_rules(
510577
self,
511578
*,
@@ -571,13 +638,22 @@ async def evaluate_rules(
571638
if all_applicable_rules:
572639
kwargs["all_applicable_rules"] = all_applicable_rules
573640
if rule_ids:
574-
kwargs["rules"] = {"rules": ResourceIdentifiers(ids={"ids": rule_ids})} # type: ignore
641+
kwargs["rules"] = EvaluateRulesFromCurrentRuleVersions(
642+
rules=ResourceIdentifiers(ids={"ids": rule_ids}) # type: ignore[arg-type]
643+
)
575644
if rule_version_ids:
576-
kwargs["rule_versions"] = rule_version_ids
645+
kwargs["rule_versions"] = EvaluateRulesFromRuleVersions(
646+
rule_version_ids=rule_version_ids
647+
)
577648
if report_template_id:
578-
kwargs["report_template"] = report_template_id
649+
kwargs["report_template"] = EvaluateRulesFromReportTemplate(
650+
report_template=ResourceIdentifier(id=report_template_id)
651+
)
579652
if tags:
580-
kwargs["tags"] = [tag.name if isinstance(tag, Tag) else tag for tag in tags]
653+
tag_names = [tag.name if isinstance(tag, Tag) else tag for tag in tags]
654+
kwargs["annotation_options"] = EvaluateRulesAnnotationOptions(
655+
tags=NamedResources(names=Names(names=tag_names)) # type: ignore[arg-type]
656+
)
581657
if report_name:
582658
kwargs["report_name"] = report_name
583659
if organization_id:
@@ -595,3 +671,31 @@ async def evaluate_rules(
595671
report = await ReportsLowLevelClient(self._grpc_client).get_report(report_id=report_id)
596672
return created_annotation_count, report, job_id
597673
return created_annotation_count, None, job_id
674+
675+
async def get_rule_version(self, rule_version_id: str) -> Rule:
676+
"""Get a rule at a specific version by rule_version_id.
677+
678+
Args:
679+
rule_version_id: The rule version ID to get.
680+
681+
Returns:
682+
The Rule at that version.
683+
"""
684+
request = GetRuleVersionRequest(rule_version_id=rule_version_id)
685+
response = await self._grpc_client.get_stub(RuleServiceStub).GetRuleVersion(request)
686+
grpc_rule = cast("GetRuleVersionResponse", response).rule
687+
return Rule._from_proto(grpc_rule)
688+
689+
async def batch_get_rule_versions(self, rule_version_ids: list[str]) -> list[Rule]:
690+
"""Get multiple rules at specific versions by rule_version_ids.
691+
692+
Args:
693+
rule_version_ids: The rule version IDs to get.
694+
695+
Returns:
696+
List of Rules at those versions (order may match request order).
697+
"""
698+
request = BatchGetRuleVersionsRequest(rule_version_ids=rule_version_ids)
699+
response = await self._grpc_client.get_stub(RuleServiceStub).BatchGetRuleVersions(request)
700+
response = cast("BatchGetRuleVersionsResponse", response)
701+
return [Rule._from_proto(r) for r in response.rules]

python/lib/sift_client/_tests/resources/test_reports.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,35 @@ def test_client_binding(sift_client):
5454

5555
@pytest.mark.integration
5656
class TestReports:
57+
def test_create_from_rule_versions(self, nostromo_run, test_rule, sift_client):
58+
"""Create a report from specific rule version IDs."""
59+
rule_versions = sift_client.rules.list_rule_versions(test_rule)
60+
assert rule_versions, "test_rule should have at least one version"
61+
report = sift_client.reports.create_from_rule_versions(
62+
name="report_from_rule_versions",
63+
run=nostromo_run,
64+
organization_id=nostromo_run.organization_id,
65+
rule_versions=[rule_versions[0].rule_version_id],
66+
)
67+
assert report is not None
68+
assert report.run_id == nostromo_run.id_
69+
assert report.name == "report_from_rule_versions"
70+
71+
def test_create_from_rule_versions_with_rule_version_objects(
72+
self, nostromo_run, test_rule, sift_client
73+
):
74+
"""Create a report passing RuleVersion instances."""
75+
rule_versions = sift_client.rules.list_rule_versions(test_rule)
76+
assert rule_versions
77+
report = sift_client.reports.create_from_rule_versions(
78+
name="report_from_rule_versions_objs",
79+
run=nostromo_run,
80+
organization_id=nostromo_run.organization_id,
81+
rule_versions=rule_versions[:1],
82+
)
83+
assert report is not None
84+
assert report.run_id == nostromo_run.id_
85+
5786
def test_create_from_rules(self, nostromo_run, test_rule, sift_client):
5887
report_from_rules = sift_client.reports.create_from_rules(
5988
name="report_from_rules",
@@ -146,17 +175,16 @@ def test_archive(self, nostromo_run, test_rule, sift_client):
146175
assert archived_report is not None
147176
assert archived_report.is_archived == True
148177

149-
def test_unarchive(self, sift_client):
150-
reports_from_rules = sift_client.reports.list_(
151-
name="report_from_rules", include_archived=True
178+
def test_unarchive(self, nostromo_run, test_rule, sift_client):
179+
# create a report, archive it, then unarchive it
180+
report_from_rules = sift_client.reports.create_from_rules(
181+
name="report_from_rules_unarchive",
182+
run=nostromo_run,
183+
rules=[test_rule],
152184
)
153-
report_from_rules = None
154-
for report_from_rules in reports_from_rules:
155-
if report_from_rules.is_archived:
156-
report_from_rules = report_from_rules
157-
break
158185
assert report_from_rules is not None
159-
assert report_from_rules.is_archived == True
160-
unarchived_report = sift_client.reports.unarchive(report=report_from_rules)
186+
archived_report = sift_client.reports.archive(report=report_from_rules)
187+
assert archived_report.is_archived is True
188+
unarchived_report = sift_client.reports.unarchive(report=archived_report)
161189
assert unarchived_report is not None
162-
assert unarchived_report.is_archived == False
190+
assert unarchived_report.is_archived is False

python/lib/sift_client/_tests/resources/test_rules.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
RuleAnnotationType,
2323
RuleCreate,
2424
RuleUpdate,
25+
RuleVersion,
2526
)
2627

2728
pytestmark = pytest.mark.integration
@@ -215,6 +216,118 @@ async def test_list_with_time_filters(self, rules_api_async):
215216
for rule in rules:
216217
assert rule.created_date >= one_year_ago
217218

219+
class TestListRuleVersions:
220+
"""Tests for the async list_rule_versions method."""
221+
222+
@pytest.mark.asyncio
223+
async def test_list_rule_versions_by_rule(self, rules_api_async, test_rule):
224+
"""Test listing rule versions for a rule."""
225+
versions = await rules_api_async.list_rule_versions(test_rule)
226+
assert isinstance(versions, list)
227+
assert len(versions) >= 1
228+
for v in versions:
229+
assert isinstance(v, RuleVersion)
230+
assert v.rule_id == test_rule.id_
231+
assert v.rule_version_id
232+
assert v.version
233+
assert v.created_date
234+
235+
@pytest.mark.asyncio
236+
async def test_list_rule_versions_by_rule_id_str(self, rules_api_async, test_rule):
237+
"""Test listing rule versions by rule ID string."""
238+
versions = await rules_api_async.list_rule_versions(test_rule.id_)
239+
assert isinstance(versions, list)
240+
assert len(versions) >= 1
241+
for v in versions:
242+
assert v.rule_id == test_rule.id_
243+
244+
@pytest.mark.asyncio
245+
async def test_list_rule_versions_with_limit(self, rules_api_async, test_rule):
246+
"""Test listing rule versions with limit."""
247+
versions = await rules_api_async.list_rule_versions(test_rule, limit=1)
248+
assert isinstance(versions, list)
249+
assert len(versions) <= 1
250+
if versions:
251+
assert isinstance(versions[0], RuleVersion)
252+
253+
@pytest.mark.asyncio
254+
async def test_list_rule_versions_with_rule_version_ids_filter(
255+
self, rules_api_async, test_rule
256+
):
257+
"""Test listing rule versions filtered by rule_version_ids."""
258+
all_versions = await rules_api_async.list_rule_versions(test_rule)
259+
assert all_versions
260+
first_id = all_versions[0].rule_version_id
261+
versions = await rules_api_async.list_rule_versions(
262+
test_rule, rule_version_ids=[first_id]
263+
)
264+
assert len(versions) == 1
265+
assert versions[0].rule_version_id == first_id
266+
267+
class TestGetRuleVersion:
268+
"""Tests for the async get_rule_version method."""
269+
270+
@pytest.mark.asyncio
271+
async def test_get_rule_version_by_id(self, rules_api_async, test_rule):
272+
"""Test getting a rule at a specific version by rule_version_id."""
273+
versions = await rules_api_async.list_rule_versions(test_rule)
274+
assert versions
275+
rule_at_version = await rules_api_async.get_rule_version(versions[0].rule_version_id)
276+
assert rule_at_version is not None
277+
assert rule_at_version.id_ == test_rule.id_
278+
assert rule_at_version.rule_version is not None
279+
assert rule_at_version.rule_version.rule_version_id == versions[0].rule_version_id
280+
281+
@pytest.mark.asyncio
282+
async def test_get_rule_version_by_rule_version_instance(self, rules_api_async, test_rule):
283+
"""Test getting a rule at a specific version by passing RuleVersion instance."""
284+
versions = await rules_api_async.list_rule_versions(test_rule)
285+
assert versions
286+
rule_at_version = await rules_api_async.get_rule_version(versions[0])
287+
assert rule_at_version is not None
288+
assert rule_at_version.id_ == test_rule.id_
289+
assert rule_at_version.rule_version.rule_version_id == versions[0].rule_version_id
290+
291+
class TestBatchGetRuleVersions:
292+
"""Tests for the async batch_get_rule_versions method."""
293+
294+
@pytest.mark.asyncio
295+
async def test_batch_get_rule_versions_by_ids(self, rules_api_async, test_rule):
296+
"""Test batch getting rules by rule_version_id strings."""
297+
versions = await rules_api_async.list_rule_versions(test_rule)
298+
assert versions
299+
ids = [v.rule_version_id for v in versions[:2]]
300+
rules = await rules_api_async.batch_get_rule_versions(ids)
301+
assert len(rules) == len(ids)
302+
returned_ids = {r.rule_version.rule_version_id for r in rules if r.rule_version}
303+
assert returned_ids >= set(ids)
304+
for r in rules:
305+
assert r.id_ == test_rule.id_
306+
307+
@pytest.mark.asyncio
308+
async def test_batch_get_rule_versions_by_rule_version_instances(
309+
self, rules_api_async, test_rule
310+
):
311+
"""Test batch getting rules by passing RuleVersion instances."""
312+
versions = await rules_api_async.list_rule_versions(test_rule)
313+
assert versions
314+
rules = await rules_api_async.batch_get_rule_versions(versions[:2])
315+
assert len(rules) <= 2
316+
for r in rules:
317+
assert r.id_ == test_rule.id_
318+
if len(versions) >= 2:
319+
assert len(rules) == 2
320+
321+
@pytest.mark.asyncio
322+
async def test_batch_get_rule_versions_single(self, rules_api_async, test_rule):
323+
"""Test batch_get_rule_versions with a single version ID."""
324+
versions = await rules_api_async.list_rule_versions(test_rule)
325+
assert versions
326+
rules = await rules_api_async.batch_get_rule_versions([versions[0].rule_version_id])
327+
assert len(rules) == 1
328+
assert rules[0].id_ == test_rule.id_
329+
assert rules[0].rule_version.rule_version_id == versions[0].rule_version_id
330+
218331
class TestFind:
219332
"""Tests for the async find method."""
220333

0 commit comments

Comments
 (0)