Skip to content

Commit 97153bf

Browse files
TomAugspurgerjreback
authored andcommitted
Implement NA.__array_ufunc__ (#30245)
1 parent ea73e0b commit 97153bf

File tree

8 files changed

+256
-113
lines changed

8 files changed

+256
-113
lines changed

doc/source/getting_started/dsintro.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -676,11 +676,11 @@ similar to an ndarray:
676676
# only show the first 5 rows
677677
df[:5].T
678678
679+
.. _dsintro.numpy_interop:
680+
679681
DataFrame interoperability with NumPy functions
680682
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
681683

682-
.. _dsintro.numpy_interop:
683-
684684
Elementwise NumPy ufuncs (log, exp, sqrt, ...) and various other NumPy functions
685685
can be used with no issues on Series and DataFrame, assuming the data within
686686
are numeric:

doc/source/user_guide/missing_data.rst

+26
Original file line numberDiff line numberDiff line change
@@ -920,3 +920,29 @@ filling missing values beforehand.
920920

921921
A similar situation occurs when using Series or DataFrame objects in ``if``
922922
statements, see :ref:`gotchas.truth`.
923+
924+
NumPy ufuncs
925+
------------
926+
927+
:attr:`pandas.NA` implements NumPy's ``__array_ufunc__`` protocol. Most ufuncs
928+
work with ``NA``, and generally return ``NA``:
929+
930+
.. ipython:: python
931+
932+
np.log(pd.NA)
933+
np.add(pd.NA, 1)
934+
935+
.. warning::
936+
937+
Currently, ufuncs involving an ndarray and ``NA`` will return an
938+
object-dtype filled with NA values.
939+
940+
.. ipython:: python
941+
942+
a = np.array([1, 2, 3])
943+
np.greater(a, pd.NA)
944+
945+
The return type here may change to return a different array type
946+
in the future.
947+
948+
See :ref:`dsintro.numpy_interop` for more on ufuncs.

pandas/_libs/missing.pyx

+48-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ from pandas._libs.tslibs.np_datetime cimport (
1414
get_timedelta64_value, get_datetime64_value)
1515
from pandas._libs.tslibs.nattype cimport (
1616
checknull_with_nat, c_NaT as NaT, is_null_datetimelike)
17+
from pandas._libs.ops_dispatch import maybe_dispatch_ufunc_to_dunder_op
1718

1819
from pandas.compat import is_platform_32bit
1920

@@ -290,16 +291,29 @@ cdef inline bint is_null_period(v):
290291
# Implementation of NA singleton
291292

292293

293-
def _create_binary_propagating_op(name, divmod=False):
294+
def _create_binary_propagating_op(name, is_divmod=False):
294295

295296
def method(self, other):
296297
if (other is C_NA or isinstance(other, str)
297-
or isinstance(other, (numbers.Number, np.bool_))):
298-
if divmod:
298+
or isinstance(other, (numbers.Number, np.bool_))
299+
or isinstance(other, np.ndarray) and not other.shape):
300+
# Need the other.shape clause to handle NumPy scalars,
301+
# since we do a setitem on `out` below, which
302+
# won't work for NumPy scalars.
303+
if is_divmod:
299304
return NA, NA
300305
else:
301306
return NA
302307

308+
elif isinstance(other, np.ndarray):
309+
out = np.empty(other.shape, dtype=object)
310+
out[:] = NA
311+
312+
if is_divmod:
313+
return out, out.copy()
314+
else:
315+
return out
316+
303317
return NotImplemented
304318

305319
method.__name__ = name
@@ -369,8 +383,8 @@ class NAType(C_NAType):
369383
__rfloordiv__ = _create_binary_propagating_op("__rfloordiv__")
370384
__mod__ = _create_binary_propagating_op("__mod__")
371385
__rmod__ = _create_binary_propagating_op("__rmod__")
372-
__divmod__ = _create_binary_propagating_op("__divmod__", divmod=True)
373-
__rdivmod__ = _create_binary_propagating_op("__rdivmod__", divmod=True)
386+
__divmod__ = _create_binary_propagating_op("__divmod__", is_divmod=True)
387+
__rdivmod__ = _create_binary_propagating_op("__rdivmod__", is_divmod=True)
374388
# __lshift__ and __rshift__ are not implemented
375389

376390
__eq__ = _create_binary_propagating_op("__eq__")
@@ -397,6 +411,8 @@ class NAType(C_NAType):
397411
return type(other)(1)
398412
else:
399413
return NA
414+
elif isinstance(other, np.ndarray):
415+
return np.where(other == 0, other.dtype.type(1), NA)
400416

