1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Callable
15+ from contextlib import AbstractContextManager , nullcontext
16+ from typing import Any , Callable
1617from unittest import mock
1718
1819import pandas as pd
2526from bigframes .testing import utils
2627
2728
29+ @pytest .fixture (scope = "function" )
30+ def text_generator_model (request , bq_connection , session ):
31+ """Creates a text generator model, mocking creation for Claude models."""
32+ model_class = request .param
33+ if model_class == llm .Claude3TextGenerator :
34+ # For Claude, mock the BQML model creation to avoid the network call
35+ # that fails due to the region issue.
36+ with mock .patch .object (llm .Claude3TextGenerator , "_create_bqml_model" ):
37+ model = model_class (connection_name = bq_connection , session = session )
38+ else :
39+ # For other models like Gemini, create as usual.
40+ model = model_class (connection_name = bq_connection , session = session )
41+ yield model
42+
43+
2844@pytest .mark .parametrize (
2945 "model_name" ,
3046 ("text-embedding-005" , "text-embedding-004" , "text-multilingual-embedding-002" ),
@@ -251,37 +267,35 @@ def __eq__(self, other):
251267 return self .equals (other )
252268
253269
254- @pytest .mark .skip ("b/436340035 test failed" )
255270@pytest .mark .parametrize (
256- (
257- "model_class" ,
258- "options" ,
259- ),
271+ ("text_generator_model" , "options" ),
260272 [
261- (
273+ pytest . param (
262274 llm .GeminiTextGenerator ,
263275 {
264276 "temperature" : 0.9 ,
265277 "max_output_tokens" : 8192 ,
266278 "top_p" : 1.0 ,
267279 "ground_with_google_search" : False ,
268280 },
281+ id = "gemini" ,
269282 ),
270- (
283+ pytest . param (
271284 llm .Claude3TextGenerator ,
272285 {
273286 "max_output_tokens" : 128 ,
274287 "top_k" : 40 ,
275288 "top_p" : 0.95 ,
276289 },
290+ id = "claude" ,
277291 ),
278292 ],
293+ indirect = ["text_generator_model" ],
279294)
280295def test_text_generator_retry_success (
281296 session ,
282- model_class ,
297+ text_generator_model ,
283298 options ,
284- bq_connection ,
285299):
286300 # Requests.
287301 df0 = EqCmpAllDataFrame (
@@ -298,21 +312,13 @@ def test_text_generator_retry_success(
298312 df1 = EqCmpAllDataFrame (
299313 {
300314 "ml_generate_text_status" : ["error" , "error" ],
301- "prompt" : [
302- "What is BQML?" ,
303- "What is BigQuery DataFrame?" ,
304- ],
315+ "prompt" : ["What is BQML?" , "What is BigQuery DataFrame?" ],
305316 },
306317 index = [1 , 2 ],
307318 session = session ,
308319 )
309320 df2 = EqCmpAllDataFrame (
310- {
311- "ml_generate_text_status" : ["error" ],
312- "prompt" : [
313- "What is BQML?" ,
314- ],
315- },
321+ {"ml_generate_text_status" : ["error" ], "prompt" : ["What is BQML?" ]},
316322 index = [1 ],
317323 session = session ,
318324 )
@@ -342,31 +348,21 @@ def test_text_generator_retry_success(
342348 EqCmpAllDataFrame (
343349 {
344350 "ml_generate_text_status" : ["error" , "" ],
345- "prompt" : [
346- "What is BQML?" ,
347- "What is BigQuery DataFrame?" ,
348- ],
351+ "prompt" : ["What is BQML?" , "What is BigQuery DataFrame?" ],
349352 },
350353 index = [1 , 2 ],
351354 session = session ,
352355 ),
353356 EqCmpAllDataFrame (
354- {
355- "ml_generate_text_status" : ["" ],
356- "prompt" : [
357- "What is BQML?" ,
358- ],
359- },
357+ {"ml_generate_text_status" : ["" ], "prompt" : ["What is BQML?" ]},
360358 index = [1 ],
361359 session = session ,
362360 ),
363361 ]
364362
365- text_generator_model = model_class (connection_name = bq_connection , session = session )
366363 text_generator_model ._bqml_model = mock_bqml_model
367364
368365 with mock .patch .object (core .BqmlModel , "generate_text_tvf" , generate_text_tvf ):
369- # 3rd retry isn't triggered
370366 result = text_generator_model .predict (df0 , max_retries = 3 )
371367
372368 mock_generate_text .assert_has_calls (
@@ -391,36 +387,36 @@ def test_text_generator_retry_success(
391387 ),
392388 check_dtype = False ,
393389 check_index_type = False ,
390+ check_like = True ,
394391 )
395392
396393
397- @pytest .mark .skip ("b/436340035 test failed" )
398394@pytest .mark .parametrize (
399- (
400- "model_class" ,
401- "options" ,
402- ),
395+ ("text_generator_model" , "options" ),
403396 [
404- (
397+ pytest . param (
405398 llm .GeminiTextGenerator ,
406399 {
407400 "temperature" : 0.9 ,
408401 "max_output_tokens" : 8192 ,
409402 "top_p" : 1.0 ,
410403 "ground_with_google_search" : False ,
411404 },
405+ id = "gemini" ,
412406 ),
413- (
407+ pytest . param (
414408 llm .Claude3TextGenerator ,
415409 {
416410 "max_output_tokens" : 128 ,
417411 "top_k" : 40 ,
418412 "top_p" : 0.95 ,
419413 },
414+ id = "claude" ,
420415 ),
421416 ],
417+ indirect = ["text_generator_model" ],
422418)
423- def test_text_generator_retry_no_progress (session , model_class , options , bq_connection ):
419+ def test_text_generator_retry_no_progress (session , text_generator_model , options ):
424420 # Requests.
425421 df0 = EqCmpAllDataFrame (
426422 {
@@ -480,7 +476,6 @@ def test_text_generator_retry_no_progress(session, model_class, options, bq_conn
480476 ),
481477 ]
482478
483- text_generator_model = model_class (connection_name = bq_connection , session = session )
484479 text_generator_model ._bqml_model = mock_bqml_model
485480
486481 with mock .patch .object (core .BqmlModel , "generate_text_tvf" , generate_text_tvf ):
@@ -508,10 +503,10 @@ def test_text_generator_retry_no_progress(session, model_class, options, bq_conn
508503 ),
509504 check_dtype = False ,
510505 check_index_type = False ,
506+ check_like = True ,
511507 )
512508
513509
514- @pytest .mark .skip ("b/436340035 test failed" )
515510def test_text_embedding_generator_retry_success (session , bq_connection ):
516511 # Requests.
517512 df0 = EqCmpAllDataFrame (
@@ -793,17 +788,28 @@ def test_gemini_preview_model_warnings(model_name):
793788 llm .GeminiTextGenerator (model_name = model_name )
794789
795790
796- # b/436340035 temp disable the test to unblock presumbit
797791@pytest .mark .parametrize (
798792 "model_class" ,
799793 [
800794 llm .TextEmbeddingGenerator ,
801795 llm .MultimodalEmbeddingGenerator ,
802796 llm .GeminiTextGenerator ,
803- # llm.Claude3TextGenerator,
797+ llm .Claude3TextGenerator ,
804798 ],
805799)
806800def test_text_embedding_generator_no_default_model_warning (model_class ):
807801 message = "Since upgrading the default model can cause unintended breakages, the\n default model will be removed in BigFrames 3.0. Please supply an\n explicit model to avoid this message."
808- with pytest .warns (FutureWarning , match = message ):
809- model_class (model_name = None )
802+
803+ # For Claude models, we must mock the model creation to avoid network errors.
804+ # For all other models, we do nothing. contextlib.nullcontext() is a
805+ # placeholder that allows the "with" statement to work for all cases.
806+ patcher : AbstractContextManager [Any ]
807+ if model_class == llm .Claude3TextGenerator :
808+ patcher = mock .patch .object (llm .Claude3TextGenerator , "_create_bqml_model" )
809+ else :
810+ # We can now call nullcontext() directly
811+ patcher = nullcontext ()
812+
813+ with patcher :
814+ with pytest .warns (FutureWarning , match = message ):
815+ model_class (model_name = None )
0 commit comments