Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/unit/vertexai/genai/replays/test_create_evaluation_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,31 @@ def test_create_eval_run_with_inference_configs(client):
assert evaluation_run.error is None


def test_create_eval_run_with_allow_cross_region_model(client):
"""Tests that create_evaluation_run() works with allow_cross_region_model in config."""
client._api_client._http_options.api_version = "v1beta1"
inference_config = types.EvaluationRunInferenceConfig(
model=MODEL_NAME,
prompt_template=types.EvaluationRunPromptTemplate(
prompt_template="test prompt template"
),
)
evaluation_run = client.evals.create_evaluation_run(
name="test_inference_config",
display_name="test_inference_config",
dataset=types.EvaluationRunDataSource(evaluation_set=EVAL_SET_NAME),
dest=GCS_DEST,
metrics=[GENERAL_QUALITY_METRIC],
inference_configs={"model_1": inference_config},
labels={"label1": "value1"},
config={"allow_cross_region_model": True},
)
assert isinstance(evaluation_run, types.EvaluationRun)
assert evaluation_run.display_name == "test_inference_config"
assert evaluation_run.state == types.EvaluationRunState.PENDING
assert evaluation_run.error is None


@mock.patch("uuid.uuid4")
def test_create_eval_run_with_metric_resource_name(mock_uuid4, client):
"""Tests create_evaluation_run with metric_resource_name."""
Expand Down
101 changes: 101 additions & 0 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9136,3 +9136,104 @@ def test_computation_metric_retry_on_resource_exhausted(
summary_metric = result.summary_metrics[0]
assert summary_metric.metric_name == "bleu"
assert summary_metric.mean_score == 0.85


class TestAllowCrossRegionModel:
"""Tests for allow_cross_region_model flag for create_evaluation_run."""

def setup_method(self, method):
self.mock_api_client = mock.MagicMock()
self.mock_api_client.vertexai = True

self.mock_response = mock.MagicMock()
self.mock_response.body = json.dumps(
{
"name": "projects/123/locations/us-central1/evaluationRuns/456",
"displayName": "test_run",
"state": "PENDING",
}
)
self.mock_api_client.request.return_value = self.mock_response

def test_create_evaluation_run_config_has_allow_cross_region_model(self):
"""Verifies allow_cross_region_model field exists on CreateEvaluationRunConfig."""
config = vertexai_genai_types.CreateEvaluationRunConfig(
allow_cross_region_model=True,
)
assert config.allow_cross_region_model is True

def test_create_evaluation_run_config_from_dict(self):
"""Verifies allow_cross_region_model can be set via dict on CreateEvaluationRunConfig."""
config = vertexai_genai_types.CreateEvaluationRunConfig.model_validate(
{"allow_cross_region_model": True}
)
assert config.allow_cross_region_model is True

def test_create_evaluation_run_config_default_is_none(self):
"""Verifies the default value of allow_cross_region_model is None."""
config = vertexai_genai_types.CreateEvaluationRunConfig()
assert config.allow_cross_region_model is None

def test_create_evaluation_run_passes_allow_cross_region_model(self):
"""Verifies allow_cross_region_model is sent inside evaluationConfig in the API request."""
evals_module = evals.Evals(api_client_=self.mock_api_client)

evals_module.create_evaluation_run(
dataset=vertexai_genai_types.EvaluationRunDataSource(
evaluation_set="projects/123/locations/us-central1/evaluationSets/789"
),
metrics=[
vertexai_genai_types.EvaluationRunMetric(
metric="general_quality_v1",
metric_config=vertexai_genai_types.UnifiedMetric(
predefined_metric_spec=genai_types.PredefinedMetricSpec(
metric_spec_name="general_quality_v1",
)
),
)
],
dest="gs://test-bucket/output",
config={"allow_cross_region_model": True},
)

self.mock_api_client.request.assert_called_once()
call_args = self.mock_api_client.request.call_args
request_body = call_args[0][2] # Third positional arg is the request dict
assert (
request_body.get("evaluationConfig", {}).get("allowCrossRegionModel")
is True
)

@pytest.mark.asyncio
async def test_create_evaluation_run_async_passes_allow_cross_region_model(self):
"""Verifies allow_cross_region_model is sent inside evaluationConfig in the async API request."""
self.mock_api_client.async_request = mock.AsyncMock(
return_value=self.mock_response
)
async_evals_module = evals.AsyncEvals(api_client_=self.mock_api_client)

await async_evals_module.create_evaluation_run(
dataset=vertexai_genai_types.EvaluationRunDataSource(
evaluation_set="projects/123/locations/us-central1/evaluationSets/789"
),
metrics=[
vertexai_genai_types.EvaluationRunMetric(
metric="general_quality_v1",
metric_config=vertexai_genai_types.UnifiedMetric(
predefined_metric_spec=genai_types.PredefinedMetricSpec(
metric_spec_name="general_quality_v1",
)
),
)
],
dest="gs://test-bucket/output",
config={"allow_cross_region_model": True},
)

self.mock_api_client.async_request.assert_called_once()
call_args = self.mock_api_client.async_request.call_args
request_body = call_args[0][2] # Third positional arg is the request dict
assert (
request_body.get("evaluationConfig", {}).get("allowCrossRegionModel")
is True
)
30 changes: 30 additions & 0 deletions vertexai/_genai/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ def _EvaluationRunConfig_from_vertex(
[item for item in getv(from_object, ["lossAnalysisConfig"])],
)

if getv(from_object, ["allowCrossRegionModel"]) is not None:
setv(
to_object,
["allow_cross_region_model"],
getv(from_object, ["allowCrossRegionModel"]),
)

return to_object


Expand Down Expand Up @@ -425,6 +432,13 @@ def _EvaluationRunConfig_to_vertex(
[item for item in getv(from_object, ["loss_analysis_config"])],
)

if getv(from_object, ["allow_cross_region_model"]) is not None:
setv(
to_object,
["allowCrossRegionModel"],
getv(from_object, ["allow_cross_region_model"]),
)

return to_object


Expand Down Expand Up @@ -2653,6 +2667,8 @@ def create_evaluation_run(
``max_top_cluster_count``. Mutually exclusive with
``loss_analysis_metrics``.
config: The configuration for the evaluation run.
- allow_cross_region_model: Opt-in flag to authorize cross-region
routing for model inference. Applies to both scraping and evaluation.

Returns:
The created evaluation run.
Expand All @@ -2672,6 +2688,11 @@ def create_evaluation_run(
else (agent_info or evals_types.AgentInfo())
)

if not config:
config = types.CreateEvaluationRunConfig()
if isinstance(config, dict):
config = types.CreateEvaluationRunConfig.model_validate(config)

if agent_info and not inference_configs:
parsed_user_simulator_config = (
evals_types.UserSimulatorConfig.model_validate(user_simulator_config)
Expand Down Expand Up @@ -2712,6 +2733,7 @@ def create_evaluation_run(
output_config=output_config,
metrics=resolved_metrics,
loss_analysis_config=resolved_loss_configs,
allow_cross_region_model=getattr(config, "allow_cross_region_model", None),
)
resolved_inference_configs = _evals_common._resolve_inference_configs(
self._api_client, resolved_dataset, inference_configs, parsed_agent_info
Expand Down Expand Up @@ -4422,6 +4444,8 @@ async def create_evaluation_run(
``max_top_cluster_count``. Mutually exclusive with
``loss_analysis_metrics``.
config: The configuration for the evaluation run.
- allow_cross_region_model: Opt-in flag to authorize cross-region
routing for model inference. Applies to both scraping and evaluation.

Returns:
The created evaluation run.
Expand All @@ -4441,6 +4465,11 @@ async def create_evaluation_run(
else (agent_info or evals_types.AgentInfo())
)

if not config:
config = types.CreateEvaluationRunConfig()
if isinstance(config, dict):
config = types.CreateEvaluationRunConfig.model_validate(config)

if agent_info and not inference_configs:
parsed_user_simulator_config = (
evals_types.UserSimulatorConfig.model_validate(user_simulator_config)
Expand Down Expand Up @@ -4481,6 +4510,7 @@ async def create_evaluation_run(
output_config=output_config,
metrics=resolved_metrics,
loss_analysis_config=resolved_loss_configs,
allow_cross_region_model=getattr(config, "allow_cross_region_model", None),
)
resolved_inference_configs = _evals_common._resolve_inference_configs(
self._api_client, resolved_dataset, inference_configs, parsed_agent_info
Expand Down
18 changes: 18 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,6 +2386,11 @@ class EvaluationRunConfig(_common.BaseModel):
default=None,
description="""Specifications for loss analysis. Each config specifies a metric and candidate to analyze for loss patterns.""",
)
allow_cross_region_model: Optional[bool] = Field(
default=None,
description="""Opt-in flag to authorize cross-region routing for model inference.
Applies to both scraping and evaluation.""",
)


class EvaluationRunConfigDict(TypedDict, total=False):
Expand All @@ -2406,6 +2411,10 @@ class EvaluationRunConfigDict(TypedDict, total=False):
loss_analysis_config: Optional[list[LossAnalysisConfigDict]]
"""Specifications for loss analysis. Each config specifies a metric and candidate to analyze for loss patterns."""

allow_cross_region_model: Optional[bool]
"""Opt-in flag to authorize cross-region routing for model inference.
Applies to both scraping and evaluation."""


EvaluationRunConfigOrDict = Union[EvaluationRunConfig, EvaluationRunConfigDict]

Expand Down Expand Up @@ -2536,6 +2545,11 @@ class CreateEvaluationRunConfig(_common.BaseModel):
http_options: Optional[genai_types.HttpOptions] = Field(
default=None, description="""Used to override HTTP request options."""
)
allow_cross_region_model: Optional[bool] = Field(
default=None,
description="""Opt-in flag to authorize cross-region routing for model inference.
Applies to both scraping and evaluation.""",
)


class CreateEvaluationRunConfigDict(TypedDict, total=False):
Expand All @@ -2544,6 +2558,10 @@ class CreateEvaluationRunConfigDict(TypedDict, total=False):
http_options: Optional[genai_types.HttpOptionsDict]
"""Used to override HTTP request options."""

allow_cross_region_model: Optional[bool]
"""Opt-in flag to authorize cross-region routing for model inference.
Applies to both scraping and evaluation."""


CreateEvaluationRunConfigOrDict = Union[
CreateEvaluationRunConfig, CreateEvaluationRunConfigDict
Expand Down
Loading