Skip to content

Commit 20e5c43

Browse files
rsoklZac-HD
authored andcommitted
add advanced integer indexing strategy and tests
1 parent a943d18 commit 20e5c43

File tree

6 files changed

+183
-7
lines changed

6 files changed

+183
-7
lines changed

hypothesis-python/RELEASE.rst

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
RELEASE_TYPE: patch
2+
3+
This release adds the strategy :func:`~hypothesis.extra.numpy.integer_array_indices`,
4+
which generates tuples of Numpy arrays that can be used for
5+
`advanced indexing <http://www.pythonlikeyoumeanit.com/Module3_IntroducingNumpy/AdvancedIndexing.html#Integer-Array-Indexing>`_
6+
to select an array of a specified shape.

hypothesis-python/src/hypothesis/_strategies.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1318,8 +1318,8 @@ def everything_except(excluded_types):
13181318
13191319
For example, ``everything_except(int)`` returns a strategy that can
13201320
generate anything that ``from_type()`` can ever generate, except for
1321-
instances of :class:python:int, and excluding instances of types
1322-
added via :func:~hypothesis.strategies.register_type_strategy.
1321+
instances of :class:`python:int`, and excluding instances of types
1322+
added via :func:`~hypothesis.strategies.register_type_strategy`.
13231323
13241324
This is useful when writing tests which check that invalid input is
13251325
rejected in a certain way.

hypothesis-python/src/hypothesis/extra/numpy.py

+71-4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
from typing import Any, Union, Sequence, Tuple, Optional # noqa
3838
from hypothesis.searchstrategy.strategies import T # noqa
3939

40+
Shape = Tuple[int, ...] # noqa
41+
4042
TIME_RESOLUTIONS = tuple("Y M D h m s ms us ns ps fs as".split())
4143

4244

