diff --git a/packages/bigframes/bigframes/core/compile/polars/lowering.py b/packages/bigframes/bigframes/core/compile/lowering.py similarity index 90% rename from packages/bigframes/bigframes/core/compile/polars/lowering.py rename to packages/bigframes/bigframes/core/compile/lowering.py index 7416ebc963b4..de245f68538b 100644 --- a/packages/bigframes/bigframes/core/compile/polars/lowering.py +++ b/packages/bigframes/bigframes/core/compile/lowering.py @@ -463,7 +463,47 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): return cast_op.as_expr(arg) -LOWER_COMPARISONS = tuple( +class LowerEqNullsMatchRule(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return comparison_ops.EqNullsMatchOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, comparison_ops.EqNullsMatchOp) + arg1, arg2 = _coerce_comparables(expr.children[0], expr.children[1]) + + # True constant + true_const = expression.const(True) + # False constant + false_const = expression.const(False) + + # equal = arg1 == arg2 + equal_expr = ops.eq_op.as_expr(arg1, arg2) + + # isnull1 = arg1.isnull() + isnull1_expr = ops.isnull_op.as_expr(arg1) + + # isnull2 = arg2.isnull() + isnull2_expr = ops.isnull_op.as_expr(arg2) + + # both_null = isnull1 & isnull2 + both_null_expr = ops.and_op.as_expr(isnull1_expr, isnull2_expr) + + # any_null = isnull1 | isnull2 + any_null_expr = ops.or_op.as_expr(isnull1_expr, isnull2_expr) + + # inner_where = where(false, any_null, equal) + inner_where_expr = ops.where_op.as_expr(false_const, any_null_expr, equal_expr) + + # outer_where = where(true, both_null, inner_where) + null_safe_eq_expr = ops.where_op.as_expr( + true_const, both_null_expr, inner_where_expr + ) + + return null_safe_eq_expr + + +POLARS_LOWER_COMPARISONS = tuple( CoerceArgsRule(op) for op in ( comparison_ops.EqOp, @@ -476,8 +516,20 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): ) ) +SUBSTRAIT_LOWER_COMPARISONS = tuple( + CoerceArgsRule(op) + for op in ( + comparison_ops.EqOp, + comparison_ops.NeOp, + comparison_ops.LtOp, + comparison_ops.GtOp, + comparison_ops.LeOp, + comparison_ops.GeOp, + ) +) + POLARS_LOWERING_RULES = ( - *LOWER_COMPARISONS, + *POLARS_LOWER_COMPARISONS, LowerAddRule(), LowerSubRule(), LowerMulRule(), @@ -492,11 +544,22 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): LowerFloorOp(), ) +SUBSTRAIT_LOWERING_RULES = ( + LowerEqNullsMatchRule(), + *SUBSTRAIT_LOWER_COMPARISONS, +) + def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode: return op_lowering.lower_ops(root, rules=POLARS_LOWERING_RULES) +def lower_ops_to_substrait( + root: bigframe_node.BigFrameNode, +) -> bigframe_node.BigFrameNode: + return op_lowering.lower_ops(root, rules=SUBSTRAIT_LOWERING_RULES) + + def _numeric_to_timedelta(expr: expression.Expression) -> expression.Expression: """rounding logic used for emulating timedelta ops""" rounded_value = ops.where_op.as_expr( diff --git a/packages/bigframes/bigframes/core/compile/polars/compiler.py b/packages/bigframes/bigframes/core/compile/polars/compiler.py index 6f24929eeb4e..61d2a42b2138 100644 --- a/packages/bigframes/bigframes/core/compile/polars/compiler.py +++ b/packages/bigframes/bigframes/core/compile/polars/compiler.py @@ -39,7 +39,6 @@ import bigframes.operations.numeric_ops as num_ops import bigframes.operations.string_ops as string_ops from bigframes.core import agg_expressions, identifiers, nodes, ordering, window_spec -from bigframes.core.compile.polars import lowering polars_installed = True if TYPE_CHECKING: @@ -652,6 +651,8 @@ def compile(self, plan: nodes.BigFrameNode) -> pl.LazyFrame: node = nodes.bottom_up(node, bigframes.core.rewrite.rewrite_slice) node = bigframes.core.rewrite.pull_out_window_order(node) node = bigframes.core.rewrite.schema_binding.bind_schema_to_tree(node) + from bigframes.core.compile import lowering + node = lowering.lower_ops_to_polars(node) return self.compile_node(node) @@ -743,6 +744,8 @@ def compile_join(self, node: nodes.JoinNode): left_on = [] right_on = [] for left_ex, right_ex in node.conditions: + from bigframes.core.compile import lowering + left_ex, right_ex = lowering._coerce_comparables(left_ex, right_ex) left_on.append(self.expr_compiler.compile_expression(left_ex)) right_on.append(self.expr_compiler.compile_expression(right_ex)) @@ -762,6 +765,8 @@ def compile_isin(self, node: nodes.InNode): right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql)) right_col = ex.ResolvedDerefOp.from_field(node.right_child.fields[0]) + from bigframes.core.compile import lowering + left_ex, right_ex = lowering._coerce_comparables(node.left_col, right_col) left_pl_ex = self.expr_compiler.compile_expression(left_ex) diff --git a/packages/bigframes/bigframes/core/compile/substrait/__init__.py b/packages/bigframes/bigframes/core/compile/substrait/__init__.py new file mode 100644 index 000000000000..13021b2fec3f --- /dev/null +++ b/packages/bigframes/bigframes/core/compile/substrait/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .compiler import SubstraitCompiler + +__all__ = ["SubstraitCompiler"] diff --git a/packages/bigframes/bigframes/core/compile/substrait/compiler.py b/packages/bigframes/bigframes/core/compile/substrait/compiler.py new file mode 100644 index 000000000000..dcece545a387 --- /dev/null +++ b/packages/bigframes/bigframes/core/compile/substrait/compiler.py @@ -0,0 +1,1164 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import singledispatchmethod +from typing import Any, Dict, Literal, Optional, Sequence + +import pandas as pd +import substrait.algebra_pb2 as algebra_pb2 +import substrait.plan_pb2 as plan_pb2 +from google.protobuf import json_format + +import bigframes.core.expression as ex +import bigframes.dtypes as dtypes +import bigframes.operations as ops +import bigframes.operations.bool_ops as bool_ops +import bigframes.operations.comparison_ops as comparison_ops +import bigframes.operations.generic_ops as generic_ops +import bigframes.operations.numeric_ops as numeric_ops +import bigframes.operations.struct_ops as struct_ops +from bigframes.core import bigframe_node, nodes, rewrite +from bigframes.core.compile import lowering + + +class SubstraitCompiler: + """ + Compiles BigFrameNode plans to Substrait schema (JSON representation). + """ + + def __init__( + self, + duration_type: Literal["interval_day", "int"], + use_precision_types: bool = True, + ): + self._duration_type = duration_type + self._use_precision_types = use_precision_types + + def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: + """ + Compiles a BigFrameNode to Substrait bytes (JSON encoded via protobuf). + """ + if not self.can_compile(plan): + return None + + # Need to bind types in before lowering + plan = rewrite.bind_schema_to_tree(plan) + plan = lowering.lower_ops_to_substrait(plan) + pb_rel = self._compile_node(plan) + + pb_plan = plan_pb2.Plan() + pb_plan.version.minor_number = 42 + + plan_rel = pb_plan.relations.add() + plan_rel.root.input.CopyFrom(pb_rel) + + for item in plan.schema.items: + plan_rel.root.names.extend( + self._get_substrait_names(item.column, item.dtype) + ) + + for name, anchor in self._EXTENSIONS.items(): + ext = pb_plan.extensions.add() + ext.extension_function.function_anchor = anchor + ext.extension_function.name = name + + return pb_plan.SerializeToString() + + def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool: + """ + Checks if the plan can be compiled to Substrait. + """ + supported_nodes = ( + nodes.ReadLocalNode, + nodes.SelectionNode, + nodes.FilterNode, + nodes.ProjectionNode, + nodes.JoinNode, + nodes.AggregateNode, + nodes.OrderByNode, + nodes.ConcatNode, + ) + return all(isinstance(n, supported_nodes) for n in plan.unique_nodes()) + + def _compile_node(self, node: bigframe_node.BigFrameNode) -> algebra_pb2.Rel: + if isinstance(node, nodes.ReadLocalNode): + return self._compile_read(node) + elif isinstance(node, nodes.SelectionNode): + return self._compile_selection(node) + elif isinstance(node, nodes.FilterNode): + return self._compile_filter(node) + elif isinstance(node, nodes.ProjectionNode): + return self._compile_projection(node) + elif isinstance(node, nodes.JoinNode): + return self._compile_join(node) + elif isinstance(node, nodes.AggregateNode): + return self._compile_aggregate(node) + elif isinstance(node, nodes.OrderByNode): + return self._compile_orderby(node) + elif isinstance(node, nodes.ConcatNode): + return self._compile_concat(node) + else: + raise NotImplementedError( + f"Node type {type(node)} not supported in Substrait compiler yet" + ) + + def _compile_read(self, node: nodes.ReadLocalNode) -> algebra_pb2.Rel: + table_name = f"table_{id(node)}" + + rel = algebra_pb2.Rel() + read_rel = rel.read + read_rel.named_table.names.append(table_name) + + import bigframes.dtypes as dtypes + + fields = [] + types = [] + for item in node.scan_list.items: + col_dtype = node.local_data_source.schema.get_type(item.source_id) + fields.extend(self._get_substrait_names(item.id.sql, col_dtype)) + types.append(self._convert_type(col_dtype)) + + if node.offsets_col is not None: + fields.append(node.offsets_col.sql) + types.append(self._convert_type(dtypes.INT_DTYPE)) + + schema_dict = {"names": fields, "struct": {"types": types}} + json_format.ParseDict(schema_dict, read_rel.base_schema) + + return rel + + def _compile_selection(self, node: nodes.SelectionNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + project_rel = rel.project + project_rel.input.CopyFrom(input_rel) + + child_ids = list(node.child.ids) + num_exprs = 0 + for aliased_ref in node.input_output_pairs: + source_id = aliased_ref.ref.id + idx = child_ids.index(source_id) + expr = project_rel.expressions.add() + expr.selection.direct_reference.struct_field.field = idx + num_exprs += 1 + + project_rel.common.emit.output_mapping.extend( + [len(child_ids) + i for i in range(num_exprs)] + ) + + return rel + + def _compile_filter(self, node: nodes.FilterNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + filter_rel = rel.filter + filter_rel.input.CopyFrom(input_rel) + + condition_expr = self._compile_expression(node.predicate, node.child) + filter_rel.condition.CopyFrom(condition_expr) + + return rel + + def _compile_projection(self, node: nodes.ProjectionNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + project_rel = rel.project + project_rel.input.CopyFrom(input_rel) + + for expr, _ in node.assignments: + expr_pb = self._compile_expression(expr, node.child) + project_rel.expressions.add().CopyFrom(expr_pb) + + child_ids = list(node.child.ids) + num_exprs = len(node.assignments) + project_rel.common.emit.output_mapping.extend(range(len(child_ids) + num_exprs)) + + return rel + + def _compile_join(self, node: nodes.JoinNode) -> algebra_pb2.Rel: + left_rel = self._compile_node(node.left_child) + right_rel = self._compile_node(node.right_child) + + rel = algebra_pb2.Rel() + if node.type == "cross": + cross_rel = rel.cross + cross_rel.left.CopyFrom(left_rel) + cross_rel.right.CopyFrom(right_rel) + return rel + + join_rel = rel.join + join_rel.left.CopyFrom(left_rel) + join_rel.right.CopyFrom(right_rel) + + type_map = { + "inner": algebra_pb2.JoinRel.JOIN_TYPE_INNER, + "left": algebra_pb2.JoinRel.JOIN_TYPE_LEFT, + "right": algebra_pb2.JoinRel.JOIN_TYPE_RIGHT, + "outer": algebra_pb2.JoinRel.JOIN_TYPE_OUTER, + } + join_rel.type = type_map.get( + node.type, algebra_pb2.JoinRel.JOIN_TYPE_UNSPECIFIED + ) + + left_len = len(node.left_child.schema) + + eq_expressions = [] + for left_deref, right_deref in node.conditions: + left_idx = list(node.left_child.ids).index(left_deref.id) + right_idx = list(node.right_child.ids).index(right_deref.id) + left_len + + arg1 = algebra_pb2.Expression() + arg1.selection.direct_reference.struct_field.field = left_idx + + arg2 = algebra_pb2.Expression() + arg2.selection.direct_reference.struct_field.field = right_idx + + eq_expr = algebra_pb2.Expression() + eq_expr.scalar_function.function_reference = self._EXTENSIONS["equal"] + eq_expr.scalar_function.arguments.add().value.CopyFrom(arg1) + eq_expr.scalar_function.arguments.add().value.CopyFrom(arg2) + + isnull1_expr = algebra_pb2.Expression() + isnull1_expr.scalar_function.function_reference = self._EXTENSIONS[ + "is_null" + ] + isnull1_expr.scalar_function.arguments.add().value.CopyFrom(arg1) + + isnull2_expr = algebra_pb2.Expression() + isnull2_expr.scalar_function.function_reference = self._EXTENSIONS[ + "is_null" + ] + isnull2_expr.scalar_function.arguments.add().value.CopyFrom(arg2) + + both_null_expr = algebra_pb2.Expression() + both_null_expr.scalar_function.function_reference = self._EXTENSIONS["and"] + both_null_expr.scalar_function.arguments.add().value.CopyFrom(isnull1_expr) + both_null_expr.scalar_function.arguments.add().value.CopyFrom(isnull2_expr) + + null_safe_eq = algebra_pb2.Expression() + null_safe_eq.scalar_function.function_reference = self._EXTENSIONS["or"] + null_safe_eq.scalar_function.arguments.add().value.CopyFrom(eq_expr) + null_safe_eq.scalar_function.arguments.add().value.CopyFrom(both_null_expr) + + eq_expressions.append(null_safe_eq) + + if len(eq_expressions) > 1: + expr = eq_expressions[0] + for e in eq_expressions[1:]: + and_expr = algebra_pb2.Expression() + and_expr.scalar_function.function_reference = 13 # and + and_expr.scalar_function.arguments.add().value.CopyFrom(expr) + and_expr.scalar_function.arguments.add().value.CopyFrom(e) + expr = and_expr + elif len(eq_expressions) == 1: + expr = eq_expressions[0] + else: + expr = algebra_pb2.Expression() + expr.literal.boolean = True + + join_rel.expression.CopyFrom(expr) + + return rel + + def _compile_bound( + self, + val: Optional[int], + bound_msg: algebra_pb2.Expression.WindowFunction.Bound, + ): + if val is None: + bound_msg.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded() + ) + elif val == 0: + bound_msg.current_row.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.CurrentRow() + ) + elif val < 0: + bound_msg.preceding.offset = -val + else: + bound_msg.following.offset = val + + def _compile_concat(self, node: nodes.ConcatNode) -> algebra_pb2.Rel: + rel = algebra_pb2.Rel() + set_rel = rel.set + set_rel.op = algebra_pb2.SetRel.SetOp.SET_OP_UNION_ALL + + for child in node.children: + set_rel.inputs.append(self._compile_node(child)) + + return rel + + def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + agg_rel = rel.aggregate + agg_rel.input.CopyFrom(input_rel) + + if node.by_column_ids: + grouping = agg_rel.groupings.add() + for deref in node.by_column_ids: + idx = list(node.child.ids).index(deref.id) + expr = grouping.grouping_expressions.add() + expr.selection.direct_reference.struct_field.field = idx + + import bigframes.dtypes as dtypes + import bigframes.operations.aggregations as agg_ops + + size_count = 0 + for agg, out_col_id in node.aggregations: + distinct = False + if isinstance(agg.op, agg_ops.SumOp): + func_ref = self._EXTENSIONS["sum"] + elif isinstance(agg.op, agg_ops.MaxOp): + func_ref = self._EXTENSIONS["max"] + elif isinstance(agg.op, agg_ops.MinOp): + func_ref = self._EXTENSIONS["min"] + elif isinstance(agg.op, agg_ops.MeanOp): + func_ref = self._EXTENSIONS["mean"] + elif isinstance(agg.op, agg_ops.CountOp): + func_ref = self._EXTENSIONS["count"] + elif isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + func_ref = self._EXTENSIONS["count"] + elif isinstance(agg.op, agg_ops.NuniqueOp): + func_ref = self._EXTENSIONS["count"] + distinct = True + elif isinstance(agg.op, agg_ops.StdOp): + func_ref = self._EXTENSIONS["stddev"] + elif isinstance(agg.op, agg_ops.VarOp): + func_ref = self._EXTENSIONS["var"] + elif isinstance(agg.op, agg_ops.PopVarOp): + func_ref = self._EXTENSIONS["var_pop"] + elif isinstance(agg.op, agg_ops.AnyValueOp): + func_ref = self._EXTENSIONS["min"] + elif isinstance(agg.op, agg_ops.AllOp): + func_ref = self._EXTENSIONS["bool_and"] + elif isinstance(agg.op, agg_ops.AnyOp): + func_ref = self._EXTENSIONS["bool_or"] + elif isinstance(agg.op, agg_ops.ProductOp): + func_ref = self._EXTENSIONS["product"] + elif isinstance(agg.op, agg_ops.MedianOp): + func_ref = self._EXTENSIONS["median"] + elif isinstance(agg.op, agg_ops.CovOp): + func_ref = self._EXTENSIONS["cov"] + elif isinstance(agg.op, agg_ops.CorrOp): + func_ref = self._EXTENSIONS["corr"] + else: + raise NotImplementedError( + f"Aggregation {type(agg.op)} not supported in Substrait compiler yet" + ) + + measure = agg_rel.measures.add() + measure.measure.function_reference = func_ref + if distinct or isinstance(agg.op, agg_ops.NuniqueOp): + measure.measure.invocation = ( + algebra_pb2.AggregateFunction.AGGREGATION_INVOCATION_DISTINCT + ) + + if isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + size_count += 1 + arg = measure.measure.arguments.add() + arg.value.literal.i64 = size_count + elif hasattr(agg, "column_references"): + for col_id in agg.column_references: + try: + idx = list(node.child.ids).index(col_id) + field_expr = algebra_pb2.Expression() + field_expr.selection.direct_reference.struct_field.field = idx + + col_dtype = node.child.schema.items[idx].dtype + is_bool = col_dtype == dtypes.BOOL_DTYPE + if isinstance( + agg.op, (agg_ops.StdOp, agg_ops.VarOp, agg_ops.PopVarOp) + ) or ( + isinstance(agg.op, (agg_ops.SumOp, agg_ops.MeanOp)) + and is_bool + ): + casted_expr = self._compile_cast( + field_expr, dtypes.FLOAT_DTYPE + ) + arg = measure.measure.arguments.add() + arg.value.CopyFrom(casted_expr) + else: + arg = measure.measure.arguments.add() + arg.value.CopyFrom(field_expr) + except ValueError: + pass + + if node.dropna and node.by_column_ids: + not_null_exprs = [] + for idx in range(len(node.by_column_ids)): + key_expr = algebra_pb2.Expression() + key_expr.selection.direct_reference.struct_field.field = idx + + not_null_op = algebra_pb2.Expression() + not_null_op.scalar_function.function_reference = self._EXTENSIONS[ + "is_not_null" + ] + not_null_op.scalar_function.arguments.add().value.CopyFrom(key_expr) + not_null_exprs.append(not_null_op) + + if len(not_null_exprs) > 1: + expr = not_null_exprs[0] + for e in not_null_exprs[1:]: + and_expr = algebra_pb2.Expression() + and_expr.scalar_function.function_reference = self._EXTENSIONS[ + "and" + ] + and_expr.scalar_function.arguments.add().value.CopyFrom(expr) + and_expr.scalar_function.arguments.add().value.CopyFrom(e) + expr = and_expr + else: + expr = not_null_exprs[0] + + filter_rel = algebra_pb2.Rel() + filter_rel.filter.input.CopyFrom(rel) + filter_rel.filter.condition.CopyFrom(expr) + rel = filter_rel + + return rel + + def _compile_orderby(self, node: nodes.OrderByNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + sort_rel = rel.sort + sort_rel.input.CopyFrom(input_rel) + + for ord_expr in node.by: + sort_field = sort_rel.sorts.add() + + # Compile the expression: + expr_pb = self._compile_expression(ord_expr.scalar_expression, node.child) + sort_field.expr.CopyFrom(expr_pb) + + # Map sort direction: + is_asc = ord_expr.direction.is_ascending + if is_asc: + if ord_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_FIRST + else: + if ord_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_FIRST + + return rel + + _EXTENSIONS = { + "add": 1, + "subtract": 2, + "multiply": 3, + "divide": 4, + "equal": 5, + "not_equal": 6, + "lt": 7, + "gt": 8, + "lte": 9, + "gte": 10, + "sum": 11, + "max": 12, + "and": 13, + "min": 14, + "mean": 15, + "count": 16, + "stddev": 17, + "var": 18, + "any_value": 19, + "all": 20, + "any": 21, + "coalesce": 22, + "or": 23, + "least": 24, + "greatest": 25, + "is_null": 26, + "is_not_null": 27, + "nullif": 28, + "sqrt": 29, + "bool_and": 30, + "bool_or": 31, + "product": 32, + "not": 33, + "mod": 34, + "floor": 35, + "abs": 36, + "ceil": 37, + "median": 38, + "xor": 40, + "var_pop": 53, + "row_number": 60, + "rank": 61, + "dense_rank": 62, + "first_value": 63, + "last_value": 64, + "lag": 65, + "lead": 66, + "struct": 67, + "get_field": 68, + "pow": 69, + "cov": 70, + "corr": 71, + } + + _OP_TO_EXTENSION = { + numeric_ops.AddOp: "add", + numeric_ops.SubOp: "subtract", + numeric_ops.MulOp: "multiply", + numeric_ops.DivOp: "divide", + numeric_ops.ModOp: "mod", + numeric_ops.PowOp: "pow", + numeric_ops.UnsafePowOp: "pow", + comparison_ops.EqOp: "equal", + comparison_ops.NeOp: "not_equal", + comparison_ops.LtOp: "lt", + comparison_ops.GtOp: "gt", + comparison_ops.LeOp: "lte", + comparison_ops.GeOp: "gte", + generic_ops.FillNaOp: "coalesce", + generic_ops.CoalesceOp: "coalesce", + bool_ops.AndOp: "and", + bool_ops.OrOp: "or", + bool_ops.XorOp: "xor", + generic_ops.InvertOp: "not", + numeric_ops.AbsOp: "abs", + numeric_ops.CeilOp: "ceil", + numeric_ops.FloorOp: "floor", + generic_ops.IsNullOp: "is_null", + generic_ops.NotNullOp: "is_not_null", + } + + @singledispatchmethod + def _compile_expression( + self, expr: ex.Expression, child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + raise NotImplementedError( + f"Expression type {type(expr)} not supported in Substrait compiler yet" + ) + + @_compile_expression.register + def _compile_scalar_constant( + self, expr: ex.ScalarConstantExpression, child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + import datetime + + pb_expr = algebra_pb2.Expression() + val = expr.value + if isinstance(val, bool): + pb_expr.literal.boolean = val + elif isinstance(val, int): + pb_expr.literal.i64 = val + elif isinstance(val, float): + pb_expr.literal.fp64 = val + elif isinstance(val, str): + pb_expr.literal.string = val + elif isinstance(val, (pd.Timestamp, datetime.datetime)): + if getattr(val, "tzinfo", None) is not None: + epoch = pd.Timestamp("1970-01-01", tz=val.tzinfo) + us = int((val - epoch).total_seconds() * 1_000_000) + pb_expr.literal.precision_timestamp_tz.precision = 6 + pb_expr.literal.precision_timestamp_tz.value = us + else: + epoch = pd.Timestamp("1970-01-01") + us = int((val - epoch).total_seconds() * 1_000_000) + pb_expr.literal.precision_timestamp.precision = 6 + pb_expr.literal.precision_timestamp.value = us + elif isinstance(val, datetime.date): + epoch = datetime.date(1970, 1, 1) + days = (val - epoch).days + pb_expr.literal.date = days + elif pd.isna(val): + pb_expr.literal.null.varchar.length = 0 + else: + pb_expr.literal.string = str(val) + return pb_expr + + @_compile_expression.register + def _compile_deref( + self, expr: ex.DerefOp, child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + try: + idx = list(child.ids).index(expr.id) + pb_expr.selection.direct_reference.struct_field.field = idx + return pb_expr + except ValueError: + raise ValueError(f"Column {expr.id} not found in child schema") + + @_compile_expression.register + def _compile_op_expr( + self, expr: ex.OpExpression, child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + pb_expr = self._compile_op(expr.op, expr.inputs, child) + if pb_expr.HasField("scalar_function"): + if not pb_expr.scalar_function.HasField("output_type"): + output_dtype = self._get_expression_dtype(expr, child) + type_dict = self._convert_type(output_dtype) + json_format.ParseDict(type_dict, pb_expr.scalar_function.output_type) + return pb_expr + + @singledispatchmethod + def _compile_op( + self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + raise NotImplementedError( + f"Op type {type(op)} not supported in Substrait compiler yet" + ) + + @_compile_op.register(ops.AsTypeOp) + def _compile_astype( + self, + op: ops.AsTypeOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + arg_expr = self._compile_expression(inputs[0], child) + return self._compile_cast(arg_expr, op.to_type) + + @_compile_op.register(struct_ops.StructOp) + def _compile_struct_op( + self, + op: struct_ops.StructOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["struct"] + for arg in inputs: + arg_expr = self._compile_expression(arg, child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return pb_expr + + @_compile_op.register(struct_ops.StructFieldOp) + def _compile_struct_field_op( + self, + op: struct_ops.StructFieldOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["get_field"] + + # Arg 0: the struct + arg_expr = self._compile_expression(inputs[0], child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + + # Arg 1: the field name as string literal + literal_expr = algebra_pb2.Expression() + literal_expr.literal.string = str(op.name_or_index) + pb_expr.scalar_function.arguments.add().value.CopyFrom(literal_expr) + return pb_expr + + def _compile_cast( + self, input_expr: algebra_pb2.Expression, target_dtype: Any + ) -> algebra_pb2.Expression: + if input_expr.HasField("literal") and input_expr.literal.HasField("null"): + pb_expr = algebra_pb2.Expression() + type_dict = self._convert_type(target_dtype) + json_format.ParseDict(type_dict, pb_expr.literal.null) + return pb_expr + + pb_expr = algebra_pb2.Expression() + cast = pb_expr.cast + cast.input.CopyFrom(input_expr) + + type_dict = self._convert_type(target_dtype) + json_format.ParseDict(type_dict, cast.type) + + # alternative: FAILURE_BEHAVIOR_RETURN_NULL not supported by acero + cast.failure_behavior = ( + algebra_pb2.Expression.Cast.FAILURE_BEHAVIOR_THROW_EXCEPTION + ) + return pb_expr + + def _get_expression_dtype( + self, expr: ex.Expression, child: nodes.BigFrameNode + ) -> Any: + import bigframes.dtypes as dtypes + + if isinstance(expr, ex.ScalarConstantExpression): + if expr.value is None or pd.isna(expr.value): + return None + return expr.dtype or dtypes.infer_literal_type(expr.value) + elif isinstance(expr, ex.DerefOp): + try: + idx = list(child.ids).index(expr.id) + return child.schema.items[idx].dtype + except ValueError: + pass + elif isinstance(expr, ex.OpExpression): + try: + input_dtypes = [ + self._get_expression_dtype(inp, child) for inp in expr.inputs + ] + return expr.op.output_type(*input_dtypes) + except Exception: + pass + return dtypes.STRING_DTYPE + + def _get_common_type(self, dtypes_list: Sequence[Any]) -> Any: + import bigframes.dtypes as dtypes + + non_null_dtypes = [dt for dt in dtypes_list if dt is not None] + if not non_null_dtypes: + return dtypes.STRING_DTYPE + if len(set(non_null_dtypes)) == 1: + return non_null_dtypes[0] + if any(dt == dtypes.STRING_DTYPE for dt in non_null_dtypes): + return dtypes.STRING_DTYPE + if any(dt == dtypes.FLOAT_DTYPE for dt in non_null_dtypes): + return dtypes.FLOAT_DTYPE + if any(dt == dtypes.INT_DTYPE for dt in non_null_dtypes): + return dtypes.INT_DTYPE + return dtypes.STRING_DTYPE + + @_compile_op.register(ops.CaseWhenOp) + def _compile_casewhen( + self, + op: ops.CaseWhenOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + ifthen = pb_expr.if_then + + then_dtypes = [ + self._get_expression_dtype(inputs[idx], child) + for idx in range(1, len(inputs), 2) + ] + common_dtype = self._get_common_type(then_dtypes) + + for idx in range(0, len(inputs), 2): + pred = self._compile_expression(inputs[idx], child) + val_expr = self._compile_expression(inputs[idx + 1], child) + + val_dtype = then_dtypes[idx // 2] + if val_dtype != common_dtype: + val = self._compile_cast(val_expr, common_dtype) + else: + val = val_expr + + if_clause = ifthen.ifs.add() + getattr(if_clause, "if").CopyFrom(pred) + if_clause.then.CopyFrom(val) + + type_dict = self._convert_type(common_dtype) + json_format.ParseDict(type_dict, getattr(ifthen, "else").literal.null) + return pb_expr + + @_compile_op.register(generic_ops.WhereOp) + def _compile_where( + self, + op: generic_ops.WhereOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + ifthen = pb_expr.if_then + + pred = self._compile_expression(inputs[1], child) + then_val = self._compile_expression(inputs[0], child) + else_val = self._compile_expression(inputs[2], child) + + then_dtype = self._get_expression_dtype(inputs[0], child) + else_dtype = self._get_expression_dtype(inputs[2], child) + common_dtype = self._get_common_type([then_dtype, else_dtype]) + + casted_then = self._compile_cast(then_val, common_dtype) + casted_else = self._compile_cast(else_val, common_dtype) + + if_clause = ifthen.ifs.add() + getattr(if_clause, "if").CopyFrom(pred) + if_clause.then.CopyFrom(casted_then) + + getattr(ifthen, "else").CopyFrom(casted_else) + return pb_expr + + @_compile_op.register(numeric_ops.DivOp) + def _compile_div_op( + self, + op: numeric_ops.DivOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["divide"] + for arg in inputs: + arg_expr = self._compile_expression(arg, child) + casted_arg = self._compile_cast(arg_expr, dtypes.FLOAT_DTYPE) + pb_expr.scalar_function.arguments.add().value.CopyFrom(casted_arg) + return pb_expr + + @_compile_op.register(numeric_ops.FloorDivOp) + def _compile_floor_div_op( + self, + op: numeric_ops.FloorDivOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + + dividend_expr = self._compile_expression(inputs[0], child) + divisor_expr = self._compile_expression(inputs[1], child) + + # Calculate standard floor division + div_expr = algebra_pb2.Expression() + div_expr.scalar_function.function_reference = self._EXTENSIONS["divide"] + + # Cast to float for standard division + casted_dividend = self._compile_cast(dividend_expr, dtypes.FLOAT_DTYPE) + casted_divisor = self._compile_cast(divisor_expr, dtypes.FLOAT_DTYPE) + + div_expr.scalar_function.arguments.add().value.CopyFrom(casted_dividend) + div_expr.scalar_function.arguments.add().value.CopyFrom(casted_divisor) + + floor_expr = algebra_pb2.Expression() + floor_expr.scalar_function.function_reference = self._EXTENSIONS["floor"] + floor_expr.scalar_function.arguments.add().value.CopyFrom(div_expr) + + # If both operands are integer/boolean, we short-circuit division by 0 to return 0 + left_dtype = self._get_expression_dtype(inputs[0], child) + right_dtype = self._get_expression_dtype(inputs[1], child) + + is_left_int = left_dtype == dtypes.INT_DTYPE or left_dtype == dtypes.BOOL_DTYPE + is_right_int = ( + right_dtype == dtypes.INT_DTYPE or right_dtype == dtypes.BOOL_DTYPE + ) + + if is_left_int and is_right_int: + # If divisor is 0, return 0 * dividend (to propagate nulls) + zero_i64 = algebra_pb2.Expression() + zero_i64.literal.i64 = 0 + + eq_expr = algebra_pb2.Expression() + eq_expr.scalar_function.function_reference = self._EXTENSIONS["equal"] + eq_expr.scalar_function.arguments.add().value.CopyFrom(divisor_expr) + eq_expr.scalar_function.arguments.add().value.CopyFrom(zero_i64) + + zero_result = algebra_pb2.Expression() + zero_result.scalar_function.function_reference = self._EXTENSIONS[ + "multiply" + ] + zero_result.scalar_function.arguments.add().value.CopyFrom(dividend_expr) + zero_result.scalar_function.arguments.add().value.CopyFrom(zero_i64) + + pb_expr = algebra_pb2.Expression() + ifthen = pb_expr.if_then + if_clause = ifthen.ifs.add() + getattr(if_clause, "if").CopyFrom(eq_expr) + if_clause.then.CopyFrom(zero_result) + + # Else, cast float floor_expr to int64 + casted_floor = self._compile_cast(floor_expr, dtypes.INT_DTYPE) + getattr(ifthen, "else").CopyFrom(casted_floor) + return pb_expr + + return floor_expr + + @_compile_op.register(generic_ops.IsInOp) + def _compile_isin( + self, + op: generic_ops.IsInOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.singular_or_list.value.CopyFrom( + self._compile_expression(inputs[0], child) + ) + for val in op.values: + opt_expr = self._compile_expression(ex.const(val), child) + pb_expr.singular_or_list.options.add().CopyFrom(opt_expr) + return pb_expr + + @_compile_op.register(generic_ops.FillNaOp) + def _compile_fillna_op( + self, + op: generic_ops.FillNaOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + first_expr = self._compile_expression(inputs[0], child) + first_dtype = self._get_expression_dtype(inputs[0], child) + second_expr = self._compile_expression(inputs[1], child) + second_dtype = self._get_expression_dtype(inputs[1], child) + + if first_dtype is not None and second_dtype != first_dtype: + second_expr = self._compile_cast(second_expr, first_dtype) + + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["coalesce"] + pb_expr.scalar_function.arguments.add().value.CopyFrom(first_expr) + pb_expr.scalar_function.arguments.add().value.CopyFrom(second_expr) + return pb_expr + + @_compile_op.register(generic_ops.CoalesceOp) + @_compile_op.register(numeric_ops.AddOp) + @_compile_op.register(numeric_ops.SubOp) + @_compile_op.register(numeric_ops.MulOp) + @_compile_op.register(numeric_ops.PowOp) + @_compile_op.register(numeric_ops.UnsafePowOp) + @_compile_op.register(comparison_ops.EqOp) + @_compile_op.register(comparison_ops.NeOp) + @_compile_op.register(comparison_ops.LtOp) + @_compile_op.register(comparison_ops.GtOp) + @_compile_op.register(comparison_ops.LeOp) + @_compile_op.register(comparison_ops.GeOp) + @_compile_op.register(bool_ops.AndOp) + @_compile_op.register(bool_ops.OrOp) + @_compile_op.register(bool_ops.XorOp) + def _compile_basic_binops( + self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + op_class = type(op) + ext_name = self._OP_TO_EXTENSION[op_class] + return self._compile_basic_binop(ext_name, inputs, child) + + def _compile_basic_binop( + self, ext_name: str, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS[ext_name] + for arg in inputs: + arg_expr = self._compile_expression(arg, child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return pb_expr + + @_compile_op.register(numeric_ops.ModOp) + def _compile_mod_op( + self, + op: numeric_ops.ModOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + + a_expr = self._compile_expression(inputs[0], child) + b_expr = self._compile_expression(inputs[1], child) + + div_expr = algebra_pb2.Expression() + div_expr.scalar_function.function_reference = self._EXTENSIONS["divide"] + + a_float = self._compile_cast(a_expr, dtypes.FLOAT_DTYPE) + b_float = self._compile_cast(b_expr, dtypes.FLOAT_DTYPE) + div_expr.scalar_function.arguments.add().value.CopyFrom(a_float) + div_expr.scalar_function.arguments.add().value.CopyFrom(b_float) + + floor_expr = algebra_pb2.Expression() + floor_expr.scalar_function.function_reference = self._EXTENSIONS["floor"] + floor_expr.scalar_function.arguments.add().value.CopyFrom(div_expr) + + mul_expr = algebra_pb2.Expression() + mul_expr.scalar_function.function_reference = self._EXTENSIONS["multiply"] + mul_expr.scalar_function.arguments.add().value.CopyFrom(b_float) + mul_expr.scalar_function.arguments.add().value.CopyFrom(floor_expr) + + sub_expr = algebra_pb2.Expression() + sub_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + sub_expr.scalar_function.arguments.add().value.CopyFrom(a_float) + sub_expr.scalar_function.arguments.add().value.CopyFrom(mul_expr) + + a_dtype = self._get_expression_dtype(inputs[0], child) + b_dtype = self._get_expression_dtype(inputs[1], child) + common_dtype = self._get_common_type([a_dtype, b_dtype]) + + if common_dtype == dtypes.INT_DTYPE: + return self._compile_cast(sub_expr, dtypes.INT_DTYPE) + return sub_expr + + @_compile_op.register(numeric_ops.AbsOp) + @_compile_op.register(numeric_ops.CeilOp) + @_compile_op.register(numeric_ops.FloorOp) + @_compile_op.register(generic_ops.IsNullOp) + @_compile_op.register(generic_ops.NotNullOp) + def _compile_standard_unaryops( + self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + op_class = type(op) + ext_name = self._OP_TO_EXTENSION[op_class] + return self._compile_basic_unaryop(ext_name, inputs, child) + + @_compile_op.register(numeric_ops.PosOp) + def _compile_pos_op( + self, + op: numeric_ops.PosOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + # Unary plus is a no-op + return self._compile_expression(inputs[0], child) + + @_compile_op.register(numeric_ops.NegOp) + def _compile_neg_op( + self, + op: numeric_ops.NegOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + # Compile negation as subtraction: 0 - x + arg_expr = self._compile_expression(inputs[0], child) + arg_dtype = self._get_expression_dtype(inputs[0], child) + + zero_expr = algebra_pb2.Expression() + if arg_dtype == dtypes.FLOAT_DTYPE: + zero_expr.literal.fp64 = 0.0 + else: + zero_expr.literal.i64 = 0 + + sub_expr = algebra_pb2.Expression() + sub_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + sub_expr.scalar_function.arguments.add().value.CopyFrom(zero_expr) + sub_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return sub_expr + + @_compile_op.register(generic_ops.InvertOp) + def _compile_invert_op( + self, + op: generic_ops.InvertOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + arg_expr = self._compile_expression(inputs[0], child) + arg_dtype = self._get_expression_dtype(inputs[0], child) + + if arg_dtype == dtypes.BOOL_DTYPE: + # Logical negation + not_expr = algebra_pb2.Expression() + not_expr.scalar_function.function_reference = self._EXTENSIONS["not"] + not_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return not_expr + else: + # Bitwise negation (two's complement mathematically equivalent to: -x - 1) + zero_i64 = algebra_pb2.Expression() + zero_i64.literal.i64 = 0 + + neg_expr = algebra_pb2.Expression() + neg_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + neg_expr.scalar_function.arguments.add().value.CopyFrom(zero_i64) + neg_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + + one_i64 = algebra_pb2.Expression() + one_i64.literal.i64 = 1 + + result_expr = algebra_pb2.Expression() + result_expr.scalar_function.function_reference = self._EXTENSIONS[ + "subtract" + ] + result_expr.scalar_function.arguments.add().value.CopyFrom(neg_expr) + result_expr.scalar_function.arguments.add().value.CopyFrom(one_i64) + return result_expr + + def _compile_basic_unaryop( + self, ext_name: str, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS[ext_name] + arg_expr = self._compile_expression(inputs[0], child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return pb_expr + + def _convert_schema(self, schema: Any) -> Dict[str, Any]: + # Convert bigframes schema to Substrait Type.NamedStruct + fields = [] + types = [] + for item in schema.items: + col = item.column + name = col.name if hasattr(col, "name") else str(col) + fields.append(name) + types.append(self._convert_type(item.dtype)) + + return {"names": fields, "struct": {"types": types}} + + def _get_substrait_names(self, name: str, dtype: Any) -> list[str]: + import bigframes.dtypes as dtypes + + names = [name] + if dtypes.is_struct_like(dtype): + fields_dict = dtypes.get_struct_fields(dtype) + for f_name, f_dtype in fields_dict.items(): + names.extend(self._get_substrait_names(f_name, f_dtype)) + return names + + def _convert_type(self, dtype: Any) -> Dict[str, Any]: + import bigframes.dtypes + + if dtype == bigframes.dtypes.INT_DTYPE: + return {"i64": {}} + elif dtype == bigframes.dtypes.FLOAT_DTYPE: + return {"fp64": {}} + elif dtype == bigframes.dtypes.BOOL_DTYPE: + return {"bool": {}} + elif dtype == bigframes.dtypes.STRING_DTYPE: + return {"string": {}} + elif dtype == bigframes.dtypes.BYTES_DTYPE: + return {"binary": {}} + elif dtype == bigframes.dtypes.DATE_DTYPE: + return {"date": {}} + elif dtype == bigframes.dtypes.DATETIME_DTYPE: + if self._use_precision_types: + return {"precision_timestamp": {"precision": 6}} + else: + return {"timestamp": {}} + elif dtype == bigframes.dtypes.TIMESTAMP_DTYPE: + if self._use_precision_types: + return {"precision_timestamp_tz": {"precision": 6}} + else: + return {"timestamp_tz": {}} + elif dtype == bigframes.dtypes.TIME_DTYPE: + if self._use_precision_types: + # type_variation_reference 1 is for time64, precision 6 is for microseconds + return { + "precision_time": {"precision": 6, "type_variation_reference": 1} + } + else: + return {"time": {}} + elif dtype in ( + bigframes.dtypes.NUMERIC_DTYPE, + bigframes.dtypes.BIGNUMERIC_DTYPE, + ): + arrow_dtype = dtype.pyarrow_dtype + return { + "decimal": { + "precision": arrow_dtype.precision, + "scale": arrow_dtype.scale, + } + } + elif dtype == bigframes.dtypes.TIMEDELTA_DTYPE: + if self._duration_type == "interval_day": + return {"interval_day": {"precision": 6, "type_variation_reference": 1}} + else: + return {"i64": {}} + elif bigframes.dtypes.is_struct_like(dtype): + fields_dict = bigframes.dtypes.get_struct_fields(dtype) + return { + "struct": { + "types": [ + self._convert_type(f_dtype) for f_dtype in fields_dict.values() + ] + } + } + elif bigframes.dtypes.is_array_like(dtype): + inner_dtype = bigframes.dtypes.get_array_inner_type(dtype) + return {"list": {"type": self._convert_type(inner_dtype)}} + else: + # Fallback to string for now + return {"string": {}} diff --git a/packages/bigframes/bigframes/core/rewrite/__init__.py b/packages/bigframes/bigframes/core/rewrite/__init__.py index ae4b142b1a46..319b1f69bcd5 100644 --- a/packages/bigframes/bigframes/core/rewrite/__init__.py +++ b/packages/bigframes/bigframes/core/rewrite/__init__.py @@ -25,6 +25,7 @@ try_reduce_to_local_scan, try_reduce_to_table_scan, ) +from bigframes.core.rewrite.schema_binding import bind_schema_to_tree from bigframes.core.rewrite.select_pullup import defer_selection from bigframes.core.rewrite.slices import pull_out_limit, pull_up_limits, rewrite_slice from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions @@ -37,6 +38,7 @@ __all__ = [ "as_sql_nodes", + "bind_schema_to_tree", "extract_ctes", "legacy_join_as_projection", "try_row_join", diff --git a/packages/bigframes/bigframes/session/substrait_executor.py b/packages/bigframes/bigframes/session/substrait_executor.py new file mode 100644 index 000000000000..2f4ba53b4b5f --- /dev/null +++ b/packages/bigframes/bigframes/session/substrait_executor.py @@ -0,0 +1,205 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +import asyncio +from typing import TYPE_CHECKING, Optional, cast + +import bigframes.core.compile.substrait.compiler as substrait_compiler +import bigframes.core.rewrite as rewrite +from bigframes.core import bigframe_node, nodes +from bigframes.session import executor, semi_executor + +if TYPE_CHECKING: + import pyarrow as pa + + +class SubstraitConsumer(abc.ABC): + """ + Interface for consuming Substrait plans and executing them. + This acts as a plugin interface for different Substrait execution engines. + """ + + @abc.abstractmethod + def consume(self, plan: bytes, tables: dict[str, pa.Table]) -> pa.Table: + """ + Executes a Substrait plan and returns a PyArrow Table. + + Args: + plan: The Substrait plan as bytes (usually a serialized Protobuf). + tables: A dictionary of table names to PyArrow Tables for local data. + + Returns: + A PyArrow Table containing the results. + """ + pass + + +class DataFusionSubstraitConsumer(SubstraitConsumer): + """ + Executes Substrait plans using Apache DataFusion. + """ + + def consume(self, plan_proto: bytes, tables: dict[str, pa.Table]) -> pa.Table: + # Import datafusion lazily to avoid hard dependency + try: + import datafusion + except ImportError: + raise ImportError( + "The datafusion package is required to use DataFusionSubstraitConsumer. " + "Install it with `pip install datafusion`." + ) + + ctx = datafusion.SessionContext() + + for name, table in tables.items(): + df = ctx.from_arrow_table(table) + ctx.register_table(name, df) + + import datafusion.substrait + + datafusion_substrait_plan = datafusion.substrait.Serde.deserialize_bytes( + plan_proto + ) + logical_plan = datafusion.substrait.Consumer.from_substrait_plan( + ctx, datafusion_substrait_plan + ) + df = ctx.create_dataframe_from_logical_plan(logical_plan) + return df.to_arrow_table() + + +class AceroSubstraitConsumer(SubstraitConsumer): + """ + Executes Substrait plans using Apache Arrow Acero. + """ + + def consume(self, plan_proto: bytes, tables: dict[str, pa.Table]) -> pa.Table: + import pyarrow.substrait as pa_substrait + + def provide_table(name: list[str], schema: pa.Schema) -> pa.Table: + return tables[name[0]] + + batch_reader = pa_substrait.run_query(plan_proto, table_provider=provide_table) + return batch_reader.read_all() + + +class SubstraitExecutor(semi_executor.SemiExecutor): + """ + Executes plans by compiling them to Substrait and running them via a consumer. + """ + + def __init__( + self, + consumer: SubstraitConsumer, + compiler: substrait_compiler.SubstraitCompiler, + ): + self._consumer = consumer + self._compiler = compiler + + @classmethod + def default_for_engine(cls, engine_name: str) -> SubstraitExecutor: + if engine_name == "acero": + return cls( + AceroSubstraitConsumer(), + substrait_compiler.SubstraitCompiler( + duration_type="int", use_precision_types=False + ), + ) + elif engine_name == "datafusion": + return cls( + DataFusionSubstraitConsumer(), + substrait_compiler.SubstraitCompiler(duration_type="int"), + ) + else: + raise ValueError(f"Unknown engine: {engine_name}") + + async def execute( + self, + plan: bigframe_node.BigFrameNode, + ordered: bool, + peek: Optional[int] = None, + ) -> Optional[executor.ExecuteResult]: + plan = plan.bottom_up(rewrite.rewrite_slice) + # Only needed for acero technically, datafusion can handle timedeltas + plan = plan.bottom_up(rewrite.rewrite_timedelta_expressions) + + from bigframes.core import expression + + output_cols = tuple((expression.DerefOp(id), id.name) for id in plan.ids) + result_node = nodes.ResultNode( + plan, + output_cols=output_cols, + ) + result_node = cast(nodes.ResultNode, rewrite.column_pruning(result_node)) + result_node = rewrite.defer_order(result_node, output_hidden_row_keys=False) + + rewritten_plan = result_node.child + + if ( + ordered + and result_node.order_by + and result_node.order_by.all_ordering_columns + ): + rewritten_plan = nodes.OrderByNode( + rewritten_plan, + by=tuple(result_node.order_by.all_ordering_columns), + ) + + original_ids = tuple(id for id in plan.ids) + if rewritten_plan.ids != original_ids: + rewritten_plan = nodes.SelectionNode( + rewritten_plan, + input_output_pairs=tuple( + nodes.AliasedRef.identity(id) for id in original_ids + ), + ) + + if not self._can_execute(rewritten_plan): + return None + + substrait_plan_proto = self._compiler.compile(rewritten_plan) + if substrait_plan_proto is None: + return None + + tables = {} + for node in rewritten_plan.unique_nodes(): + if isinstance(node, nodes.ReadLocalNode): + table_name = f"table_{id(node)}" + table = node.local_data_source.to_pyarrow_table(duration_type="int") + table = table.select([item.source_id for item in node.scan_list.items]) + table = table.rename_columns( + [item.id.sql for item in node.scan_list.items] + ) + if node.offsets_col is not None: + from bigframes.core import pyarrow_utils + + table = pyarrow_utils.append_offsets(table, node.offsets_col.sql) + tables[table_name] = table + + pa_table = await asyncio.to_thread( + self._consumer.consume, substrait_plan_proto, tables + ) + + if peek is not None: + pa_table = pa_table.slice(0, peek) + + return executor.LocalExecuteResult( + data=pa_table, + bf_schema=rewritten_plan.schema, + ) + + def _can_execute(self, plan: bigframe_node.BigFrameNode) -> bool: + return self._compiler.can_compile(plan) diff --git a/packages/bigframes/bigframes/testing/substrait_session.py b/packages/bigframes/bigframes/testing/substrait_session.py new file mode 100644 index 000000000000..98d573fd5df8 --- /dev/null +++ b/packages/bigframes/bigframes/testing/substrait_session.py @@ -0,0 +1,118 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import weakref +from typing import TYPE_CHECKING, Union + +import pandas + +import bigframes +import bigframes.core.blocks +import bigframes.dataframe +import bigframes.session.execution_spec +import bigframes.session.executor +import bigframes.session.metrics + +if TYPE_CHECKING: + import bigframes.core + + +class SubstraitTestExecutor(bigframes.session.executor.Executor): + def __init__( + self, consumer: bigframes.session.substrait_executor.SubstraitConsumer + ): + from bigframes.session.substrait_executor import SubstraitExecutor + + self.executor = SubstraitExecutor(consumer) + + def execute( + self, + array_value: bigframes.core.ArrayValue, + execution_spec: bigframes.session.execution_spec.ExecutionSpec, + ): + if execution_spec.destination_spec is not None: + raise ValueError( + f"SubstraitTestExecutor does not support destination spec: {execution_spec.destination_spec}" + ) + + result = asyncio.run( + self.executor.execute( + array_value.node, ordered=True, peek=execution_spec.peek + ) + ) + if result is None: + raise NotImplementedError("SubstraitExecutor cannot execute this plan") + + return result + + def cached( + self, + array_value: bigframes.core.ArrayValue, + *, + config, + ) -> None: + return + + +class TestSession(bigframes.session.Session): + def __init__(self, executor: SubstraitTestExecutor): + self._location = None # type: ignore + self._bq_kms_key_name = None # type: ignore + self._clients_provider = None # type: ignore + self._bq_connection = None # type: ignore + self._skip_bq_connection_check = True + self._session_id: str = "substrait_test_session" + self._objects: list[ + weakref.ReferenceType[ + Union[ + bigframes.core.indexes.Index, + bigframes.series.Series, + bigframes.dataframe.DataFrame, + ] + ] + ] = [] + self._strictly_ordered: bool = True + self._allow_ambiguity = False # type: ignore + self._default_index_type = bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64 + self._metrics = bigframes.session.metrics.ExecutionMetrics() + self._function_session = None # type: ignore + self._temp_storage_manager = None # type: ignore + self._executor = executor + self._loader = None # type: ignore + + def read_pandas(self, pandas_dataframe, write_engine="default"): + original_input = pandas_dataframe + + if isinstance(pandas_dataframe, (pandas.Series, pandas.Index)): + pandas_dataframe = pandas_dataframe.to_frame() + + local_block = bigframes.core.blocks.Block.from_local(pandas_dataframe, self) + bf_df = bigframes.dataframe.DataFrame(local_block) + + if isinstance(original_input, pandas.Series): + series = bf_df[bf_df.columns[0]] + series.name = original_input.name + return series + + if isinstance(original_input, pandas.Index): + return bf_df.index + + return bf_df + + @property + def bqclient(self): + return None diff --git a/packages/bigframes/noxfile.py b/packages/bigframes/noxfile.py index c8d08c6787d0..824a07254bb8 100644 --- a/packages/bigframes/noxfile.py +++ b/packages/bigframes/noxfile.py @@ -72,12 +72,12 @@ UNIT_TEST_DEPENDENCIES: List[str] = [] UNIT_TEST_EXTRAS: List[str] = ["tests"] UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = { - "3.10": ["tests", "scikit-learn", "anywidget"], - "3.11": ["tests", "polars", "scikit-learn", "anywidget"], + "3.10": ["tests", "polars", "datafusion", "substrait", "scikit-learn", "anywidget"], + "3.11": ["tests", "polars", "datafusion", "substrait", "scikit-learn", "anywidget"], # Make sure we leave some versions without "extras" so we know those # dependencies are actually optional. - "3.13": ["tests", "polars", "scikit-learn", "anywidget"], - "3.14": ["tests", "polars", "scikit-learn", "anywidget"], + "3.13": ["tests", "polars", "datafusion", "substrait", "scikit-learn", "anywidget"], + "3.14": ["tests", "polars", "datafusion", "substrait", "scikit-learn", "anywidget"], } # 3.11 is used by colab. @@ -106,9 +106,9 @@ # Make sure we leave some versions without "extras" so we know those # dependencies are actually optional. "3.10": ["tests", "scikit-learn", "anywidget"], - "3.12": ["tests", "scikit-learn", "polars", "anywidget"], - "3.13": ["tests", "polars", "anywidget"], - "3.14": ["tests", "polars", "anywidget"], + "3.12": ["tests", "scikit-learn", "polars", "datafusion", "substrait", "anywidget"], + "3.13": ["tests", "polars", "datafusion", "substrait", "anywidget"], + "3.14": ["tests", "polars", "datafusion", "substrait", "anywidget"], } LOGGING_NAME_ENV_VAR = "BIGFRAMES_PERFORMANCE_LOG_NAME" @@ -123,9 +123,7 @@ # TODO(tswast): Consider removing this when unit_noextras and cover is run # from GitHub actions. "unit_noextras", - "system-3.10", # No extras. "system-3.12", # No extras. - f"system-{DEFAULT_PYTHON_VERSION}", # All extras. "cover", # TODO(b/401609005): remove "cleanup", diff --git a/packages/bigframes/setup.py b/packages/bigframes/setup.py index 138c52879526..30a021c04510 100644 --- a/packages/bigframes/setup.py +++ b/packages/bigframes/setup.py @@ -79,6 +79,8 @@ ], # used for local engine "polars": ["polars >= 1.21.0"], + "datafusion": ["datafusion >= 45.0.0", "substrait >= 0.29"], + "substrait": ["substrait >= 0.29"], "scikit-learn": ["scikit-learn>=1.2.2"], # Packages required for basic development flow. "dev": [ diff --git a/packages/bigframes/testing/constraints-3.10.txt b/packages/bigframes/testing/constraints-3.10.txt index 137710df3299..685d3baae05e 100644 --- a/packages/bigframes/testing/constraints-3.10.txt +++ b/packages/bigframes/testing/constraints-3.10.txt @@ -16,6 +16,9 @@ grpc-google-iam-v1==0.14.2 numpy==1.24.0 pandas==1.5.3 pandas-gbq==0.26.1 +polars==1.21.0 +substrait==0.29.0 +datafusion==45.0.0 pyarrow==15.0.2 pydata-google-auth==1.8.2 pyiceberg==0.7.1 diff --git a/packages/bigframes/tests/system/small/engines/conftest.py b/packages/bigframes/tests/system/small/engines/conftest.py index 823ba9806d58..bdaa0d74749a 100644 --- a/packages/bigframes/tests/system/small/engines/conftest.py +++ b/packages/bigframes/tests/system/small/engines/conftest.py @@ -17,6 +17,18 @@ import google.cloud.bigquery_storage_v1 import pandas as pd import pytest + +# Skip the entire engines test directory if required libraries are missing +try: + import datafusion # noqa: F401 + import polars # noqa: F401 + import substrait # noqa: F401 +except ImportError as e: + pytest.skip( + f"Skipping engines tests because dependencies are missing: {e}", + allow_module_level=True, + ) + from google.cloud import bigquery import bigframes @@ -26,6 +38,7 @@ local_scan_executor, polars_executor, semi_executor, + substrait_executor, ) CURRENT_DIR = pathlib.Path(__file__).parent @@ -80,9 +93,35 @@ def sqlglot_engine( ) -@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq", "bq-sqlglot"]) +@pytest.fixture(scope="session") +def substrait_datafusion_engine() -> semi_executor.SemiExecutor: + return substrait_executor.SubstraitExecutor.default_for_engine("datafusion") + + +@pytest.fixture(scope="session") +def substrait_acero_engine() -> semi_executor.SemiExecutor: + return substrait_executor.SubstraitExecutor.default_for_engine("acero") + + +@pytest.fixture( + scope="session", + params=[ + "pyarrow", + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], +) def engine( - request, pyarrow_engine, polars_engine, bq_engine, sqlglot_engine + request, + pyarrow_engine, + polars_engine, + bq_engine, + sqlglot_engine, + substrait_datafusion_engine, + substrait_acero_engine, ) -> semi_executor.SemiExecutor: if request.param == "pyarrow": return pyarrow_engine @@ -92,6 +131,10 @@ def engine( return bq_engine if request.param == "bq-sqlglot": return sqlglot_engine + if request.param == "substrait-datafusion": + return substrait_datafusion_engine + if request.param == "substrait-acero": + return substrait_acero_engine raise ValueError(f"Unrecognized param: {request.param}") diff --git a/packages/bigframes/tests/system/small/engines/test_aggregation.py b/packages/bigframes/tests/system/small/engines/test_aggregation.py index 669eae9ebf75..16a34a664589 100644 --- a/packages/bigframes/tests/system/small/engines/test_aggregation.py +++ b/packages/bigframes/tests/system/small/engines/test_aggregation.py @@ -54,7 +54,9 @@ def apply_agg_to_all_valid( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", ["polars", "bq", "bq-sqlglot", "substrait-datafusion"], indirect=True +) def test_engines_aggregate_post_filter_size( scalars_array_value: array_value.ArrayValue, engine, diff --git a/packages/bigframes/tests/system/small/engines/test_bool_ops.py b/packages/bigframes/tests/system/small/engines/test_bool_ops.py index a6ef702885b0..b3a66526515b 100644 --- a/packages/bigframes/tests/system/small/engines/test_bool_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_bool_ops.py @@ -46,7 +46,17 @@ def apply_op_pairwise( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) @pytest.mark.parametrize( "op", [ diff --git a/packages/bigframes/tests/system/small/engines/test_comparison_ops.py b/packages/bigframes/tests/system/small/engines/test_comparison_ops.py index de0f110fa9b1..223e3068cb40 100644 --- a/packages/bigframes/tests/system/small/engines/test_comparison_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_comparison_ops.py @@ -48,7 +48,17 @@ def apply_op_pairwise( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) @pytest.mark.parametrize( "op", [ diff --git a/packages/bigframes/tests/system/small/engines/test_filtering.py b/packages/bigframes/tests/system/small/engines/test_filtering.py index fcb85aa8859b..5f0c4ad052ca 100644 --- a/packages/bigframes/tests/system/small/engines/test_filtering.py +++ b/packages/bigframes/tests/system/small/engines/test_filtering.py @@ -24,7 +24,17 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_filter_bool_col( scalars_array_value: array_value.ArrayValue, engine, @@ -35,7 +45,17 @@ def test_engines_filter_bool_col( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_filter_expr_cond( scalars_array_value: array_value.ArrayValue, engine, @@ -47,7 +67,17 @@ def test_engines_filter_expr_cond( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_filter_true( scalars_array_value: array_value.ArrayValue, engine, @@ -57,7 +87,17 @@ def test_engines_filter_true( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_filter_false( scalars_array_value: array_value.ArrayValue, engine, diff --git a/packages/bigframes/tests/system/small/engines/test_generic_ops.py b/packages/bigframes/tests/system/small/engines/test_generic_ops.py index 22ad1bfefa4e..acbd3e23db9b 100644 --- a/packages/bigframes/tests/system/small/engines/test_generic_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_generic_ops.py @@ -52,7 +52,17 @@ def apply_op( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine): polars_version = tuple([int(part) for part in polars.__version__.split(".")]) if polars_version >= (1, 34, 0): @@ -69,7 +79,11 @@ def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + ["polars", "bq", "bq-sqlglot", "substrait-datafusion", "substrait-acero"], + indirect=True, +) def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, engine): vals = ["1", "100", "-3"] arr, _ = scalars_array_value.compute_values( @@ -84,7 +98,17 @@ def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -95,7 +119,17 @@ def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engin assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_float( scalars_array_value: array_value.ArrayValue, engine ): @@ -112,7 +146,17 @@ def test_engines_astype_string_float( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE) @@ -121,7 +165,17 @@ def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engine): # floats work slightly different with trailing zeroes rn arr = apply_op( @@ -133,7 +187,17 @@ def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engi assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -144,7 +208,17 @@ def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, eng assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_numeric( scalars_array_value: array_value.ArrayValue, engine ): @@ -161,7 +235,17 @@ def test_engines_astype_string_numeric( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -172,7 +256,17 @@ def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_date( scalars_array_value: array_value.ArrayValue, engine ): @@ -189,7 +283,17 @@ def test_engines_astype_string_date( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -200,7 +304,17 @@ def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, en assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_datetime( scalars_array_value: array_value.ArrayValue, engine ): @@ -217,7 +331,17 @@ def test_engines_astype_string_datetime( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -228,7 +352,17 @@ def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, e assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_timestamp( scalars_array_value: array_value.ArrayValue, engine ): @@ -249,7 +383,17 @@ def test_engines_astype_string_timestamp( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -260,7 +404,17 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, engine): exprs = [ ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE).as_expr( @@ -281,7 +435,17 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, engine): exprs = [ ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr( @@ -304,7 +468,17 @@ def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, eng assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, diff --git a/packages/bigframes/tests/system/small/engines/test_numeric_ops.py b/packages/bigframes/tests/system/small/engines/test_numeric_ops.py index c188e37370c5..e573b848e0e3 100644 --- a/packages/bigframes/tests/system/small/engines/test_numeric_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_numeric_ops.py @@ -94,7 +94,15 @@ def test_engines_project_floor( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + ], + indirect=True, +) def test_engines_project_add( scalars_array_value: array_value.ArrayValue, engine, diff --git a/packages/bigframes/tests/system/small/engines/test_slicing.py b/packages/bigframes/tests/system/small/engines/test_slicing.py index 022758893d29..d21218c8f4d5 100644 --- a/packages/bigframes/tests/system/small/engines/test_slicing.py +++ b/packages/bigframes/tests/system/small/engines/test_slicing.py @@ -24,7 +24,17 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) @pytest.mark.parametrize( ("start", "stop", "step"), [