Skip to content

Commit d88c562

Browse files
refactor: add output type annotations to scalar ops (#338)
* refactor: add output type annotations to scalar ops * use same expression type annotation everywhere * pr comments
1 parent 73e997b commit d88c562

File tree

8 files changed

+407
-129
lines changed

8 files changed

+407
-129
lines changed

bigframes/core/expression.py

+34-9
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,13 @@
1818
import dataclasses
1919
import itertools
2020
import typing
21-
from typing import Optional
2221

23-
import bigframes.dtypes
22+
import bigframes.dtypes as dtypes
2423
import bigframes.operations
2524

2625

27-
def const(
28-
value: typing.Hashable, dtype: Optional[bigframes.dtypes.Dtype] = None
29-
) -> Expression:
30-
return ScalarConstantExpression(value, dtype)
26+
def const(value: typing.Hashable, dtype: dtypes.ExpressionType = None) -> Expression:
27+
return ScalarConstantExpression(value, dtype or dtypes.infer_literal_type(value))
3128

3229

3330
def free_var(id: str) -> Expression:
@@ -45,9 +42,16 @@ def unbound_variables(self) -> typing.Tuple[str, ...]:
4542
def rename(self, name_mapping: dict[str, str]) -> Expression:
4643
return self
4744

48-
@abc.abstractproperty
45+
@property
46+
@abc.abstractmethod
4947
def is_const(self) -> bool:
50-
return False
48+
...
49+
50+
@abc.abstractmethod
51+
def output_type(
52+
self, input_types: dict[str, dtypes.ExpressionType]
53+
) -> dtypes.ExpressionType:
54+
...
5155

5256

5357
@dataclasses.dataclass(frozen=True)
@@ -56,12 +60,17 @@ class ScalarConstantExpression(Expression):
5660

5761
# TODO: Further constrain?
5862
value: typing.Hashable
59-
dtype: Optional[bigframes.dtypes.Dtype] = None
63+
dtype: dtypes.ExpressionType = None
6064

6165
@property
6266
def is_const(self) -> bool:
6367
return True
6468

69+
def output_type(
70+
self, input_types: dict[str, bigframes.dtypes.Dtype]
71+
) -> dtypes.ExpressionType:
72+
return self.dtype
73+
6574

6675
@dataclasses.dataclass(frozen=True)
6776
class UnboundVariableExpression(Expression):
@@ -83,6 +92,14 @@ def rename(self, name_mapping: dict[str, str]) -> Expression:
8392
def is_const(self) -> bool:
8493
return False
8594

95+
def output_type(
96+
self, input_types: dict[str, bigframes.dtypes.Dtype]
97+
) -> dtypes.ExpressionType:
98+
if self.id in input_types:
99+
return input_types[self.id]
100+
else:
101+
raise ValueError("Type of variable has not been fixed.")
102+
86103

87104
@dataclasses.dataclass(frozen=True)
88105
class OpExpression(Expression):
@@ -110,3 +127,11 @@ def rename(self, name_mapping: dict[str, str]) -> Expression:
110127
@property
111128
def is_const(self) -> bool:
112129
return all(child.is_const for child in self.inputs)
130+
131+
def output_type(
132+
self, input_types: dict[str, dtypes.ExpressionType]
133+
) -> dtypes.ExpressionType:
134+
operand_types = tuple(
135+
map(lambda x: x.output_type(input_types=input_types), self.inputs)
136+
)
137+
return self.op.output_type(*operand_types)

bigframes/dtypes.py

+62-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import geopandas as gpd # type: ignore
2424
import google.cloud.bigquery as bigquery
2525
import ibis
26+
from ibis.backends.bigquery.datatypes import BigQueryType
2627
import ibis.expr.datatypes as ibis_dtypes
28+
from ibis.expr.datatypes.core import dtype as python_type_to_bigquery_type
2729
import ibis.expr.types as ibis_types
2830
import numpy as np
2931
import pandas as pd
@@ -42,6 +44,14 @@
4244
pd.ArrowDtype,
4345
gpd.array.GeometryDtype,
4446
]
47+
# Represents both column types (dtypes) and local-only types
48+
# None represents the type of a None scalar.
49+
ExpressionType = typing.Optional[Dtype]
50+
51+
INT_DTYPE = pd.Int64Dtype()
52+
FLOAT_DTYPE = pd.Float64Dtype()
53+
BOOL_DTYPE = pd.BooleanDtype()
54+
STRING_DTYPE = pd.StringDtype(storage="pyarrow")
4555

