Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions tests/unit/vertexai/test_offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from google import auth
from vertexai.resources.preview.feature_store import (
offline_store,
FeatureGroup,
Feature,
)

pytestmark = [
Expand Down Expand Up @@ -266,7 +264,7 @@ def mock_session(bigframes_import_mock):

@pytest.fixture()
def mock_fg():
with mock.patch.object(FeatureGroup, "__new__") as mock_fg:
with mock.patch.object(offline_store, "FeatureGroup") as mock_fg:
yield mock_fg


Expand All @@ -282,12 +280,6 @@ def create_mock_fg(
return fg


@pytest.fixture()
def mock_feature():
with mock.patch.object(Feature, "__new__") as mock_feature:
yield mock_feature


def create_mock_feature(
name: str,
version_column_name: str,
Expand Down Expand Up @@ -360,7 +352,6 @@ def test_one_feature_same_and_different_bq_col_name(
mock_convert_to_bigquery_dataframe,
mock_session,
mock_fg,
mock_feature,
bigframes_import_mock,
):
bigframes, _ = bigframes_import_mock
Expand All @@ -382,6 +373,7 @@ def test_one_feature_same_and_different_bq_col_name(
mock_fg.return_value = create_mock_fg(
name="fake", entity_id_cols=["customer_id"], bq_uri="bq://my.table"
)
mock_feature = mock.MagicMock()
mock_feature.return_value = create_mock_feature(
name=feature_name, version_column_name=bq_column_name
)
Expand All @@ -407,7 +399,6 @@ def test_one_feature_with_explicit_project_and_location(
mock_convert_to_bigquery_dataframe,
mock_session,
mock_fg,
mock_feature,
bigframes_import_mock,
):
bigframes, _ = bigframes_import_mock
Expand All @@ -428,6 +419,7 @@ def test_one_feature_with_explicit_project_and_location(
mock_fg.return_value = create_mock_fg(
name="fake", entity_id_cols=["customer_id"], bq_uri="bq://my.table"
)
mock_feature = mock.MagicMock()
mock_feature.return_value = create_mock_feature(
name="my_feature", version_column_name="my_feature"
)
Expand Down Expand Up @@ -457,7 +449,6 @@ def test_one_feature_with_explicit_credentials(
mock_convert_to_bigquery_dataframe,
mock_session,
mock_fg,
mock_feature,
bigframes_import_mock,
):
bigframes, _ = bigframes_import_mock
Expand All @@ -478,6 +469,7 @@ def test_one_feature_with_explicit_credentials(
mock_fg.return_value = create_mock_fg(
name="fake", entity_id_cols=["customer_id"], bq_uri="bq://my.table"
)
mock_feature = mock.MagicMock()
mock_feature.return_value = create_mock_feature(
name="my_feature", version_column_name="my_feature"
)
Expand Down Expand Up @@ -507,7 +499,6 @@ def test_one_feature_with_explicit_credentials(
# Ensure when getting the FeatureGroup and Feature, the credentials are
# passed through.
mock_fg.assert_called_once_with(
FeatureGroup,
"fake",
project=None,
credentials=credentials,
Expand Down
Loading