@@ -442,42 +442,67 @@ def replace(
442
442
self , to_replace : typing .Any , value : typing .Any = None , * , regex : bool = False
443
443
):
444
444
if regex :
445
- if not (isinstance (to_replace , str ) and isinstance (value , str )):
446
- raise NotImplementedError (
447
- f"replace regex mode only supports strings for 'to_replace' and 'value'. { constants .FEEDBACK_LINK } "
448
- )
449
- block , result_col = self ._block .apply_unary_op (
450
- self ._value_column ,
451
- ops .ReplaceRegexOp (to_replace , value ),
452
- result_label = self .name ,
453
- )
454
- return Series (block .select_column (result_col ))
445
+ # No-op unless to_replace and series dtype are both string type
446
+ if not isinstance (to_replace , str ) or not isinstance (
447
+ self .dtype , pandas .StringDtype
448
+ ):
449
+ return self
450
+ return self ._regex_replace (to_replace , value )
455
451
elif utils .is_dict_like (to_replace ):
456
- raise NotImplementedError (
457
- f"Dict 'to_replace' not supported. { constants .FEEDBACK_LINK } "
458
- )
452
+ return self ._mapping_replace (to_replace ) # type: ignore
459
453
elif utils .is_list_like (to_replace ):
460
- block , cond = self ._block .apply_unary_op (
461
- self ._value_column , ops .IsInOp (to_replace )
462
- )
463
- block , result_col = block .apply_binary_op (
464
- cond ,
465
- self ._value_column ,
466
- ops .partial_arg1 (ops .where_op , value ),
467
- result_label = self .name ,
468
- )
469
- return Series (block .select_column (result_col ))
454
+ replace_list = to_replace
470
455
else : # Scalar
471
- block , cond = self ._block .apply_unary_op (
472
- self ._value_column , ops .BinopPartialLeft (ops .eq_op , to_replace )
456
+ replace_list = [to_replace ]
457
+ replace_list = [
458
+ i for i in replace_list if bigframes .dtypes .is_comparable (i , self .dtype )
459
+ ]
460
+ return self ._simple_replace (replace_list , value ) if replace_list else self
461
+
462
+ def _regex_replace (self , to_replace : str , value : str ):
463
+ if not bigframes .dtypes .is_dtype (value , self .dtype ):
464
+ raise NotImplementedError (
465
+ f"Cannot replace { self .dtype } elements with incompatible item { value } as mixed-type columns not supported. { constants .FEEDBACK_LINK } "
473
466
)
474
- block , result_col = block .apply_binary_op (
475
- cond ,
476
- self ._value_column ,
477
- ops .partial_arg1 (ops .where_op , value ),
478
- result_label = self .name ,
467
+ block , result_col = self ._block .apply_unary_op (
468
+ self ._value_column ,
469
+ ops .ReplaceRegexOp (to_replace , value ),
470
+ result_label = self .name ,
471
+ )
472
+ return Series (block .select_column (result_col ))
473
+
474
+ def _simple_replace (self , to_replace_list : typing .Sequence , value ):
475
+ if not bigframes .dtypes .is_dtype (value , self .dtype ):
476
+ raise NotImplementedError (
477
+ f"Cannot replace { self .dtype } elements with incompatible item { value } as mixed-type columns not supported. { constants .FEEDBACK_LINK } "
479
478
)
480
- return Series (block .select_column (result_col ))
479
+
480
+ block , cond = self ._block .apply_unary_op (
481
+ self ._value_column , ops .IsInOp (to_replace_list )
482
+ )
483
+ block , result_col = block .apply_binary_op (
484
+ cond ,
485
+ self ._value_column ,
486
+ ops .partial_arg1 (ops .where_op , value ),
487
+ result_label = self .name ,
488
+ )
489
+ return Series (block .select_column (result_col ))
490
+
491
+ def _mapping_replace (self , mapping : dict [typing .Hashable , typing .Hashable ]):
492
+ tuples = []
493
+ for key , value in mapping .items ():
494
+ if not bigframes .dtypes .is_comparable (key , self .dtype ):
495
+ continue
496
+ if not bigframes .dtypes .is_dtype (value , self .dtype ):
497
+ raise NotImplementedError (
498
+ f"Cannot replace { self .dtype } elements with incompatible item { value } as mixed-type columns not supported. { constants .FEEDBACK_LINK } "
499
+ )
500
+ tuples .append ((key , value ))
501
+
502
+ block , result = self ._block .apply_unary_op (
503
+ self ._value_column , ops .MapOp (tuple (tuples ))
504
+ )
505
+ return Series (block .select_column (result ))
481
506
482
507
def interpolate (self , method : str = "linear" ) -> Series :
483
508
if method == "pad" :
0 commit comments