Skip to content

Commit f2da321

Browse files
TrevorBergeronGenesis929
authored andcommitted
fix: Restore string to date/time type coercion (#565)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 5b76885 commit f2da321

File tree

4 files changed

+107
-26
lines changed

4 files changed

+107
-26
lines changed

bigframes/dtypes.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ def is_compatible(scalar: typing.Any, dtype: Dtype) -> typing.Optional[Dtype]:
648648

649649

650650
def lcd_type(dtype1: Dtype, dtype2: Dtype) -> Dtype:
651+
"""Get the supertype of the two types."""
651652
if dtype1 == dtype2:
652653
return dtype1
653654
# Implicit conversion currently only supported for numeric types
@@ -664,12 +665,26 @@ def lcd_type(dtype1: Dtype, dtype2: Dtype) -> Dtype:
664665
return hierarchy[lcd_index]
665666

666667

667-
def lcd_etype(etype1: ExpressionType, etype2: ExpressionType) -> ExpressionType:
668-
if etype1 is None:
668+
def coerce_to_common(etype1: ExpressionType, etype2: ExpressionType) -> ExpressionType:
669+
"""Coerce types to a common type or throw a TypeError"""
670+
if etype1 is not None and etype2 is not None:
671+
common_supertype = lcd_type(etype1, etype2)
672+
if common_supertype is not None:
673+
return common_supertype
674+
if can_coerce(etype1, etype2):
669675
return etype2
670-
if etype2 is None:
676+
if can_coerce(etype2, etype1):
671677
return etype1
672-
return lcd_type_or_throw(etype1, etype2)
678+
raise TypeError(f"Cannot coerce {etype1} and {etype2} to a common type.")
679+
680+
681+
def can_coerce(source_type: ExpressionType, target_type: ExpressionType) -> bool:
682+
if source_type is None:
683+
return True # None can be coerced to any supported type
684+
else:
685+
return (source_type == STRING_DTYPE) and (
686+
target_type in (DATETIME_DTYPE, TIMESTAMP_DTYPE, TIME_DTYPE, DATE_DTYPE)
687+
)
673688

674689

675690
def lcd_type_or_throw(dtype1: Dtype, dtype2: Dtype) -> Dtype:

bigframes/operations/__init__.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -548,16 +548,10 @@ def output_type(self, *input_types):
548548

549549

550550
# Binary Ops
551-
fillna_op = create_binary_op(name="fillna", type_signature=op_typing.COMMON_SUPERTYPE)
552-
cliplower_op = create_binary_op(
553-
name="clip_lower", type_signature=op_typing.COMMON_SUPERTYPE
554-
)
555-
clipupper_op = create_binary_op(
556-
name="clip_upper", type_signature=op_typing.COMMON_SUPERTYPE
557-
)
558-
coalesce_op = create_binary_op(
559-
name="coalesce", type_signature=op_typing.COMMON_SUPERTYPE
560-
)
551+
fillna_op = create_binary_op(name="fillna", type_signature=op_typing.COERCE)
552+
cliplower_op = create_binary_op(name="clip_lower", type_signature=op_typing.COERCE)
553+
clipupper_op = create_binary_op(name="clip_upper", type_signature=op_typing.COERCE)
554+
coalesce_op = create_binary_op(name="coalesce", type_signature=op_typing.COERCE)
561555

562556

563557
## Math Ops
@@ -575,7 +569,7 @@ def output_type(self, *input_types):
575569
right_type is None or dtypes.is_numeric(right_type)
576570
):
577571
# Numeric addition
578-
return dtypes.lcd_etype(left_type, right_type)
572+
return dtypes.coerce_to_common(left_type, right_type)
579573
# TODO: Add temporal addition once delta types supported
580574
raise TypeError(f"Cannot add dtypes {left_type} and {right_type}")
581575

@@ -592,7 +586,7 @@ def output_type(self, *input_types):
592586
right_type is None or dtypes.is_numeric(right_type)
593587
):
594588
# Numeric subtraction
595-
return dtypes.lcd_etype(left_type, right_type)
589+
return dtypes.coerce_to_common(left_type, right_type)
596590
# TODO: Add temporal addition once delta types supported
597591
raise TypeError(f"Cannot subtract dtypes {left_type} and {right_type}")
598592

@@ -652,7 +646,7 @@ class WhereOp(TernaryOp):
652646
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
653647
if input_types[1] != dtypes.BOOL_DTYPE:
654648
raise TypeError("where condition must be a boolean")
655-
return dtypes.lcd_etype(input_types[0], input_types[2])
649+
return dtypes.coerce_to_common(input_types[0], input_types[2])
656650

657651

658652
where_op = WhereOp()
@@ -663,8 +657,8 @@ class ClipOp(TernaryOp):
663657
name: typing.ClassVar[str] = "clip"
664658