401417
return NotImplemented
402418

@@ -408,6 +424,8 @@ class NAType(C_NAType):
408424
return other
409425
else:
410426
return NA
427+
elif isinstance(other, np.ndarray):
428+
return np.where((other == 1) | (other == -1), other, NA)
411429

412430
return NotImplemented
413431

@@ -440,6 +458,31 @@ class NAType(C_NAType):
440458

441459
__rxor__ = __xor__
442460

461+
__array_priority__ = 1000
462+
_HANDLED_TYPES = (np.ndarray, numbers.Number, str, np.bool_)
463+
464+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
465+
types = self._HANDLED_TYPES + (NAType,)
466+
for x in inputs:
467+
if not isinstance(x, types):
468+
return NotImplemented
469+
470+
if method != "__call__":
471+
raise ValueError(f"ufunc method '{method}' not supported for NA")
472+
result = maybe_dispatch_ufunc_to_dunder_op(
473+
self, ufunc, method, *inputs, **kwargs
474+
)
475+
if result is NotImplemented:
476+
# For a NumPy ufunc that's not a binop, like np.logaddexp
477+
index = [i for i, x in enumerate(inputs) if x is NA][0]
478+
result = np.broadcast_arrays(*inputs)[index]
479+
if result.ndim == 0:
480+
result = result.item()
481+
if ufunc.nout > 1:
482+
result = (NA,) * ufunc.nout
483+
484+
return result
485+
443486

444487
C_NA = NAType() # C-visible
445488
NA = C_NA # Python-visible

pandas/_libs/ops_dispatch.pyx

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
DISPATCHED_UFUNCS = {
2+
"add",
3+
"sub",
4+
"mul",
5+
"pow",
6+
"mod",
7+
"floordiv",
8+
"truediv",
9+
"divmod",
10+
"eq",
11+
"ne",
12+
"lt",
13+
"gt",
14+
"le",
15+
"ge",
16+
"remainder",
17+
"matmul",
18+
"or",
19+
"xor",
20+
"and",
21+
}
22+
UFUNC_ALIASES = {
23+
"subtract": "sub",
24+
"multiply": "mul",
25+
"floor_divide": "floordiv",
26+
"true_divide": "truediv",
27+
"power": "pow",
28+
"remainder": "mod",
29+
"divide": "div",
30+
"equal": "eq",
31+
"not_equal": "ne",
32+
"less": "lt",
33+
"less_equal": "le",
34+
"greater": "gt",
35+
"greater_equal": "ge",
36+
"bitwise_or": "or",
37+
"bitwise_and": "and",
38+
"bitwise_xor": "xor",
39+
}
40+
41+
# For op(., Array) -> Array.__r{op}__
42+
REVERSED_NAMES = {
43+
"lt": "__gt__",
44+
"le": "__ge__",
45+
"gt": "__lt__",
46+
"ge": "__le__",
47+
"eq": "__eq__",
48+
"ne": "__ne__",
49+
}
50+
51+
52+
def maybe_dispatch_ufunc_to_dunder_op(
53+
object self, object ufunc, str method, *inputs, **kwargs
54+
):
55+
"""
56+
Dispatch a ufunc to the equivalent dunder method.
57+
58+
Parameters
59+
----------
60+
self : ArrayLike
61+
The array whose dunder method we dispatch to
62+
ufunc : Callable
63+
A NumPy ufunc
64+
method : {'reduce', 'accumulate', 'reduceat', 'outer', 'at', '__call__'}
65+
inputs : ArrayLike
66+
The input arrays.
67+
kwargs : Any
68+
The additional keyword arguments, e.g. ``out``.
69+
70+
Returns
71+
-------
72+
result : Any
73+
The result of applying the ufunc
74+
"""
75+
# special has the ufuncs we dispatch to the dunder op on
76+
77+
op_name = ufunc.__name__
78+
op_name = UFUNC_ALIASES.get(op_name, op_name)
79+
80+
def not_implemented(*args, **kwargs):
81+
return NotImplemented
82+
83+
if (method == "__call__"
84+
and op_name in DISPATCHED_UFUNCS
85+
and kwargs.get("out") is None):
86+
if isinstance(inputs[0], type(self)):
87+
name = f"__{op_name}__"
88+
return getattr(self, name, not_implemented)(inputs[1])
89+
else:
90+
name = REVERSED_NAMES.get(op_name, f"__r{op_name}__")
91+
result = getattr(self, name, not_implemented)(inputs[0])
92+
return result
93+
else:
94+
return NotImplemented