@@ -295,7 +297,7 @@ def fill_for(elements, unique, fill, name=""):
295297
@st.defines_strategy
296298
def arrays(
297299
dtype, # type: Any
298-
shape, # type: Union[int, Sequence[int], st.SearchStrategy[Sequence[int]]]
300+
shape, # type: Union[int, Shape, st.SearchStrategy[Shape]]
299301
elements=None, # type: st.SearchStrategy[Any]
300302
fill=None, # type: st.SearchStrategy[Any]
301303
unique=False, # type: bool
@@ -401,7 +403,7 @@ def arrays(
401403

402404
@st.defines_strategy
403405
def array_shapes(min_dims=1, max_dims=None, min_side=1, max_side=None):
404-
# type: (int, int, int, int) -> st.SearchStrategy[Tuple[int, ...]]
406+
# type: (int, int, int, int) -> st.SearchStrategy[Shape]
405407
"""Return a strategy for array shapes (tuples of int >= 1)."""
406408
check_type(integer_types, min_dims, "min_dims")
407409
check_type(integer_types, min_side, "min_side")
@@ -672,7 +674,7 @@ def nested_dtypes(
672674

673675
@st.defines_strategy
674676
def valid_tuple_axes(ndim, min_size=0, max_size=None):
675-
# type: (int, int, int) -> st.SearchStrategy[Tuple[int, ...]]
677+
# type: (int, int, int) -> st.SearchStrategy[Shape]
676678
"""Return a strategy for generating permissible tuple-values for the
677679
``axis`` argument for a numpy sequential function (e.g.
678680
:func:`numpy:numpy.sum`), given an array of the specified
@@ -763,7 +765,7 @@ def do_draw(self, data):
763765

764766
@st.defines_strategy
765767
def broadcastable_shapes(shape, min_dims=0, max_dims=None, min_side=1, max_side=None):
766-
# type: (Sequence[int], int, Optional[int], int, Optional[int]) -> st.SearchStrategy[Tuple[int, ...]]
768+
# type: (Shape, int, int, int, int) -> st.SearchStrategy[Shape]
767769
"""Return a strategy for generating shapes that are broadcast-compatible
768770
with the provided shape.
769771
@@ -846,3 +848,68 @@ def broadcastable_shapes(shape, min_dims=0, max_dims=None, min_side=1, max_side=
846848
min_side=min_side,
847849
max_side=max_side,
848850
)
851+
852+
853+
@st.defines_strategy
854+
def integer_array_indices(shape, result_shape=array_shapes(), dtype="int"):
855+
# type: (Shape, SearchStrategy[Shape], np.dtype) -> st.SearchStrategy[Tuple[np.ndarray, ...]]
856+
"""Return a search strategy for tuples of integer-arrays that, when used
857+
to index into an array of shape ``shape``, given an array whose shape
858+
was drawn from ``result_shape``.
859+
860+
Examples from this strategy shrink towards the tuple of index-arrays::
861+
862+
len(shape) * (np.zeros(drawn_result_shape, dtype), )
863+
864+
* ``shape`` a tuple of integers that indicates the shape of the array,
865+
whose indices are being generated.
866+
* ``result_shape`` a strategy for generating tuples of integers, which
867+
describe the shape of the resulting index arrays. The default is
868+
:func:`~hypothesis.extra.numpy.array_shapes`. The shape drawn from
869+
this strategy determines the shape of the array that will be produced
870+
when the corresponding example from ``integer_array_indices`` is used
871+
as an index.
872+
* ``dtype`` the integer data type of the generated index-arrays. Negative
873+
integer indices can be generated if a signed integer type is specified.
874+
875+
Recall that an array can be indexed using a tuple of integer-arrays to
876+
access its members in an arbitrary order, producing an array with an
877+
arbitrary shape. For example:
878+
879+
.. code-block:: pycon
880+
881+
>>> from numpy import array
882+
>>> x = array([-0, -1, -2, -3, -4])
883+
>>> ind = (array([[4, 0], [0, 1]]),) # a tuple containing a 2D integer-array
884+
>>> x[ind] # the resulting array is commensurate with the indexing array(s)
885+
array([[-4, 0],
886+
[0, -1]])
887+
888+
Note that this strategy does not accommodate all variations of so-called
889+
'advanced indexing', as prescribed by NumPy's nomenclature. Combinations
890+
of basic and advanced indexes are too complex to usefully define in a
891+
standard strategy; we leave application-specific strategies to the user.
892+
Advanced-boolean indexing can be defined as ``arrays(shape=..., dtype=bool)``,
893+
and is similarly left to the user.
894+
"""
895+
check_type(tuple, shape, "shape")
896+
check_argument(
897+
shape and all(isinstance(x, integer_types) and x > 0 for x in shape),
898+
"shape=%r must be a non-empty tuple of integers > 0" % (shape,),
899+
)
900+
check_type(SearchStrategy, result_shape, "result_shape")
901+
check_argument(
902+
np.issubdtype(dtype, np.integer), "dtype=%r must be an integer dtype" % (dtype,)
903+
)
904+
signed = np.issubdtype(dtype, np.signedinteger)
905+
906+
def array_for(index_shape, size):
907+
return arrays(
908+
dtype=dtype,
909+
shape=index_shape,
910+
elements=st.integers(-size if signed else 0, size - 1),
911+
)
912+
913+
return result_shape.flatmap(
914+
lambda index_shape: st.tuples(*[array_for(index_shape, size) for size in shape])
915+
)

hypothesis-python/tests/numpy/test_argument_validation.py

+5
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def e(a, **kwargs):
116116
min_side=2,
117117
max_side=3,
118118
),
119+
e(nps.integer_array_indices, shape=()),
120+
e(nps.integer_array_indices, shape=(2, 0)),
121+
e(nps.integer_array_indices, shape="a"),
122+
e(nps.integer_array_indices, shape=(2,), result_shape=(2, 2)),
123+
e(nps.integer_array_indices, shape=(2,), dtype=float),
119124
],
120125
)
121126
def test_raise_invalid_argument(function, kwargs):

hypothesis-python/tests/numpy/test_gen_data.py