665659
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
666-
return dtypes.lcd_etype(
667-
input_types[0], dtypes.lcd_etype(input_types[1], input_types[2])
660+
return dtypes.coerce_to_common(
661+
input_types[0], dtypes.coerce_to_common(input_types[1], input_types[2])
668662
)
669663

670664

bigframes/operations/type.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def output_type(
118118
raise TypeError(f"Type {left_type} is not numeric")
119119
if (right_type is not None) and not bigframes.dtypes.is_numeric(right_type):
120120
raise TypeError(f"Type {right_type} is not numeric")
121-
return bigframes.dtypes.lcd_etype(left_type, right_type)
121+
return bigframes.dtypes.coerce_to_common(left_type, right_type)
122122

123123

124124
@dataclasses.dataclass
@@ -132,21 +132,29 @@ def output_type(
132132
raise TypeError(f"Type {left_type} is not numeric")
133133
if (right_type is not None) and not bigframes.dtypes.is_numeric(right_type):
134134
raise TypeError(f"Type {right_type} is not numeric")
135-
lcd_type = bigframes.dtypes.lcd_etype(left_type, right_type)
135+
lcd_type = bigframes.dtypes.coerce_to_common(left_type, right_type)
136136
if lcd_type == bigframes.dtypes.INT_DTYPE:
137137
# Real numeric ops produce floats on int input
138138
return bigframes.dtypes.FLOAT_DTYPE
139139
return lcd_type
140140

141141

142142
@dataclasses.dataclass
143-
class Supertype(BinaryTypeSignature):
144-
"""Type signature for functions that return a the supertype of its inputs. Currently BigFrames just supports upcasting numerics."""
143+
class CoerceCommon(BinaryTypeSignature):
144+
"""Attempt to coerce inputs to a compatible type."""
145145

146146
def output_type(
147147
self, left_type: ExpressionType, right_type: ExpressionType
148148
) -> ExpressionType:
149-
return bigframes.dtypes.lcd_etype(left_type, right_type)
149+
try:
150+
return bigframes.dtypes.coerce_to_common(left_type, right_type)
151+
except TypeError:
152+
pass
153+
if bigframes.dtypes.can_coerce(left_type, right_type):
154+
return right_type
155+
if bigframes.dtypes.can_coerce(right_type, left_type):
156+
return left_type
157+
raise TypeError(f"Cannot coerce {left_type} and {right_type} to a common type.")
150158

151159

152160
@dataclasses.dataclass
@@ -156,7 +164,7 @@ class Comparison(BinaryTypeSignature):
156164
def output_type(
157165
self, left_type: ExpressionType, right_type: ExpressionType
158166
) -> ExpressionType:
159-
common_type = bigframes.dtypes.lcd_etype(left_type, right_type)
167+
common_type = CoerceCommon().output_type(left_type, right_type)
160168
if not bigframes.dtypes.is_comparable(common_type):
161169
raise TypeError(f"Types {left_type} and {right_type} are not comparable")
162170
return bigframes.dtypes.BOOL_DTYPE
@@ -188,7 +196,7 @@ def output_type(
188196
BINARY_NUMERIC = BinaryNumeric()
189197
BINARY_REAL_NUMERIC = BinaryRealNumeric()
190198
COMPARISON = Comparison()
191-
COMMON_SUPERTYPE = Supertype()
199+
COERCE = CoerceCommon()
192200
LOGICAL = Logical()
193201
STRING_TRANSFORM = TypePreserving(
194202
bigframes.dtypes.is_string_like, description="numeric"

tests/system/small/operations/test_datetimes.py

+64
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import datetime
16+
1517
import pandas as pd
1618
import pytest
1719

@@ -303,3 +305,65 @@ def test_dt_floor(scalars_dfs, col_name, freq):
303305
pd_result.astype(scalars_df[col_name].dtype), # floor preserves type
304306
bf_result,
305307
)
308+
309+
310+
def test_dt_compare_coerce_str_datetime(scalars_dfs):
311+
scalars_df, scalars_pandas_df = scalars_dfs
312+
bf_series: bigframes.series.Series = scalars_df["datetime_col"]
313+
bf_result = (bf_series >= "2024-01-01").to_pandas()
314+
315+
pd_result = scalars_pandas_df["datetime_col"] >= pd.to_datetime("2024-01-01")
316+
317+
# pandas produces pyarrow bool dtype
318+
assert_series_equal(pd_result, bf_result, check_dtype=False)
319+
320+
321+
def test_dt_clip_datetime_literals(scalars_dfs):
322+
scalars_df, scalars_pandas_df = scalars_dfs
323+
bf_series: bigframes.series.Series = scalars_df["date_col"]
324+
bf_result = bf_series.clip(
325+
datetime.date(2020, 1, 1), datetime.date(2024, 1, 1)
326+
).to_pandas()
327+
328+
pd_result = scalars_pandas_df["date_col"].clip(
329+
datetime.date(2020, 1, 1), datetime.date(2024, 1, 1)
330+
)
331+
332+
assert_series_equal(
333+
pd_result,
334+
bf_result,
335+
)
336+
337+
338+
def test_dt_clip_coerce_str_date(scalars_dfs):
339+
scalars_df, scalars_pandas_df = scalars_dfs
340+
bf_series: bigframes.series.Series = scalars_df["date_col"]
341+
bf_result = bf_series.clip("2020-01-01", "2024-01-01").to_pandas()
342+
343+
# Pandas can't coerce with pyarrow types so convert first
344+
pd_result = scalars_pandas_df["date_col"].clip(
345+
datetime.date(2020, 1, 1), datetime.date(2024, 1, 1)
346+
)
347+
348+
assert_series_equal(
349+
pd_result,
350+
bf_result,
351+
)
352+
353+
354+
def test_dt_clip_coerce_str_timestamp(scalars_dfs):
355+
scalars_df, scalars_pandas_df = scalars_dfs
356+
bf_series: bigframes.series.Series = scalars_df["timestamp_col"]
357+
bf_result = bf_series.clip(
358+
"2020-01-01T20:03:50Z", "2024-01-01T20:03:50Z"
359+
).to_pandas()
360+
361+
pd_result = scalars_pandas_df["timestamp_col"].clip(
362+
pd.to_datetime("2020-01-01T20:03:50Z", utc=True),
363+
pd.to_datetime("2024-01-01T20:03:50Z", utc=True),
364+
)
365+
366+
assert_series_equal(
367+
pd_result,
368+
bf_result,
369+
)

0 commit comments

Comments
 (0)