|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 | from pathlib import PosixPath |
18 | | -from typing import Any |
19 | 18 |
|
20 | 19 | import pyarrow as pa |
21 | 20 | import pytest |
|
24 | 23 |
|
25 | 24 | from pyiceberg.catalog import Catalog |
26 | 25 | from pyiceberg.exceptions import NoSuchTableError |
27 | | -from pyiceberg.expressions import AlwaysTrue, And, BooleanExpression, EqualTo, In, IsNaN, IsNull, Or, Reference |
| 26 | +from pyiceberg.expressions import AlwaysTrue, And, EqualTo, In, IsNaN, IsNull, Or, Reference |
28 | 27 | from pyiceberg.expressions.literals import DoubleLiteral, LongLiteral |
29 | 28 | from pyiceberg.io.pyarrow import schema_to_pyarrow |
30 | 29 | from pyiceberg.schema import Schema |
@@ -441,73 +440,80 @@ def test_create_match_filter_single_condition() -> None: |
441 | 440 | ) |
442 | 441 |
|
443 | 442 |
|
444 | | -@pytest.mark.parametrize( |
445 | | - "data, expected", |
446 | | - [ |
447 | | - pytest.param( |
448 | | - [{"x": 1.0}, {"x": 2.0}, {"x": 3.0}], |
449 | | - In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(3.0)}), |
450 | | - id="single-column-without-null", |
| 443 | +def test_create_match_filter_single_column_without_null() -> None: |
| 444 | + data = [{"x": 1.0}, {"x": 2.0}, {"x": 3.0}] |
| 445 | + |
| 446 | + schema = pa.schema([pa.field("x", pa.float64())]) |
| 447 | + table = pa.Table.from_pylist(data, schema=schema) |
| 448 | + |
| 449 | + expr = create_match_filter(table, join_cols=["x"]) |
| 450 | + |
| 451 | + assert expr == In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(3.0)}) |
| 452 | + |
| 453 | + |
| 454 | +def test_create_match_filter_single_column_with_null() -> None: |
| 455 | + data = [ |
| 456 | + {"x": 1.0}, |
| 457 | + {"x": 2.0}, |
| 458 | + {"x": None}, |
| 459 | + {"x": 4.0}, |
| 460 | + {"x": float("nan")}, |
| 461 | + ] |
| 462 | + schema = pa.schema([pa.field("x", pa.float64())]) |
| 463 | + table = pa.Table.from_pylist(data, schema=schema) |
| 464 | + |
| 465 | + expr = create_match_filter(table, join_cols=["x"]) |
| 466 | + |
| 467 | + assert expr == Or( |
| 468 | + left=IsNull(term=Reference(name="x")), |
| 469 | + right=Or( |
| 470 | + left=IsNaN(term=Reference(name="x")), |
| 471 | + right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}), |
451 | 472 | ), |
452 | | - pytest.param( |
453 | | - [{"x": 1.0}, {"x": 2.0}, {"x": None}, {"x": 4.0}, {"x": float("nan")}], |
454 | | - Or( |
455 | | - left=IsNull(term=Reference(name="x")), |
456 | | - right=Or( |
457 | | - left=IsNaN(term=Reference(name="x")), |
458 | | - right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}), |
459 | | - ), |
| 473 | + ) |
| 474 | + |
| 475 | + |
| 476 | +def test_create_match_filter_multi_column_with_null() -> None: |
| 477 | + data = [ |
| 478 | + {"x": 1.0, "y": 9.0}, |
| 479 | + {"x": 2.0, "y": None}, |
| 480 | + {"x": None, "y": 7.0}, |
| 481 | + {"x": 4.0, "y": float("nan")}, |
| 482 | + {"x": float("nan"), "y": 0.0}, |
| 483 | + ] |
| 484 | + schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())]) |
| 485 | + table = pa.Table.from_pylist(data, schema=schema) |
| 486 | + |
| 487 | + expr = create_match_filter(table, join_cols=["x", "y"]) |
| 488 | + |
| 489 | + assert expr == Or( |
| 490 | + left=Or( |
| 491 | + left=And( |
| 492 | + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)), |
| 493 | + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)), |
| 494 | + ), |
| 495 | + right=And( |
| 496 | + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)), |
| 497 | + right=IsNull(term=Reference(name="y")), |
460 | 498 | ), |
461 | | - id="single-column-with-null", |
462 | 499 | ), |
463 | | - pytest.param( |
464 | | - [ |
465 | | - {"x": 1.0, "y": 9.0}, |
466 | | - {"x": 2.0, "y": None}, |
467 | | - {"x": None, "y": 7.0}, |
468 | | - {"x": 4.0, "y": float("nan")}, |
469 | | - {"x": float("nan"), "y": 0.0}, |
470 | | - ], |
471 | | - Or( |
472 | | - left=Or( |
473 | | - left=And( |
474 | | - left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)), |
475 | | - right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)), |
476 | | - ), |
477 | | - right=And( |
478 | | - left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)), |
479 | | - right=IsNull(term=Reference(name="y")), |
480 | | - ), |
| 500 | + right=Or( |
| 501 | + left=And( |
| 502 | + left=IsNull(term=Reference(name="x")), |
| 503 | + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)), |
| 504 | + ), |
| 505 | + right=Or( |
| 506 | + left=And( |
| 507 | + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)), |
| 508 | + right=IsNaN(term=Reference(name="y")), |
481 | 509 | ), |
482 | | - right=Or( |
483 | | - left=And( |
484 | | - left=IsNull(term=Reference(name="x")), |
485 | | - right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)), |
486 | | - ), |
487 | | - right=Or( |
488 | | - left=And( |
489 | | - left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)), |
490 | | - right=IsNaN(term=Reference(name="y")), |
491 | | - ), |
492 | | - right=And( |
493 | | - left=IsNaN(term=Reference(name="x")), |
494 | | - right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)), |
495 | | - ), |
496 | | - ), |
| 510 | + right=And( |
| 511 | + left=IsNaN(term=Reference(name="x")), |
| 512 | + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)), |
497 | 513 | ), |
498 | 514 | ), |
499 | | - id="multi-column-with-null", |
500 | 515 | ), |
501 | | - ], |
502 | | -) |
503 | | -def test_create_match_filter(data: list[dict[str, Any]], expected: BooleanExpression) -> None: |
504 | | - schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())]) |
505 | | - table = pa.Table.from_pylist(data, schema=schema) |
506 | | - join_cols = sorted({col for record in data for col in record}) |
507 | | - |
508 | | - expr = create_match_filter(table, join_cols) |
509 | | - |
510 | | - assert expr == expected |
| 516 | + ) |
511 | 517 |
|
512 | 518 |
|
513 | 519 | def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: |
|
0 commit comments