Skip to content

Commit 87faec1

Browse files
For custom-sized subplots, _add_annotation_like works with "all"
The implementation doesn't try to put shapes on subplots at are None in the grid_ref. A test is there that goes with it.
1 parent caff20c commit 87faec1

File tree

3 files changed

+103
-12
lines changed

3 files changed

+103
-12
lines changed

packages/python/plotly/plotly/basedatatypes.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,10 +1136,9 @@ def _add_annotation_like(
11361136

11371137
# Address multiple subplots
11381138
if row is not None and _is_select_subplot_coordinates_arg(row, col):
1139-
# TODO add product argument
1140-
rows, cols = self._select_subplot_coordinates(row, col)
1141-
# TODO do we have to unzip the row and columns, just to zip them again?
1142-
for r, c in zip(rows, cols):
1139+
# TODO product argument could be added
1140+
rows_cols = self._select_subplot_coordinates(row, col)
1141+
for r, c in rows_cols:
11431142
self._add_annotation_like(
11441143
prop_singular,
11451144
prop_plural,
@@ -1732,9 +1731,8 @@ def add_trace(self, trace, row=None, col=None, secondary_y=None):
17321731
# Address multiple subplots
17331732
if row is not None and _is_select_subplot_coordinates_arg(row, col):
17341733
# TODO add product argument
1735-
rows, cols = self._select_subplot_coordinates(row, col)
1736-
# TODO do we have to unzip the row and columns, just to zip them again?
1737-
for r, c in zip(rows, cols):
1734+
rows_cols = self._select_subplot_coordinates(row, col)
1735+
for r, c in rows_cols:
17381736
self.add_trace(trace, row=r, col=c, secondary_y=secondary_y)
17391737
return self
17401738

@@ -1978,8 +1976,10 @@ def _select_subplot_coordinates(self, rows, cols, product=False):
19781976
[rows, cols], list(self._get_subplot_rows_columns()), product=product,
19791977
)
19801978
t = list(t)
1981-
r, c = _unzip_pairs(t)
1982-
return (r, c)
1979+
# remove rows and cols where the subplot is "None"
1980+
grid_ref = self._validate_get_grid_ref()
1981+
t = list(filter(lambda u: grid_ref[u[0] - 1][u[1] - 1] is not None, t))
1982+
return t
19831983

19841984
def get_subplot(self, row, col, secondary_y=False):
19851985
"""

packages/python/plotly/plotly/tests/test_core/test_update_objects/test_paper_span_shapes.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,96 @@ def test_non_subplot_add_span_shape(test_input, expected, non_subplot_fig_fixtur
383383
_check_figure_shapes(test_input, expected, non_subplot_fig_fixture)
384384

385385

386-
def test_invalid_subplot_address(subplot_fig_fixture):
386+
@pytest.mark.parametrize(
387+
"test_input",
388+
[
389+
(go.Figure.add_hline, dict(y=10, row=4, col=5)),
390+
# valid row, invalid column
391+
(go.Figure.add_hline, dict(y=10, row=1, col=5)),
392+
],
393+
)
394+
def test_invalid_subplot_address(test_input, subplot_fig_fixture):
395+
f, kwargs = test_input
387396
with pytest.raises(IndexError):
388-
subplot_fig_fixture.add_hline(y=10, row=4, col=5)
397+
f(subplot_fig_fixture, **kwargs)
398+
399+
400+
def _check_figure_shapes_custom_sized(test_input, expected, fig):
401+
# look up domains in fig
402+
corrects = []
403+
for d, ax in expected:
404+
dom = fig["layout"][ax]["domain"]
405+
if ax[: len("xaxis")] == "xaxis":
406+
d["x0"], d["x1"] = dom
407+
elif ax[: len("yaxis")] == "yaxis":
408+
d["y0"], d["y1"] = dom
409+
else:
410+
raise ValueError("bad axis")
411+
corrects.append(d)
412+
f, kwargs = test_input
413+
f(fig, **kwargs)
414+
ret = True
415+
for s, d in zip(fig.layout.shapes, corrects):
416+
ret &= _cmp_partial_dict(s, d)
417+
assert ret
418+
419+
420+
@pytest.mark.parametrize(
421+
"test_input,expected",
422+
# test_input: (function,kwargs)
423+
# expected: list of dictionaries with key:value pairs we expect in the added shapes
424+
[
425+
(
426+
(go.Figure.add_vline, dict(x=1.5, row="all", col=2)),
427+
[
428+
(
429+
{
430+
"type": "line",
431+
"x0": 1.5,
432+
"x1": 1.5,
433+
"xref": "x2",
434+
"yref": "paper",
435+
},
436+
"yaxis2",
437+
),
438+
(
439+
{
440+
"type": "line",
441+
"x0": 1.5,
442+
"x1": 1.5,
443+
"xref": "x6",
444+
"yref": "paper",
445+
},
446+
"yaxis6",
447+
),
448+
],
449+
),
450+
(
451+
(go.Figure.add_hline, dict(y=1.5, row=5, col="all")),
452+
[
453+
(
454+
{
455+
"type": "line",
456+
"yref": "y5",
457+
"y0": 1.5,
458+
"y1": 1.5,
459+
"xref": "paper",
460+
},
461+
"xaxis5",
462+
),
463+
(
464+
{
465+
"type": "line",
466+
"yref": "y6",
467+
"y0": 1.5,
468+
"y1": 1.5,
469+
"xref": "paper",
470+
},
471+
"xaxis6",
472+
),
473+
],
474+
),
475+
],
476+
)
477+
def test_custom_sized_subplots(test_input, expected, custom_sized_subplots):
478+
_check_figure_shapes_custom_sized(test_input, expected, custom_sized_subplots)

packages/python/plotly/plotly/tests/test_core/test_update_objects/test_select_subplots.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def _sort_row_col_lists(rows, cols):
176176
def test_select_subplot_coordinates(subplot_fig_fixture, test_input, expected):
177177
rows, cols, product = test_input
178178
er, ec = _sort_row_col_lists(*expected)
179-
r, c = subplot_fig_fixture._select_subplot_coordinates(rows, cols, product=product)
179+
t = subplot_fig_fixture._select_subplot_coordinates(rows, cols, product=product)
180+
r, c = _unzip_pairs(t)
180181
r, c = _sort_row_col_lists(r, c)
181182
assert (r == er) and (c == ec)

0 commit comments

Comments
 (0)