@@ -14,6 +14,7 @@ from pandas._libs.tslibs.np_datetime cimport (
14
14
get_timedelta64_value, get_datetime64_value)
15
15
from pandas._libs.tslibs.nattype cimport (
16
16
checknull_with_nat, c_NaT as NaT, is_null_datetimelike)
17
+ from pandas._libs.ops_dispatch import maybe_dispatch_ufunc_to_dunder_op
17
18
18
19
from pandas.compat import is_platform_32bit
19
20
@@ -290,16 +291,29 @@ cdef inline bint is_null_period(v):
290
291
# Implementation of NA singleton
291
292
292
293
293
- def _create_binary_propagating_op (name , divmod = False ):
294
+ def _create_binary_propagating_op (name , is_divmod = False ):
294
295
295
296
def method (self , other ):
296
297
if (other is C_NA or isinstance (other, str )
297
- or isinstance (other, (numbers.Number, np.bool_))):
298
- if divmod :
298
+ or isinstance (other, (numbers.Number, np.bool_))
299
+ or isinstance (other, np.ndarray) and not other.shape):
300
+ # Need the other.shape clause to handle NumPy scalars,
301
+ # since we do a setitem on `out` below, which
302
+ # won't work for NumPy scalars.
303
+ if is_divmod:
299
304
return NA, NA
300
305
else :
301
306
return NA
302
307
308
+ elif isinstance (other, np.ndarray):
309
+ out = np.empty(other.shape, dtype = object )
310
+ out[:] = NA
311
+
312
+ if is_divmod:
313
+ return out, out.copy()
314
+ else :
315
+ return out
316
+
303
317
return NotImplemented
304
318
305
319
method.__name__ = name
@@ -369,8 +383,8 @@ class NAType(C_NAType):
369
383
__rfloordiv__ = _create_binary_propagating_op(" __rfloordiv__" )
370
384
__mod__ = _create_binary_propagating_op(" __mod__" )
371
385
__rmod__ = _create_binary_propagating_op(" __rmod__" )
372
- __divmod__ = _create_binary_propagating_op(" __divmod__" , divmod = True )
373
- __rdivmod__ = _create_binary_propagating_op(" __rdivmod__" , divmod = True )
386
+ __divmod__ = _create_binary_propagating_op(" __divmod__" , is_divmod = True )
387
+ __rdivmod__ = _create_binary_propagating_op(" __rdivmod__" , is_divmod = True )
374
388
# __lshift__ and __rshift__ are not implemented
375
389
376
390
__eq__ = _create_binary_propagating_op(" __eq__" )
@@ -397,6 +411,8 @@ class NAType(C_NAType):
397
411
return type (other)(1 )
398
412
else :
399
413
return NA
414
+ elif isinstance (other, np.ndarray):
415
+ return np.where(other == 0 , other.dtype.type(1 ), NA)
400
416
401
417
return NotImplemented
402
418
@@ -408,6 +424,8 @@ class NAType(C_NAType):
408
424
return other
409
425
else :
410
426
return NA
427
+ elif isinstance (other, np.ndarray):
428
+ return np.where((other == 1 ) | (other == - 1 ), other, NA)
411
429
412
430
return NotImplemented
413
431
@@ -440,6 +458,31 @@ class NAType(C_NAType):
440
458
441
459
__rxor__ = __xor__
442
460
461
+ __array_priority__ = 1000
462
+ _HANDLED_TYPES = (np.ndarray, numbers.Number, str , np.bool_)
463
+
464
+ def __array_ufunc__ (self , ufunc , method , *inputs , **kwargs ):
465
+ types = self ._HANDLED_TYPES + (NAType,)
466
+ for x in inputs:
467
+ if not isinstance (x, types):
468
+ return NotImplemented
469
+
470
+ if method != " __call__" :
471
+ raise ValueError (f" ufunc method '{method}' not supported for NA" )
472
+ result = maybe_dispatch_ufunc_to_dunder_op(
473
+ self , ufunc, method, * inputs, ** kwargs
474
+ )
475
+ if result is NotImplemented :
476
+ # For a NumPy ufunc that's not a binop, like np.logaddexp
477
+ index = [i for i, x in enumerate (inputs) if x is NA][0 ]
478
+ result = np.broadcast_arrays(* inputs)[index]
479
+ if result.ndim == 0 :
480
+ result = result.item()
481
+ if ufunc.nout > 1 :
482
+ result = (NA,) * ufunc.nout
483
+
484
+ return result
485
+
443
486
444
487
C_NA = NAType() # C-visible
445
488
NA = C_NA # Python-visible
0 commit comments