Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/unit/vertexai/genai/replays/test_create_evaluation_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,35 @@ def test_create_eval_run_with_metric_resource_name(mock_uuid4, client):
# == INPUT_DF_WITH_CONTEXT_AND_HISTORY.iloc[i]["response"]
# )
# assert evaluation_run.error is None
def test_create_eval_run_with_red_teaming_config(client):
"""Tests that create_evaluation_run() with red_teaming_config sends analysisConfigs."""
evaluation_run = client.evals.create_evaluation_run(
name="test_red_teaming",
display_name="test_red_teaming",
dataset=types.EvaluationRunDataSource(evaluation_set=EVAL_SET_NAME),
dest=GCS_DEST,
metrics=[],
red_teaming_config=types.RedTeamingAnalysisConfig(
attack_categories=["FINANCIAL_OR_CREDENTIAL_PHISHING"],
vulnerable_tools=[
types.VulnerableTool(
tool_name="search_flights",
json_paths=["$.flights[0].description"],
),
],
),
)
assert isinstance(evaluation_run, types.EvaluationRun)
assert evaluation_run.display_name == "test_red_teaming"
assert evaluation_run.state == types.EvaluationRunState.PENDING
assert evaluation_run.analysis_configs is not None
assert len(evaluation_run.analysis_configs) == 1
rt_config = evaluation_run.analysis_configs[0].red_teaming_analysis_config
assert rt_config.attack_categories == ["FINANCIAL_OR_CREDENTIAL_PHISHING"]
assert rt_config.vulnerable_tools[0].tool_name == "search_flights"
assert evaluation_run.error is None


pytest_plugins = ("pytest_asyncio",)


Expand Down
140 changes: 140 additions & 0 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,6 +1835,146 @@ def test_loss_analysis_metrics_accepts_metric_object(self):
assert result[0].candidate == "agent-1"


class TestRedTeamingTypes:
"""Unit tests for red teaming type definitions."""

def test_red_teaming_analysis_config_construction(self):
config = common_types.RedTeamingAnalysisConfig(
attack_categories=["FINANCIAL_OR_CREDENTIAL_PHISHING"],
vulnerable_tools=[
common_types.VulnerableTool(
tool_name="search_flights",
json_paths=["$.flights[0].description"],
),
],
)
assert len(config.attack_categories) == 1
assert config.vulnerable_tools[0].tool_name == "search_flights"

def test_red_teaming_analysis_config_optional_fields(self):
config = common_types.RedTeamingAnalysisConfig()
assert config.attack_categories is None
assert config.vulnerable_tools is None

def test_evaluation_run_results_has_red_teaming_results(self):
results = common_types.EvaluationRunResults(
red_teaming_analysis_results=[
common_types.RedTeamingAnalysisResult(
category_results=[
common_types.AttackCategoryResult(
attack_category="FINANCIAL_OR_CREDENTIAL_PHISHING",
attack_success_rate=0.9,
),
],
)
],
)
assert len(results.red_teaming_analysis_results) == 1
assert (
results.red_teaming_analysis_results[0]
.category_results[0]
.attack_success_rate
== 0.9
)

def test_create_params_accepts_analysis_configs(self):
params = common_types._CreateEvaluationRunParameters(
name="test-run",
analysis_configs=[
common_types.AnalysisConfig(
red_teaming_analysis_config=common_types.RedTeamingAnalysisConfig(
attack_categories=["FINANCIAL_OR_CREDENTIAL_PHISHING"],
),
),
],
)
assert len(params.analysis_configs) == 1


class TestResolveRedTeamingConfig:
"""Unit tests for _resolve_red_teaming_config."""

def test_none_when_no_config(self):
result = _evals_utils._resolve_red_teaming_config()
assert result is None