pandas/core/ops/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111

1212
from pandas._libs import Timedelta, Timestamp, lib
13+
from pandas._libs.ops_dispatch import maybe_dispatch_ufunc_to_dunder_op # noqa:F401
1314
from pandas.util._decorators import Appender
1415

1516
from pandas.core.dtypes.common import is_list_like, is_timedelta64_dtype
@@ -31,7 +32,6 @@
3132
)
3233
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY # noqa:F401
3334
from pandas.core.ops.common import unpack_zerodim_and_defer
34-
from pandas.core.ops.dispatch import maybe_dispatch_ufunc_to_dunder_op # noqa:F401
3535
from pandas.core.ops.dispatch import should_series_dispatch
3636
from pandas.core.ops.docstrings import (
3737
_arith_doc_FRAME,

pandas/core/ops/dispatch.py

+1-94
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
"""
22
Functions for defining unary operations.
33
"""
4-
from typing import Any, Callable, Union
4+
from typing import Any, Union
55

66
import numpy as np
77

8-
from pandas._typing import ArrayLike
9-
108
from pandas.core.dtypes.common import (
119
is_datetime64_dtype,
1210
is_extension_array_dtype,
@@ -126,94 +124,3 @@ def dispatch_to_extension_op(
126124
# on the ExtensionArray
127125
res_values = op(left, right)
128126
return res_values
129-
130-
131-
def maybe_dispatch_ufunc_to_dunder_op(
132-
self: ArrayLike, ufunc: Callable, method: str, *inputs: ArrayLike, **kwargs: Any
133-
):
134-
"""
135-
Dispatch a ufunc to the equivalent dunder method.
136-
137-
Parameters
138-
----------
139-
self : ArrayLike
140-
The array whose dunder method we dispatch to
141-
ufunc : Callable
142-
A NumPy ufunc
143-
method : {'reduce', 'accumulate', 'reduceat', 'outer', 'at', '__call__'}
144-
inputs : ArrayLike
145-
The input arrays.
146-
kwargs : Any
147-
The additional keyword arguments, e.g. ``out``.
148-
149-
Returns
150-
-------
151-
result : Any
152-
The result of applying the ufunc
153-
"""
154-
# special has the ufuncs we dispatch to the dunder op on
155-
special = {
156-
"add",
157-
"sub",
158-
"mul",
159-
"pow",
160-
"mod",
161-
"floordiv",
162-
"truediv",
163-
"divmod",
164-
"eq",
165-
"ne",
166-
"lt",
167-
"gt",
168-
"le",
169-
"ge",
170-
"remainder",
171-
"matmul",
172-
"or",
173-
"xor",
174-
"and",
175-
}
176-
aliases = {
177-
"subtract": "sub",
178-
"multiply": "mul",
179-
"floor_divide": "floordiv",
180-
"true_divide": "truediv",
181-
"power": "pow",
182-
"remainder": "mod",
183-
"divide": "div",
184-
"equal": "eq",
185-
"not_equal": "ne",
186-
"less": "lt",
187-
"less_equal": "le",
188-
"greater": "gt",
189-
"greater_equal": "ge",
190-
"bitwise_or": "or",
191-
"bitwise_and": "and",
192-
"bitwise_xor": "xor",
193-
}
194-
195-
# For op(., Array) -> Array.__r{op}__
196-
flipped = {
197-
"lt": "__gt__",
198-
"le": "__ge__",
199-
"gt": "__lt__",
200-
"ge": "__le__",
201-
"eq": "__eq__",
202-
"ne": "__ne__",
203-
}
204-
205-
op_name = ufunc.__name__
206-
op_name = aliases.get(op_name, op_name)
207-
208-
def not_implemented(*args, **kwargs):
209-
return NotImplemented
210-
211-
if method == "__call__" and op_name in special and kwargs.get("out") is None:
212-
if isinstance(inputs[0], type(self)):
213-
name = f"__{op_name}__"
214-
return getattr(self, name, not_implemented)(inputs[1])
215-
else:
216-
name = flipped.get(op_name, f"__r{op_name}__")
217-
return getattr(self, name, not_implemented)(inputs[0])
218-
else:
219-
return NotImplemented

0 commit comments

Comments
 (0)