4656
# On BQ side, ARRAY, STRUCT, GEOGRAPHY, JSON are not orderable
4757
UNORDERED_DTYPES = [gpd.array.GeometryDtype()]
@@ -539,31 +549,80 @@ def is_compatible(scalar: typing.Any, dtype: Dtype) -> typing.Optional[Dtype]:
539549
return lcd_type(pd.Int64Dtype(), dtype)
540550
if isinstance(scalar, decimal.Decimal):
541551
# TODO: Check context to see if can use NUMERIC instead of BIGNUMERIC
542-
return lcd_type(pd.ArrowDtype(pa.decimal128(76, 38)), dtype)
552+
return lcd_type(pd.ArrowDtype(pa.decimal256(76, 38)), dtype)
543553
return None
544554

545555

546-
def lcd_type(dtype1: Dtype, dtype2: Dtype) -> typing.Optional[Dtype]:
556+
def lcd_type(dtype1: Dtype, dtype2: Dtype) -> Dtype:
547557
if dtype1 == dtype2:
548558
return dtype1
549559
# Implicit conversion currently only supported for numeric types
550560
hierarchy: list[Dtype] = [
551561
pd.BooleanDtype(),
552562
pd.Int64Dtype(),
553-
pd.Float64Dtype(),
554563
pd.ArrowDtype(pa.decimal128(38, 9)),
555564
pd.ArrowDtype(pa.decimal256(76, 38)),
565+
pd.Float64Dtype(),
556566
]
557567
if (dtype1 not in hierarchy) or (dtype2 not in hierarchy):
558568
return None
559569
lcd_index = max(hierarchy.index(dtype1), hierarchy.index(dtype2))
560570
return hierarchy[lcd_index]
561571

562572

573+
def lcd_etype(etype1: ExpressionType, etype2: ExpressionType) -> ExpressionType:
574+
if etype1 is None:
575+
return etype2
576+
if etype2 is None:
577+
return etype1
578+
return lcd_type_or_throw(etype1, etype2)
579+
580+
563581
def lcd_type_or_throw(dtype1: Dtype, dtype2: Dtype) -> Dtype:
564582
result = lcd_type(dtype1, dtype2)
565583
if result is None:
566584
raise NotImplementedError(
567585
f"BigFrames cannot upcast {dtype1} and {dtype2} to common type. {constants.FEEDBACK_LINK}"
568586
)
569587
return result
588+
589+
590+
def infer_literal_type(literal) -> typing.Optional[Dtype]:
591+
if pd.isna(literal):
592+
return None # Null value without a definite type
593+
# Temporary logic, use ibis inferred type
594+
ibis_literal = literal_to_ibis_scalar(literal)
595+
return ibis_dtype_to_bigframes_dtype(ibis_literal.type())
596+
597+
598+
# Input and output types supported by BigQuery DataFrames remote functions.
599+
# TODO(shobs): Extend the support to all types supported by BQ remote functions
600+
# https://cloud.google.com/bigquery/docs/remote-functions#limitations
601+
SUPPORTED_IO_PYTHON_TYPES = {bool, float, int, str}
602+
SUPPORTED_IO_BIGQUERY_TYPEKINDS = {
603+
"BOOLEAN",
604+
"BOOL",
605+
"FLOAT",
606+
"FLOAT64",
607+
"INT64",
608+
"INTEGER",
609+
"STRING",
610+
}
611+
612+
613+
class UnsupportedTypeError(ValueError):
614+
def __init__(self, type_, supported_types):
615+
self.type = type_
616+
self.supported_types = supported_types
617+
618+
619+
def ibis_type_from_python_type(t: type) -> ibis_dtypes.DataType:
620+
if t not in SUPPORTED_IO_PYTHON_TYPES:
621+
raise UnsupportedTypeError(t, SUPPORTED_IO_PYTHON_TYPES)
622+
return python_type_to_bigquery_type(t)
623+
624+
625+
def ibis_type_from_type_kind(tk: bigquery.StandardSqlTypeNames) -> ibis_dtypes.DataType:
626+
if tk not in SUPPORTED_IO_BIGQUERY_TYPEKINDS:
627+
raise UnsupportedTypeError(tk, SUPPORTED_IO_BIGQUERY_TYPEKINDS)
628+
return BigQueryType.to_ibis(tk)

bigframes/functions/remote_function.py