def test_wraps_config_in_analysis_configs(self):
config = common_types.RedTeamingAnalysisConfig(
attack_categories=["FINANCIAL_OR_CREDENTIAL_PHISHING"],
)
result = _evals_utils._resolve_red_teaming_config(config)
assert len(result) == 1
assert isinstance(result[0], common_types.AnalysisConfig)
assert (
result[0].red_teaming_analysis_config.attack_categories[0]
== "FINANCIAL_OR_CREDENTIAL_PHISHING"
)

def test_accepts_dict_input(self):
result = _evals_utils._resolve_red_teaming_config(
{"attack_categories": ["INJECTED_HOSTILITY_AND_HARASSMENT"]}
)
assert len(result) == 1
assert isinstance(result[0], common_types.AnalysisConfig)


class TestRedTeamingSerializationConverters:
"""Unit tests for red teaming serialization converters."""

def test_analysis_config_to_vertex(self):
config = common_types.AnalysisConfig(
analysis_name="my-analysis",
red_teaming_analysis_config=common_types.RedTeamingAnalysisConfig(
attack_categories=["FINANCIAL_OR_CREDENTIAL_PHISHING"],
vulnerable_tools=[
common_types.VulnerableTool(
tool_name="search_flights",
json_paths=["$.flights[0].description"],
),
],
),
)
result = evals._AnalysisConfig_to_vertex(config)
assert result["analysisName"] == "my-analysis"
rt = result["redTeamingAnalysisConfig"]
assert rt["attackCategories"] == ["FINANCIAL_OR_CREDENTIAL_PHISHING"]
assert rt["vulnerableTools"][0]["toolName"] == "search_flights"

def test_analysis_config_from_vertex(self):
api_response = {
"analysisName": "my-analysis",
"redTeamingAnalysisConfig": {
"attackCategories": ["FINANCIAL_OR_CREDENTIAL_PHISHING"],
"vulnerableTools": [
{"toolName": "search_flights", "jsonPaths": ["$.flights[0].description"]},
],
},
}
result = evals._AnalysisConfig_from_vertex(api_response)
assert result["analysis_name"] == "my-analysis"
rt = result["red_teaming_analysis_config"]
assert rt["attack_categories"] == ["FINANCIAL_OR_CREDENTIAL_PHISHING"]
assert rt["vulnerable_tools"][0]["tool_name"] == "search_flights"

def test_analysis_config_round_trip(self):
original = common_types.AnalysisConfig(
analysis_name="round-trip",
red_teaming_analysis_config=common_types.RedTeamingAnalysisConfig(
attack_categories=["FINANCIAL_OR_CREDENTIAL_PHISHING"],
vulnerable_tools=[
common_types.VulnerableTool(tool_name="search_flights"),
],
),
)
serialized = evals._AnalysisConfig_to_vertex(original)
deserialized = evals._AnalysisConfig_from_vertex(serialized)
assert deserialized["analysis_name"] == "round-trip"
assert (
deserialized["red_teaming_analysis_config"]["vulnerable_tools"][0]["tool_name"]
== "search_flights"
)


class TestResolveMetricName:
"""Unit tests for _resolve_metric_name."""

Expand Down
14 changes: 14 additions & 0 deletions vertexai/_genai/_evals_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,20 @@ def _resolve_eval_run_loss_configs(
return configs


def _resolve_red_teaming_config(
red_teaming_config: Optional[types.RedTeamingAnalysisConfigOrDict] = None,
) -> Optional[list[types.AnalysisConfig]]:
"""Wraps a RedTeamingAnalysisConfig into analysis_configs for the API."""
if not red_teaming_config:
return None
config = (
types.RedTeamingAnalysisConfig.model_validate(red_teaming_config)
if isinstance(red_teaming_config, dict)
else red_teaming_config
)
return [types.AnalysisConfig(red_teaming_analysis_config=config)]


def _resolve_loss_analysis_config(
eval_result: types.EvaluationResult,
config: Optional[types.LossAnalysisConfig] = None,
Expand Down
Loading
Loading