+98
Original file line numberDiff line numberDiff line change
@@ -702,3 +702,101 @@ def test_broadcastable_shape_can_generate_arbitrary_ndims(shape, max_dims, data)
702702
lambda x: len(x) == desired_ndim,
703703
settings(max_examples=10 ** 6),
704704
)
705+
706+
707+
@settings(deadline=None)
708+
@given(
709+
shape=nps.array_shapes(min_dims=1, min_side=1),
710+
dtype=st.one_of(nps.unsigned_integer_dtypes(), nps.integer_dtypes()),
711+
data=st.data(),
712+
)
713+
def test_advanced_integer_index_is_valid_with_default_result_shape(shape, dtype, data):
714+
index = data.draw(nps.integer_array_indices(shape, dtype=dtype))
715+
x = np.zeros(shape)
716+
out = x[index] # raises if the index is invalid
717+
assert not np.shares_memory(x, out) # advanced indexing should not return a view
718+
assert all(dtype == x.dtype for x in index)
719+
720+
721+
@settings(deadline=None)
722+
@given(
723+
shape=nps.array_shapes(min_dims=1, min_side=1),
724+
min_dims=st.integers(0, 3),
725+
min_side=st.integers(0, 3),
726+
dtype=st.one_of(nps.unsigned_integer_dtypes(), nps.integer_dtypes()),
727+
data=st.data(),
728+
)
729+
def test_advanced_integer_index_is_valid_and_satisfies_bounds(
730+
shape, min_dims, min_side, dtype, data
731+
):
732+
max_side = data.draw(st.integers(min_side, min_side + 2), label="max_side")
733+
max_dims = data.draw(st.integers(min_dims, min_dims + 2), label="max_dims")
734+
index = data.draw(
735+
nps.integer_array_indices(
736+
shape,
737+
result_shape=nps.array_shapes(
738+
min_dims=min_dims,
739+
max_dims=max_dims,
740+
min_side=min_side,
741+
max_side=max_side,
742+
),
743+
dtype=dtype,
744+
)
745+
)
746+
x = np.zeros(shape)
747+
out = x[index] # raises if the index is invalid
748+
assert all(min_side <= s <= max_side for s in out.shape)
749+
assert min_dims <= out.ndim <= max_dims
750+
assert not np.shares_memory(x, out) # advanced indexing should not return a view
751+
assert all(dtype == x.dtype for x in index)
752+
753+
754+
@settings(deadline=None)
755+
@given(
756+
shape=nps.array_shapes(min_dims=1, min_side=1),
757+
min_dims=st.integers(0, 3),
758+
min_side=st.integers(0, 3),
759+
dtype=st.sampled_from(["uint8", "int8"]),
760+
data=st.data(),
761+
)
762+
def test_advanced_integer_index_minimizes_as_documented(
763+
shape, min_dims, min_side, dtype, data
764+
):
765+
max_side = data.draw(st.integers(min_side, min_side + 2), label="max_side")
766+
max_dims = data.draw(st.integers(min_dims, min_dims + 2), label="max_dims")
767+
result_shape = nps.array_shapes(
768+
min_dims=min_dims, max_dims=max_dims, min_side=min_side, max_side=max_side
769+
)
770+
smallest = minimal(
771+
nps.integer_array_indices(shape, result_shape=result_shape, dtype=dtype)
772+
)
773+
desired = len(shape) * (np.zeros(min_dims * [min_side]),)
774+
assert len(smallest) == len(desired)
775+
for s, d in zip(smallest, desired):
776+
np.testing.assert_array_equal(s, d)
777+
778+
779+
@settings(deadline=None, max_examples=10)
780+
@given(
781+
shape=nps.array_shapes(min_dims=1, max_dims=2, min_side=1, max_side=3),
782+
data=st.data(),
783+
)
784+
def test_advanced_integer_index_can_generate_any_pattern(shape, data):
785+
# ensures that generated index-arrays can be used to yield any pattern of elements from an array
786+
x = np.arange(np.product(shape)).reshape(shape)
787+
788+
target = data.draw(
789+
nps.arrays(
790+
shape=nps.array_shapes(min_dims=1, max_dims=2, min_side=1, max_side=2),
791+
elements=st.sampled_from(x.flatten()),
792+
dtype=x.dtype,
793+
),
794+
label="target",
795+
)
796+
find_any(
797+
nps.integer_array_indices(
798+
shape, result_shape=st.just(target.shape), dtype=np.dtype("int8")
799+
),
800+
lambda index: np.all(target == x[index]),
801+
settings(max_examples=10 ** 6),
802+
)

tooling/src/hypothesistooling/projects/hypothesispython.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def upload_distribution():
193193
entries = [i for i, l in enumerate(lines) if CHANGELOG_HEADER.match(l)]
194194
changelog_body = "".join(lines[entries[0] + 2 : entries[1]]).strip() + (
195195
"\n\n*[The canonical version of these notes (with links) is on readthedocs.]"
196-
"(https://hypothesis.readthedocs.io/en/latest/changes.html#v%s).*"
196+
"(https://hypothesis.readthedocs.io/en/latest/changes.html#v%s)*"
197197
% (current_version().replace(".", "-"),)
198198
)
199199

0 commit comments

Comments
 (0)