Skip to content

Commit 5933fc3

Browse files
committed
chore: implement mul_op and div_op compilers
1 parent 132e0ed commit 5933fc3

File tree

7 files changed

+241
-13
lines changed

7 files changed

+241
-13
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,51 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7373
)
7474

7575

76+
@BINARY_OP_REGISTRATION.register(ops.div_op)
77+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
78+
left_expr = left.expr
79+
if left.dtype == dtypes.BOOL_DTYPE:
80+
left_expr = sge.Cast(this=left_expr, to="INT64")
81+
right_expr = right.expr
82+
if right.dtype == dtypes.BOOL_DTYPE:
83+
right_expr = sge.Cast(this=right_expr, to="INT64")
84+
85+
result = sge.func("IEEE_DIVIDE", left_expr, right_expr)
86+
if dtypes.is_numeric(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE:
87+
return sge.Cast(this=sge.Floor(this=result), to="INT64")
88+
else:
89+
return result
90+
91+
92+
@BINARY_OP_REGISTRATION.register(ops.ge_op)
93+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
94+
return sge.GTE(this=left.expr, expression=right.expr)
95+
96+
97+
@BINARY_OP_REGISTRATION.register(ops.JSONSet)
98+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
99+
return sge.func("JSON_SET", left.expr, sge.convert(op.json_path), right.expr)
100+
101+
102+
@BINARY_OP_REGISTRATION.register(ops.mul_op)
103+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
104+
left_expr = left.expr
105+
if left.dtype == dtypes.BOOL_DTYPE:
106+
left_expr = sge.Cast(this=left_expr, to="INT64")
107+
right_expr = right.expr
108+
if right.dtype == dtypes.BOOL_DTYPE:
109+
right_expr = sge.Cast(this=right_expr, to="INT64")
110+
111+
result = sge.Mul(this=left_expr, expression=right_expr)
112+
113+
if (dtypes.is_numeric(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE) or (
114+
dtypes.is_numeric(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE
115+
):
116+
return sge.Cast(this=sge.Floor(this=result), to="INT64")
117+
else:
118+
return result
119+
120+
76121
@BINARY_OP_REGISTRATION.register(ops.sub_op)
77122
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
78123
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
@@ -113,13 +158,3 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
113158
raise TypeError(
114159
f"Cannot subtract type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
115160
)
116-
117-
118-
@BINARY_OP_REGISTRATION.register(ops.ge_op)
119-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
120-
return sge.GTE(this=left.expr, expression=right.expr)
121-
122-
123-
@BINARY_OP_REGISTRATION.register(ops.JSONSet)
124-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
125-
return sge.func("JSON_SET", left.expr, sge.convert(op.json_path), right.expr)

tests/system/small/engines/test_numeric_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_engines_project_sub(
7171
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
7272

7373

74-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
74+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
7575
def test_engines_project_mul(
7676
scalars_array_value: array_value.ArrayValue,
7777
engine,
@@ -80,7 +80,7 @@ def test_engines_project_mul(
8080
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
8181

8282

83-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
83+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
8484
def test_engines_project_div(scalars_array_value: array_value.ArrayValue, engine):
8585
# TODO: Duration div is sensitive to zeroes
8686
# TODO: Numeric col is sensitive to scale shifts
@@ -90,7 +90,7 @@ def test_engines_project_div(scalars_array_value: array_value.ArrayValue, engine
9090
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
9191

9292

93-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
93+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
9494
def test_engines_project_div_durations(
9595
scalars_array_value: array_value.ArrayValue, engine
9696
):
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_1` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
IEEE_DIVIDE(`bfcol_1`, `bfcol_1`) AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*,
18+
`bfcol_6` AS `bfcol_14`,
19+
`bfcol_7` AS `bfcol_15`,
20+
`bfcol_8` AS `bfcol_16`,
21+
`bfcol_9` AS `bfcol_17`,
22+
IEEE_DIVIDE(`bfcol_7`, 1) AS `bfcol_18`
23+
FROM `bfcte_1`
24+
), `bfcte_3` AS (
25+
SELECT
26+
*,
27+
`bfcol_14` AS `bfcol_24`,
28+
`bfcol_15` AS `bfcol_25`,
29+
`bfcol_16` AS `bfcol_26`,
30+
`bfcol_17` AS `bfcol_27`,
31+
`bfcol_18` AS `bfcol_28`,
32+
IEEE_DIVIDE(`bfcol_15`, CAST(`bfcol_16` AS INT64)) AS `bfcol_29`
33+
FROM `bfcte_2`
34+
), `bfcte_4` AS (
35+
SELECT
36+
*,
37+
`bfcol_24` AS `bfcol_36`,
38+
`bfcol_25` AS `bfcol_37`,
39+
`bfcol_26` AS `bfcol_38`,
40+
`bfcol_27` AS `bfcol_39`,
41+
`bfcol_28` AS `bfcol_40`,
42+
`bfcol_29` AS `bfcol_41`,
43+
IEEE_DIVIDE(CAST(`bfcol_26` AS INT64), `bfcol_25`) AS `bfcol_42`
44+
FROM `bfcte_3`
45+
)
46+
SELECT
47+
`bfcol_36` AS `rowindex`,
48+
`bfcol_37` AS `int64_col`,
49+
`bfcol_38` AS `bool_col`,
50+
`bfcol_39` AS `int_div_int`,
51+
`bfcol_40` AS `int_div_1`,
52+
`bfcol_41` AS `int_div_bool`,
53+
`bfcol_42` AS `bool_div_int`
54+
FROM `bfcte_4`
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`date_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`,
5+
`timestamp_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
43200000000 AS `bfcol_6`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_1` AS `rowindex`,
15+
`bfcol_2` AS `timestamp_col`,
16+
`bfcol_0` AS `date_col`,
17+
`bfcol_6` AS `timedelta_div_numeric`
18+
FROM `bfcte_1`
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_1` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
`bfcol_1` * `bfcol_1` AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*,
18+
`bfcol_6` AS `bfcol_14`,
19+
`bfcol_7` AS `bfcol_15`,
20+
`bfcol_8` AS `bfcol_16`,
21+
`bfcol_9` AS `bfcol_17`,
22+
`bfcol_7` * 1 AS `bfcol_18`
23+
FROM `bfcte_1`
24+
), `bfcte_3` AS (
25+
SELECT
26+
*,
27+
`bfcol_14` AS `bfcol_24`,
28+
`bfcol_15` AS `bfcol_25`,
29+
`bfcol_16` AS `bfcol_26`,
30+
`bfcol_17` AS `bfcol_27`,
31+
`bfcol_18` AS `bfcol_28`,
32+
`bfcol_15` * CAST(`bfcol_16` AS INT64) AS `bfcol_29`
33+
FROM `bfcte_2`
34+
), `bfcte_4` AS (
35+
SELECT
36+
*,
37+
`bfcol_24` AS `bfcol_36`,
38+
`bfcol_25` AS `bfcol_37`,
39+
`bfcol_26` AS `bfcol_38`,
40+
`bfcol_27` AS `bfcol_39`,
41+
`bfcol_28` AS `bfcol_40`,
42+
`bfcol_29` AS `bfcol_41`,
43+
CAST(`bfcol_26` AS INT64) * `bfcol_25` AS `bfcol_42`
44+
FROM `bfcte_3`
45+
)
46+
SELECT
47+
`bfcol_36` AS `rowindex`,
48+
`bfcol_37` AS `int64_col`,
49+
`bfcol_38` AS `bool_col`,
50+
`bfcol_39` AS `int_mul_int`,
51+
`bfcol_40` AS `int_mul_1`,
52+
`bfcol_41` AS `int_mul_bool`,
53+
`bfcol_42` AS `bool_mul_int`
54+
FROM `bfcte_4`
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`date_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`,
5+
`timestamp_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
172800000000 AS `bfcol_6`
11+
FROM `bfcte_0`
12+
), `bfcte_2` AS (
13+
SELECT
14+
*,
15+
172800000000 AS `bfcol_7`
16+
FROM `bfcte_1`
17+
)
18+
SELECT
19+
`bfcol_1` AS `rowindex`,
20+
`bfcol_2` AS `timestamp_col`,
21+
`bfcol_0` AS `date_col`,
22+
`bfcol_6` AS `timedelta_mul_numeric`,
23+
`bfcol_7` AS `numeric_mul_timedelta`
24+
FROM `bfcte_2`

tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,27 @@ def test_add_unsupported_raises(scalar_types_df: bpd.DataFrame):
8282
_apply_binary_op(scalar_types_df, ops.add_op, "int64_col", "string_col")
8383

8484

85+
def test_div_numeric(scalar_types_df: bpd.DataFrame, snapshot):
86+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
87+
88+
bf_df["int_div_int"] = bf_df["int64_col"] / bf_df["int64_col"]
89+
bf_df["int_div_1"] = bf_df["int64_col"] / 1
90+
91+
bf_df["int_div_bool"] = bf_df["int64_col"] / bf_df["bool_col"]
92+
bf_df["bool_div_int"] = bf_df["bool_col"] / bf_df["int64_col"]
93+
94+
snapshot.assert_match(bf_df.sql, "out.sql")
95+
96+
97+
def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
98+
bf_df = scalar_types_df[["timestamp_col", "date_col"]]
99+
timedelta = pd.Timedelta(1, unit="d")
100+
101+
bf_df["timedelta_div_numeric"] = timedelta / 2
102+
103+
snapshot.assert_match(bf_df.sql, "out.sql")
104+
105+
85106
def test_json_set(json_types_df: bpd.DataFrame, snapshot):
86107
bf_df = json_types_df[["json_col"]]
87108
sql = _apply_binary_op(
@@ -122,3 +143,25 @@ def test_sub_unsupported_raises(scalar_types_df: bpd.DataFrame):
122143

123144
with pytest.raises(TypeError):
124145
_apply_binary_op(scalar_types_df, ops.sub_op, "int64_col", "string_col")
146+
147+
148+
def test_mul_numeric(scalar_types_df: bpd.DataFrame, snapshot):
149+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
150+
151+
bf_df["int_mul_int"] = bf_df["int64_col"] * bf_df["int64_col"]
152+
bf_df["int_mul_1"] = bf_df["int64_col"] * 1
153+
154+
bf_df["int_mul_bool"] = bf_df["int64_col"] * bf_df["bool_col"]
155+
bf_df["bool_mul_int"] = bf_df["bool_col"] * bf_df["int64_col"]
156+
157+
snapshot.assert_match(bf_df.sql, "out.sql")
158+
159+
160+
def test_mul_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
161+
bf_df = scalar_types_df[["timestamp_col", "date_col"]]
162+
timedelta = pd.Timedelta(1, unit="d")
163+
164+
bf_df["timedelta_mul_numeric"] = timedelta * 2
165+
bf_df["numeric_mul_timedelta"] = 2 * timedelta
166+
167+
snapshot.assert_match(bf_df.sql, "out.sql")

0 commit comments

Comments
 (0)