Skip to content

Commit 7fca7de

Browse files
Merge branch 'issue-2140-b' into issue-2140
2 parents af9c758 + 366f5ec commit 7fca7de

File tree

3 files changed

+249
-51
lines changed

3 files changed

+249
-51
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,4 @@ temp-plot.html
4646
doc/python/.ipynb_checkpoints
4747
doc/python/.mapbox_token
4848
doc/.ipynb_checkpoints
49-
5049
tags

packages/python/plotly/plotly/basedatatypes.py

+68-50
Original file line numberDiff line numberDiff line change
@@ -19,58 +19,41 @@
1919
Undefined = object()
2020

2121

22-
def _rcindex_type(d):
23-
all_flag = False
24-
if type(d) == type(tuple()):
25-
d, f = d
26-
if f == "all":
27-
all_flag = True
28-
if type(d) == type(range(1)):
29-
d = list(d)
30-
if type(d) == type(int()):
31-
return (d, "i", all_flag)
32-
elif type(d) == type(list()):
33-
return (d, "l", all_flag)
34-
elif d == "all":
35-
return (d, "a", all_flag)
36-
else:
37-
raise TypeError(
38-
"argument must be 'all', int or list, got {d_type}".format(
39-
d_type=str(type(d))
40-
)
41-
)
22+
def _unzip_pairs(pairs):
23+
pairs = list(pairs)
24+
return ([t[0] for t in pairs], [t[1] for t in pairs])
4225

4326

44-
def _rcsingle_index_to_list(d):
45-
if type(d) == type(int()):
46-
return [d]
47-
return d
48-
49-
50-
def _row_col_index_combinations(rows, cols, max_n_rows, max_n_cols):
51-
all_flag = False
52-
rows, rtype, f = _rcindex_type(rows)
53-
all_flag |= f
54-
cols, ctype, f = _rcindex_type(cols)
55-
all_flag |= f
56-
rows = _rcsingle_index_to_list(rows)
57-
cols = _rcsingle_index_to_list(cols)
58-
ptype = (rtype, ctype)
59-
all_rows = range(1, max_n_rows + 1)
60-
all_cols = range(1, max_n_cols + 1)
61-
if ptype == ("a", "a"):
62-
return list(itertools.product(all_rows, all_cols))
63-
elif ptype == ("l", "a") or ptype == ("i", "a"):
64-
return list(itertools.product(rows, all_cols))
65-
elif ptype == ("a", "l") or ptype == ("a", "i"):
66-
return list(itertools.product(all_rows, cols))
67-
elif ptype == ("l", "l"):
68-
if len(rows) == len(cols) and not all_flag:
69-
return list(zip(rows, cols))
70-
else:
71-
return list(itertools.product(rows, cols))
72-
elif ptype == ("l", "i") or ptype == ("i", "i") or ptype == ("i", "l"):
73-
return list(itertools.product(rows, cols))
27+
def _indexing_combinations(dims, alls, product=False):
28+
"""
29+
Gives indexing tuples specified by the coordinates in dims.
30+
If a member of dims is 'all' then it is replaced by the corresponding member
31+
in alls.
32+
If product is True, then the cartesian product of all the indices is
33+
returned, otherwise the zip (that means index lists of mis-matched length
34+
will yield a list of tuples whose length is the length of the shortest
35+
list).
36+
"""
37+
if len(dims) == 0:
38+
# this is because list(itertools.product(*[])) returns [()] which has non-zero
39+
# length!
40+
return []
41+
if len(dims) != len(alls):
42+
raise ValueError(
43+
"Must have corresponding values in alls for each value of dims. Got dims=%s and alls=%s."
44+
% (str(dims), str(alls))
45+
)
46+
r = []
47+
for d, a in zip(dims, alls):
48+
if d == "all":
49+
d = a
50+
elif type(d) != type(list()):
51+
d = [d]
52+
r.append(d)
53+
if product:
54+
return itertools.product(*r)
55+
else:
56+
return zip(*r)
7457

7558

7659
class BaseFigure(object):
@@ -1918,6 +1901,41 @@ def _validate_get_grid_ref(self):
19181901
)
19191902
return grid_ref
19201903

