|
37 | 37 | from typing import Any, Union, Sequence, Tuple, Optional # noqa
|
38 | 38 | from hypothesis.searchstrategy.strategies import T # noqa
|
39 | 39 |
|
| 40 | + Shape = Tuple[int, ...] # noqa |
| 41 | + |
40 | 42 | TIME_RESOLUTIONS = tuple("Y M D h m s ms us ns ps fs as".split())
|
41 | 43 |
|
42 | 44 |
|
@@ -295,7 +297,7 @@ def fill_for(elements, unique, fill, name=""):
|
295 | 297 | @st.defines_strategy
|
296 | 298 | def arrays(
|
297 | 299 | dtype, # type: Any
|
298 |
| - shape, # type: Union[int, Sequence[int], st.SearchStrategy[Sequence[int]]] |
| 300 | + shape, # type: Union[int, Shape, st.SearchStrategy[Shape]] |
299 | 301 | elements=None, # type: st.SearchStrategy[Any]
|
300 | 302 | fill=None, # type: st.SearchStrategy[Any]
|
301 | 303 | unique=False, # type: bool
|
@@ -401,7 +403,7 @@ def arrays(
|
401 | 403 |
|
402 | 404 | @st.defines_strategy
|
403 | 405 | def array_shapes(min_dims=1, max_dims=None, min_side=1, max_side=None):
|
404 |
| - # type: (int, int, int, int) -> st.SearchStrategy[Tuple[int, ...]] |
| 406 | + # type: (int, int, int, int) -> st.SearchStrategy[Shape] |
405 | 407 | """Return a strategy for array shapes (tuples of int >= 1)."""
|
406 | 408 | check_type(integer_types, min_dims, "min_dims")
|
407 | 409 | check_type(integer_types, min_side, "min_side")
|
@@ -672,7 +674,7 @@ def nested_dtypes(
|
672 | 674 |
|
673 | 675 | @st.defines_strategy
|
674 | 676 | def valid_tuple_axes(ndim, min_size=0, max_size=None):
|
675 |
| - # type: (int, int, int) -> st.SearchStrategy[Tuple[int, ...]] |
| 677 | + # type: (int, int, int) -> st.SearchStrategy[Shape] |
676 | 678 | """Return a strategy for generating permissible tuple-values for the
|
677 | 679 | ``axis`` argument for a numpy sequential function (e.g.
|
678 | 680 | :func:`numpy:numpy.sum`), given an array of the specified
|
@@ -763,7 +765,7 @@ def do_draw(self, data):
|
763 | 765 |
|
764 | 766 | @st.defines_strategy
|
765 | 767 | def broadcastable_shapes(shape, min_dims=0, max_dims=None, min_side=1, max_side=None):
|
766 |
| - # type: (Sequence[int], int, Optional[int], int, Optional[int]) -> st.SearchStrategy[Tuple[int, ...]] |
| 768 | + # type: (Shape, int, int, int, int) -> st.SearchStrategy[Shape] |
767 | 769 | """Return a strategy for generating shapes that are broadcast-compatible
|
768 | 770 | with the provided shape.
|
769 | 771 |
|
@@ -846,3 +848,68 @@ def broadcastable_shapes(shape, min_dims=0, max_dims=None, min_side=1, max_side=
|
846 | 848 | min_side=min_side,
|
847 | 849 | max_side=max_side,
|
848 | 850 | )
|
| 851 | + |
| 852 | + |
| 853 | +@st.defines_strategy |
| 854 | +def integer_array_indices(shape, result_shape=array_shapes(), dtype="int"): |
| 855 | + # type: (Shape, SearchStrategy[Shape], np.dtype) -> st.SearchStrategy[Tuple[np.ndarray, ...]] |
| 856 | + """Return a search strategy for tuples of integer-arrays that, when used |
| 857 | + to index into an array of shape ``shape``, given an array whose shape |
| 858 | + was drawn from ``result_shape``. |
| 859 | +
|
| 860 | + Examples from this strategy shrink towards the tuple of index-arrays:: |
| 861 | +
|
| 862 | + len(shape) * (np.zeros(drawn_result_shape, dtype), ) |
| 863 | +
|
| 864 | + * ``shape`` a tuple of integers that indicates the shape of the array, |
| 865 | + whose indices are being generated. |
| 866 | + * ``result_shape`` a strategy for generating tuples of integers, which |
| 867 | + describe the shape of the resulting index arrays. The default is |
| 868 | + :func:`~hypothesis.extra.numpy.array_shapes`. The shape drawn from |
| 869 | + this strategy determines the shape of the array that will be produced |
| 870 | + when the corresponding example from ``integer_array_indices`` is used |
| 871 | + as an index. |
| 872 | + * ``dtype`` the integer data type of the generated index-arrays. Negative |
| 873 | + integer indices can be generated if a signed integer type is specified. |
| 874 | +
|
| 875 | + Recall that an array can be indexed using a tuple of integer-arrays to |
| 876 | + access its members in an arbitrary order, producing an array with an |
| 877 | + arbitrary shape. For example: |
| 878 | +
|
| 879 | + .. code-block:: pycon |
| 880 | +
|
| 881 | + >>> from numpy import array |
| 882 | + >>> x = array([-0, -1, -2, -3, -4]) |
| 883 | + >>> ind = (array([[4, 0], [0, 1]]),) # a tuple containing a 2D integer-array |
| 884 | + >>> x[ind] # the resulting array is commensurate with the indexing array(s) |
| 885 | + array([[-4, 0], |
| 886 | + [0, -1]]) |
| 887 | +
|
| 888 | + Note that this strategy does not accommodate all variations of so-called |
| 889 | + 'advanced indexing', as prescribed by NumPy's nomenclature. Combinations |
| 890 | + of basic and advanced indexes are too complex to usefully define in a |
| 891 | + standard strategy; we leave application-specific strategies to the user. |
| 892 | + Advanced-boolean indexing can be defined as ``arrays(shape=..., dtype=bool)``, |
| 893 | + and is similarly left to the user. |
| 894 | + """ |
| 895 | + check_type(tuple, shape, "shape") |
| 896 | + check_argument( |
| 897 | + shape and all(isinstance(x, integer_types) and x > 0 for x in shape), |
| 898 | + "shape=%r must be a non-empty tuple of integers > 0" % (shape,), |
| 899 | + ) |
| 900 | + check_type(SearchStrategy, result_shape, "result_shape") |
| 901 | + check_argument( |
| 902 | + np.issubdtype(dtype, np.integer), "dtype=%r must be an integer dtype" % (dtype,) |
| 903 | + ) |
| 904 | + signed = np.issubdtype(dtype, np.signedinteger) |
| 905 | + |
| 906 | + def array_for(index_shape, size): |
| 907 | + return arrays( |
| 908 | + dtype=dtype, |
| 909 | + shape=index_shape, |
| 910 | + elements=st.integers(-size if signed else 0, size - 1), |
| 911 | + ) |
| 912 | + |
| 913 | + return result_shape.flatmap( |
| 914 | + lambda index_shape: st.tuples(*[array_for(index_shape, size) for size in shape]) |
| 915 | + ) |
0 commit comments