+12-38
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,19 @@
4646
from ibis.backends.bigquery.compiler import compiles
4747
from ibis.backends.bigquery.datatypes import BigQueryType
4848
from ibis.expr.datatypes.core import DataType as IbisDataType
49-
from ibis.expr.datatypes.core import dtype as python_type_to_bigquery_type
5049
import ibis.expr.operations as ops
5150
import ibis.expr.rules as rlz
5251

5352
from bigframes import clients
5453
import bigframes.constants as constants
54+
import bigframes.dtypes
5555

5656
logger = logging.getLogger(__name__)
5757

5858
# Protocol version 4 is available in python version 3.4 and above
5959
# https://docs.python.org/3/library/pickle.html#data-stream-format
6060
_pickle_protocol_version = 4
6161

62-
# Input and output types supported by BigQuery DataFrames remote functions.
63-
# TODO(shobs): Extend the support to all types supported by BQ remote functions
64-
# https://cloud.google.com/bigquery/docs/remote-functions#limitations
65-
SUPPORTED_IO_PYTHON_TYPES = {bool, float, int, str}
66-
SUPPORTED_IO_BIGQUERY_TYPEKINDS = {
67-
"BOOLEAN",
68-
"BOOL",
69-
"FLOAT",
70-
"FLOAT64",
71-
"INT64",
72-
"INTEGER",
73-
"STRING",
74-
}
75-
7662

7763
def get_remote_function_locations(bq_location):
7864
"""Get BQ location and cloud functions region given a BQ client."""
@@ -558,33 +544,17 @@ def f(*args, **kwargs):
558544
return f
559545

560546

561-
class UnsupportedTypeError(ValueError):
562-
def __init__(self, type_, supported_types):
563-
self.type = type_
564-
self.supported_types = supported_types
565-
566-
567-
def ibis_type_from_python_type(t: type) -> IbisDataType:
568-
if t not in SUPPORTED_IO_PYTHON_TYPES:
569-
raise UnsupportedTypeError(t, SUPPORTED_IO_PYTHON_TYPES)
570-
return python_type_to_bigquery_type(t)
571-
572-
573-
def ibis_type_from_type_kind(tk: bigquery.StandardSqlTypeNames) -> IbisDataType:
574-
if tk not in SUPPORTED_IO_BIGQUERY_TYPEKINDS:
575-
raise UnsupportedTypeError(tk, SUPPORTED_IO_BIGQUERY_TYPEKINDS)
576-
return BigQueryType.to_ibis(tk)
577-
578-
579547
def ibis_signature_from_python_signature(
580548
signature: inspect.Signature,
581549
input_types: Sequence[type],
582550
output_type: type,
583551
) -> IbisSignature:
584552
return IbisSignature(
585553
parameter_names=list(signature.parameters.keys()),
586-
input_types=[ibis_type_from_python_type(t) for t in input_types],
587-
output_type=ibis_type_from_python_type(output_type),
554+
input_types=[
555+
bigframes.dtypes.ibis_type_from_python_type(t) for t in input_types
556+
],
557+
output_type=bigframes.dtypes.ibis_type_from_python_type(output_type),
588558
)
589559

590560

@@ -599,10 +569,14 @@ def ibis_signature_from_routine(routine: bigquery.Routine) -> IbisSignature:
599569
return IbisSignature(
600570
parameter_names=[arg.name for arg in routine.arguments],
601571
input_types=[
602-
ibis_type_from_type_kind(arg.data_type.type_kind) if arg.data_type else None
572+
bigframes.dtypes.ibis_type_from_type_kind(arg.data_type.type_kind)
573+
if arg.data_type
574+
else None
603575
for arg in routine.arguments
604576
],
605-
output_type=ibis_type_from_type_kind(routine.return_type.type_kind),
577+
output_type=bigframes.dtypes.ibis_type_from_type_kind(
578+
routine.return_type.type_kind
579+
),
606580
)
607581

608582

@@ -908,7 +882,7 @@ def read_gbq_function(
908882
raise ValueError(
909883
"Function return type must be specified. {constants.FEEDBACK_LINK}"
910884
)
911-
except UnsupportedTypeError as e:
885+
except bigframes.dtypes.UnsupportedTypeError as e:
912886
raise ValueError(
913887
f"Type {e.type} not supported, supported types are {e.supported_types}. "
914888
f"{constants.FEEDBACK_LINK}"

0 commit comments

Comments
 (0)