Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,47 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
return cast_op.as_expr(arg)


LOWER_COMPARISONS = tuple(
class LowerEqNullsMatchRule(op_lowering.OpLoweringRule):
@property
def op(self) -> type[ops.ScalarOp]:
return comparison_ops.EqNullsMatchOp

def lower(self, expr: expression.OpExpression) -> expression.Expression:
assert isinstance(expr.op, comparison_ops.EqNullsMatchOp)
arg1, arg2 = _coerce_comparables(expr.children[0], expr.children[1])

# True constant
true_const = expression.const(True)
# False constant
false_const = expression.const(False)

# equal = arg1 == arg2
equal_expr = ops.eq_op.as_expr(arg1, arg2)

# isnull1 = arg1.isnull()
isnull1_expr = ops.isnull_op.as_expr(arg1)

# isnull2 = arg2.isnull()
isnull2_expr = ops.isnull_op.as_expr(arg2)

# both_null = isnull1 & isnull2
both_null_expr = ops.and_op.as_expr(isnull1_expr, isnull2_expr)

# any_null = isnull1 | isnull2
any_null_expr = ops.or_op.as_expr(isnull1_expr, isnull2_expr)

# inner_where = where(false, any_null, equal)
inner_where_expr = ops.where_op.as_expr(false_const, any_null_expr, equal_expr)

# outer_where = where(true, both_null, inner_where)
null_safe_eq_expr = ops.where_op.as_expr(
true_const, both_null_expr, inner_where_expr
)

return null_safe_eq_expr


POLARS_LOWER_COMPARISONS = tuple(
CoerceArgsRule(op)
for op in (
comparison_ops.EqOp,
Expand All @@ -476,8 +516,20 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
)
)

SUBSTRAIT_LOWER_COMPARISONS = tuple(
CoerceArgsRule(op)
for op in (
comparison_ops.EqOp,
comparison_ops.NeOp,
comparison_ops.LtOp,
comparison_ops.GtOp,
comparison_ops.LeOp,
comparison_ops.GeOp,
)
)

POLARS_LOWERING_RULES = (
*LOWER_COMPARISONS,
*POLARS_LOWER_COMPARISONS,
LowerAddRule(),
LowerSubRule(),
LowerMulRule(),
Expand All @@ -492,11 +544,22 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
LowerFloorOp(),
)

SUBSTRAIT_LOWERING_RULES = (
LowerEqNullsMatchRule(),
*SUBSTRAIT_LOWER_COMPARISONS,
)


def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:
return op_lowering.lower_ops(root, rules=POLARS_LOWERING_RULES)


def lower_ops_to_substrait(
root: bigframe_node.BigFrameNode,
) -> bigframe_node.BigFrameNode:
return op_lowering.lower_ops(root, rules=SUBSTRAIT_LOWERING_RULES)


def _numeric_to_timedelta(expr: expression.Expression) -> expression.Expression:
"""rounding logic used for emulating timedelta ops"""
rounded_value = ops.where_op.as_expr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import bigframes.operations.numeric_ops as num_ops
import bigframes.operations.string_ops as string_ops
from bigframes.core import agg_expressions, identifiers, nodes, ordering, window_spec
from bigframes.core.compile.polars import lowering

polars_installed = True
if TYPE_CHECKING:
Expand Down Expand Up @@ -652,6 +651,8 @@ def compile(self, plan: nodes.BigFrameNode) -> pl.LazyFrame:
node = nodes.bottom_up(node, bigframes.core.rewrite.rewrite_slice)
node = bigframes.core.rewrite.pull_out_window_order(node)
node = bigframes.core.rewrite.schema_binding.bind_schema_to_tree(node)
from bigframes.core.compile import lowering

node = lowering.lower_ops_to_polars(node)
return self.compile_node(node)

Expand Down Expand Up @@ -743,6 +744,8 @@ def compile_join(self, node: nodes.JoinNode):
left_on = []
right_on = []
for left_ex, right_ex in node.conditions:
from bigframes.core.compile import lowering

left_ex, right_ex = lowering._coerce_comparables(left_ex, right_ex)
left_on.append(self.expr_compiler.compile_expression(left_ex))
right_on.append(self.expr_compiler.compile_expression(right_ex))
Expand All @@ -762,6 +765,8 @@ def compile_isin(self, node: nodes.InNode):
right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql))

right_col = ex.ResolvedDerefOp.from_field(node.right_child.fields[0])
from bigframes.core.compile import lowering

left_ex, right_ex = lowering._coerce_comparables(node.left_col, right_col)

left_pl_ex = self.expr_compiler.compile_expression(left_ex)
Expand Down
19 changes: 19 additions & 0 deletions packages/bigframes/bigframes/core/compile/substrait/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .compiler import SubstraitCompiler

__all__ = ["SubstraitCompiler"]
Loading
Loading