Skip to content

Commit 5e15b6f

Browse files
committed
test: change how claude test is running
1 parent 7acc2f1 commit 5e15b6f

File tree

1 file changed

+52
-46
lines changed

1 file changed

+52
-46
lines changed

tests/system/small/ml/test_llm.py

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Callable
15+
from contextlib import AbstractContextManager, nullcontext
16+
from typing import Any, Callable
1617
from unittest import mock
1718

1819
import pandas as pd
@@ -25,6 +26,21 @@
2526
from bigframes.testing import utils
2627

2728

29+
@pytest.fixture(scope="function")
30+
def text_generator_model(request, bq_connection, session):
31+
"""Creates a text generator model, mocking creation for Claude models."""
32+
model_class = request.param
33+
if model_class == llm.Claude3TextGenerator:
34+
# For Claude, mock the BQML model creation to avoid the network call
35+
# that fails due to the region issue.
36+
with mock.patch.object(llm.Claude3TextGenerator, "_create_bqml_model"):
37+
model = model_class(connection_name=bq_connection, session=session)
38+
else:
39+
# For other models like Gemini, create as usual.
40+
model = model_class(connection_name=bq_connection, session=session)
41+
yield model
42+
43+
2844
@pytest.mark.parametrize(
2945
"model_name",
3046
("text-embedding-005", "text-embedding-004", "text-multilingual-embedding-002"),
@@ -251,37 +267,35 @@ def __eq__(self, other):
251267
return self.equals(other)
252268

253269

254-
@pytest.mark.skip("b/436340035 test failed")
255270
@pytest.mark.parametrize(
256-
(
257-
"model_class",
258-
"options",
259-
),
271+
("text_generator_model", "options"),
260272
[
261-
(
273+
pytest.param(
262274
llm.GeminiTextGenerator,
263275
{
264276
"temperature": 0.9,
265277
"max_output_tokens": 8192,
266278
"top_p": 1.0,
267279
"ground_with_google_search": False,
268280
},
281+
id="gemini",
269282
),
270-
(
283+
pytest.param(
271284
llm.Claude3TextGenerator,
272285
{
273286
"max_output_tokens": 128,
274287
"top_k": 40,
275288
"top_p": 0.95,
276289
},
290+
id="claude",
277291
),
278292
],
293+
indirect=["text_generator_model"],
279294
)
280295
def test_text_generator_retry_success(
281296
session,
282-
model_class,
297+
text_generator_model,
283298
options,
284-
bq_connection,
285299
):
286300
# Requests.
287301
df0 = EqCmpAllDataFrame(
@@ -298,21 +312,13 @@ def test_text_generator_retry_success(
298312
df1 = EqCmpAllDataFrame(
299313
{
300314
"ml_generate_text_status": ["error", "error"],
301-
"prompt": [
302-
"What is BQML?",
303-
"What is BigQuery DataFrame?",
304-
],
315+
"prompt": ["What is BQML?", "What is BigQuery DataFrame?"],
305316
},
306317
index=[1, 2],
307318
session=session,
308319
)
309320
df2 = EqCmpAllDataFrame(
310-
{
311-
"ml_generate_text_status": ["error"],
312-
"prompt": [
313-
"What is BQML?",
314-
],
315-
},
321+
{"ml_generate_text_status": ["error"], "prompt": ["What is BQML?"]},
316322
index=[1],
317323
session=session,
318324
)
@@ -342,31 +348,21 @@ def test_text_generator_retry_success(
342348
EqCmpAllDataFrame(
343349
{
344350
"ml_generate_text_status": ["error", ""],
345-
"prompt": [
346-
"What is BQML?",
347-
"What is BigQuery DataFrame?",
348-
],
351+
"prompt": ["What is BQML?", "What is BigQuery DataFrame?"],
349352
},
350353
index=[1, 2],
351354
session=session,
352355
),
353356
EqCmpAllDataFrame(
354-
{
355-
"ml_generate_text_status": [""],
356-
"prompt": [
357-
"What is BQML?",
358-
],
359-
},
357+
{"ml_generate_text_status": [""], "prompt": ["What is BQML?"]},
360358
index=[1],
361359
session=session,
362360
),
363361
]
364362

365-
text_generator_model = model_class(connection_name=bq_connection, session=session)
366363
text_generator_model._bqml_model = mock_bqml_model
367364

368365
with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf):
369-
# 3rd retry isn't triggered
370366
result = text_generator_model.predict(df0, max_retries=3)
371367

372368
mock_generate_text.assert_has_calls(
@@ -391,36 +387,36 @@ def test_text_generator_retry_success(
391387
),
392388
check_dtype=False,
393389
check_index_type=False,
390+
check_like=True,
394391
)
395392

396393

397-
@pytest.mark.skip("b/436340035 test failed")
398394
@pytest.mark.parametrize(
399-
(
400-
"model_class",
401-
"options",
402-
),
395+
("text_generator_model", "options"),
403396
[
404-
(
397+
pytest.param(
405398
llm.GeminiTextGenerator,
406399
{
407400
"temperature": 0.9,
408401
"max_output_tokens": 8192,
409402
"top_p": 1.0,
410403
"ground_with_google_search": False,
411404
},
405+
id="gemini",
412406
),
413-
(
407+
pytest.param(
414408
llm.Claude3TextGenerator,
415409
{
416410
"max_output_tokens": 128,
417411
"top_k": 40,
418412
"top_p": 0.95,
419413
},
414+
id="claude",
420415
),
421416
],
417+
indirect=["text_generator_model"],
422418
)
423-
def test_text_generator_retry_no_progress(session, model_class, options, bq_connection):
419+
def test_text_generator_retry_no_progress(session, text_generator_model, options):
424420
# Requests.
425421
df0 = EqCmpAllDataFrame(
426422
{
@@ -480,7 +476,6 @@ def test_text_generator_retry_no_progress(session, model_class, options, bq_conn
480476
),
481477
]
482478

483-
text_generator_model = model_class(connection_name=bq_connection, session=session)
484479
text_generator_model._bqml_model = mock_bqml_model
485480

486481
with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf):
@@ -508,10 +503,10 @@ def test_text_generator_retry_no_progress(session, model_class, options, bq_conn
508503
),
509504
check_dtype=False,
510505
check_index_type=False,
506+
check_like=True,
511507
)
512508

513509

514-
@pytest.mark.skip("b/436340035 test failed")
515510
def test_text_embedding_generator_retry_success(session, bq_connection):
516511
# Requests.
517512
df0 = EqCmpAllDataFrame(
@@ -793,17 +788,28 @@ def test_gemini_preview_model_warnings(model_name):
793788
llm.GeminiTextGenerator(model_name=model_name)
794789

795790

796-
# b/436340035 temp disable the test to unblock presumbit
797791
@pytest.mark.parametrize(
798792
"model_class",
799793
[
800794
llm.TextEmbeddingGenerator,
801795
llm.MultimodalEmbeddingGenerator,
802796
llm.GeminiTextGenerator,
803-
# llm.Claude3TextGenerator,
797+
llm.Claude3TextGenerator,
804798
],
805799
)
806800
def test_text_embedding_generator_no_default_model_warning(model_class):
807801
message = "Since upgrading the default model can cause unintended breakages, the\ndefault model will be removed in BigFrames 3.0. Please supply an\nexplicit model to avoid this message."
808-
with pytest.warns(FutureWarning, match=message):
809-
model_class(model_name=None)
802+
803+
# For Claude models, we must mock the model creation to avoid network errors.
804+
# For all other models, we do nothing. contextlib.nullcontext() is a
805+
# placeholder that allows the "with" statement to work for all cases.
806+
patcher: AbstractContextManager[Any]
807+
if model_class == llm.Claude3TextGenerator:
808+
patcher = mock.patch.object(llm.Claude3TextGenerator, "_create_bqml_model")
809+
else:
810+
# We can now call nullcontext() directly
811+
patcher = nullcontext()
812+
813+
with patcher:
814+
with pytest.warns(FutureWarning, match=message):
815+
model_class(model_name=None)

0 commit comments

Comments
 (0)