1904+
def _get_subplot_rows_columns(self):
1905+
"""
1906+
Returns a pair of lists, the first containing all the row indices and
1907+
the second all the column indices.
1908+
"""
1909+
# currently, this just iterates over all the rows and columns (because
1910+
# self._grid_ref is currently always rectangular)
1911+
grid_ref = self._validate_get_grid_ref()
1912+
nrows = len(grid_ref)
1913+
ncols = len(grid_ref[0])
1914+
return (range(1, nrows + 1), range(1, ncols + 1))
1915+
1916+
def _get_subplot_coordinates(self):
1917+
"""
1918+
Returns an iterator over (row,col) pairs representing all the possible
1919+
subplot coordinates.
1920+
"""
1921+
return itertools.product(*self._get_subplot_rows_columns())
1922+
1923+
def _select_subplot_coordinates(self, rows, cols, product=False):
1924+
"""
1925+
Allows selecting all or a subset of the subplots.
1926+
If any of rows or columns is 'all', product is set to True. This is
1927+
probably the expected behaviour, so that rows=1,cols='all' selects all
1928+
the columns in row 1 (otherwise it would just select the subplot in the
1929+
first row and first column).
1930+
"""
1931+
product |= any([s == "all" for s in [rows, cols]])
1932+
t = _indexing_combinations(
1933+
[rows, cols], list(self._get_subplot_rows_columns()), product=product,
1934+
)
1935+
t = list(t)
1936+
r, c = _unzip_pairs(t)
1937+
return (r, c)
1938+
19211939
def get_subplot(self, row, col, secondary_y=False):
19221940
"""
19231941
Return an object representing the subplot at the specified row
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import plotly.graph_objs as go
2+
from plotly.subplots import make_subplots
3+
from plotly.basedatatypes import _indexing_combinations, _unzip_pairs
4+
import pytest
5+
6+
NROWS = 4
7+
NCOLS = 5
8+
9+
10+
@pytest.fixture
11+
def subplot_fig_fixture():
12+
fig = make_subplots(NROWS, NCOLS)
13+
return fig
14+
15+
16+
@pytest.fixture
17+
def non_subplot_fig_fixture():
18+
fig = go.Figure(go.Scatter(x=[1, 2, 3], y=[4, 3, 2]))
19+
return fig
20+
21+
22+
def test_invalid_validate_get_grid_ref(non_subplot_fig_fixture):
23+
with pytest.raises(Exception):
24+
_ = non_subplot_fig_fixture._validate_get_grid_ref()
25+
26+
27+
def test_get_subplot_coordinates(subplot_fig_fixture):
28+
assert set(subplot_fig_fixture._get_subplot_coordinates()) == set(
29+
[(r, c) for r in range(1, NROWS + 1) for c in range(1, NCOLS + 1)]
30+
)
31+
32+
33+
def test_indexing_combinations_edge_cases():
34+
# Although in theory _indexing_combinations works for any number of
35+
# dimensions, we're just interested in 2D for subplots so that's what we
36+
# test here.
37+
assert _indexing_combinations([], []) == []
38+
with pytest.raises(ValueError):
39+
_ = _indexing_combinations([[1, 2], [3, 4, 5]], [[1, 2]])
40+
41+
42+
# 18 combinations of input possible:
43+
# ('all', 'all', 'product=True'),
44+
# ('all', 'all', 'product=False'),
45+
# ('all', '<list>', 'product=True'),
46+
# ('all', '<list>', 'product=False'),
47+
# ('all', '<not-list>', 'product=True'),
48+
# ('all', '<not-list>', 'product=False'),
49+
# ('<list>', 'all', 'product=True'),
50+
# ('<list>', 'all', 'product=False'),
51+
# ('<list>', '<list>', 'product=True'),
52+
# ('<list>', '<list>', 'product=False'),
53+
# ('<list>', '<not-list>', 'product=True'),
54+
# ('<list>', '<not-list>', 'product=False'),
55+
# ('<not-list>', 'all', 'product=True'),
56+
# ('<not-list>', 'all', 'product=False'),
57+
# ('<not-list>', '<list>', 'product=True'),
58+
# ('<not-list>', '<list>', 'product=False'),
59+
# ('<not-list>', '<not-list>', 'product=True'),
60+
# ('<not-list>', '<not-list>', 'product=False')
61+
# For <not-list> we choose int because that's what the subplot indexing routines
62+
# will work with.
63+
all_rows = [1, 2, 3, 4]
64+
all_cols = [1, 2, 3, 4, 5]
65+
66+
67+
@pytest.mark.parametrize(
68+
"test_input,expected",
69+
[
70+
(
71+
dict(dims=["all", "all"], alls=[all_rows, all_cols], product=False),
72+
set(zip(all_rows, all_cols)),
73+
),
74+
(
75+
dict(dims=["all", "all"], alls=[all_rows, all_cols], product=True),
76+
set([(r, c) for r in all_rows for c in all_cols]),
77+
),
78+
(
79+
dict(dims=["all", [2, 4, 5]], alls=[all_rows, all_cols], product=False),
80+
set(zip(all_rows, [2, 4, 5])),
81+
),
82+
(
83+
dict(dims=["all", [2, 4, 5]], alls=[all_rows, all_cols], product=True),
84+
set([(r, c) for r in all_rows for c in [2, 4, 5]]),
85+
),
86+
(
87+
dict(dims=["all", 3], alls=[all_rows, all_cols], product=False),
88+
set([(all_rows[0], 3)]),
89+
),
90+
(
91+
dict(dims=["all", 3], alls=[all_rows, all_cols], product=True),
92+
set([(r, c) for r in all_rows for c in [3]]),
93+
),
94+
(
95+
dict(dims=[[1, 3], "all"], alls=[all_rows, all_cols], product=False),
96+
set(zip([1, 3], all_cols)),
97+
),
98+
(
99+
dict(dims=[[1, 3], "all"], alls=[all_rows, all_cols], product=True),
100+
set([(r, c) for r in [1, 3] for c in all_cols]),
101+
),
102+
(
103+
dict(dims=[[1, 3], [2, 4, 5]], alls=[all_rows, all_cols], product=False),
104+
set(zip([1, 3], [2, 4, 5])),
105+
),
106+
(
107+
dict(dims=[[1, 3], [2, 4, 5]], alls=[all_rows, all_cols], product=True),
108+
set([(r, c) for r in [1, 3] for c in [2, 4, 5]]),
109+
),
110+
(
111+
dict(dims=[[1, 3], 3], alls=[all_rows, all_cols], product=False),
112+
set([(1, 3)]),
113+
),
114+
(
115+
dict(dims=[[1, 3], 3], alls=[all_rows, all_cols], product=True),
116+
set([(r, c) for r in [1, 3] for c in [3]]),
117+
),
118+
(
119+
dict(dims=[2, "all"], alls=[all_rows, all_cols], product=False),
120+
set([(2, all_cols[0])]),
121+
),
122+
(
123+
dict(dims=[2, "all"], alls=[all_rows, all_cols], product=True),
124+
set([(r, c) for r in [2] for c in all_cols]),
125+
),
126+
(
127+
dict(dims=[2, [2, 4, 5]], alls=[all_rows, all_cols], product=False),
128+
set([(2, 2)]),
129+
),
130+
(
131+
dict(dims=[2, [2, 4, 5]], alls=[all_rows, all_cols], product=True),
132+
set([(r, c) for r in [2] for c in [2, 4, 5]]),
133+
),
134+
(dict(dims=[2, 3], alls=[all_rows, all_cols], product=False), set([(2, 3)])),
135+
(dict(dims=[2, 3], alls=[all_rows, all_cols], product=True), set([(2, 3)])),
136+
],
137+
)
138+
def test_indexing_combinations(test_input, expected):
139+
assert set(_indexing_combinations(**test_input)) == expected
140+
141+
142+
def _sort_row_col_lists(rows, cols):
143+
# makes sure that row and column lists are compared in the same order
144+
# sorted on rows
145+
si = sorted(range(len(rows)), key=lambda i: rows[i])
146+
rows = [rows[i] for i in si]
147+
cols = [cols[i] for i in si]
148+
return (rows, cols)
149+
150+
151+
# _indexing_combinations tests most cases of the following function
152+
# we just need to test that setting rows or cols to 'all' makes product True,
153+
# and if not, we can still set product to True.
154+
@pytest.mark.parametrize(
155+
"test_input,expected",
156+
[
157+
(
158+
("all", [2, 4, 5], False),
159+
_unzip_pairs([(r, c) for r in range(1, NROWS + 1) for c in [2, 4, 5]]),
160+
),
161+
(
162+
([1, 3], "all", False),
163+
_unzip_pairs([(r, c) for r in [1, 3] for c in range(1, NCOLS + 1)]),
164+
),
165+
(
166+
([1, 3], "all", True),
167+
_unzip_pairs([(r, c) for r in [1, 3] for c in range(1, NCOLS + 1)]),
168+
),
169+
(([1, 3], [2, 4, 5], False), _unzip_pairs([(1, 2), (3, 4)])),
170+
(
171+
([1, 3], [2, 4, 5], True),
172+
_unzip_pairs([(r, c) for r in [1, 3] for c in [2, 4, 5]]),
173+
),
174+
],
175+
)
176+
def test_select_subplot_coordinates(subplot_fig_fixture, test_input, expected):
177+
rows, cols, product = test_input
178+
er, ec = _sort_row_col_lists(*expected)
179+
r, c = subplot_fig_fixture._select_subplot_coordinates(rows, cols, product=product)
180+
r, c = _sort_row_col_lists(r, c)
181+
assert (r == er) and (c == ec)

0 commit comments

Comments
 (0)