|
| 1 | +import plotly.graph_objs as go |
| 2 | +from plotly.subplots import make_subplots |
| 3 | +from plotly.basedatatypes import _indexing_combinations, _unzip_pairs |
| 4 | +import pytest |
| 5 | + |
| 6 | +NROWS = 4 |
| 7 | +NCOLS = 5 |
| 8 | + |
| 9 | + |
| 10 | +@pytest.fixture |
| 11 | +def subplot_fig_fixture(): |
| 12 | + fig = make_subplots(NROWS, NCOLS) |
| 13 | + return fig |
| 14 | + |
| 15 | + |
| 16 | +@pytest.fixture |
| 17 | +def non_subplot_fig_fixture(): |
| 18 | + fig = go.Figure(go.Scatter(x=[1, 2, 3], y=[4, 3, 2])) |
| 19 | + return fig |
| 20 | + |
| 21 | + |
| 22 | +def test_invalid_validate_get_grid_ref(non_subplot_fig_fixture): |
| 23 | + with pytest.raises(Exception): |
| 24 | + _ = non_subplot_fig_fixture._validate_get_grid_ref() |
| 25 | + |
| 26 | + |
| 27 | +def test_get_subplot_coordinates(subplot_fig_fixture): |
| 28 | + assert set(subplot_fig_fixture._get_subplot_coordinates()) == set( |
| 29 | + [(r, c) for r in range(1, NROWS + 1) for c in range(1, NCOLS + 1)] |
| 30 | + ) |
| 31 | + |
| 32 | + |
| 33 | +def test_indexing_combinations_edge_cases(): |
| 34 | + # Although in theory _indexing_combinations works for any number of |
| 35 | + # dimensions, we're just interested in 2D for subplots so that's what we |
| 36 | + # test here. |
| 37 | + assert _indexing_combinations([], []) == [] |
| 38 | + with pytest.raises(ValueError): |
| 39 | + _ = _indexing_combinations([[1, 2], [3, 4, 5]], [[1, 2]]) |
| 40 | + |
| 41 | + |
| 42 | +# 18 combinations of input possible: |
| 43 | +# ('all', 'all', 'product=True'), |
| 44 | +# ('all', 'all', 'product=False'), |
| 45 | +# ('all', '<list>', 'product=True'), |
| 46 | +# ('all', '<list>', 'product=False'), |
| 47 | +# ('all', '<not-list>', 'product=True'), |
| 48 | +# ('all', '<not-list>', 'product=False'), |
| 49 | +# ('<list>', 'all', 'product=True'), |
| 50 | +# ('<list>', 'all', 'product=False'), |
| 51 | +# ('<list>', '<list>', 'product=True'), |
| 52 | +# ('<list>', '<list>', 'product=False'), |
| 53 | +# ('<list>', '<not-list>', 'product=True'), |
| 54 | +# ('<list>', '<not-list>', 'product=False'), |
| 55 | +# ('<not-list>', 'all', 'product=True'), |
| 56 | +# ('<not-list>', 'all', 'product=False'), |
| 57 | +# ('<not-list>', '<list>', 'product=True'), |
| 58 | +# ('<not-list>', '<list>', 'product=False'), |
| 59 | +# ('<not-list>', '<not-list>', 'product=True'), |
| 60 | +# ('<not-list>', '<not-list>', 'product=False') |
| 61 | +# For <not-list> we choose int because that's what the subplot indexing routines |
| 62 | +# will work with. |
| 63 | +all_rows = [1, 2, 3, 4] |
| 64 | +all_cols = [1, 2, 3, 4, 5] |
| 65 | + |
| 66 | + |
| 67 | +@pytest.mark.parametrize( |
| 68 | + "test_input,expected", |
| 69 | + [ |
| 70 | + ( |
| 71 | + dict(dims=["all", "all"], alls=[all_rows, all_cols], product=False), |
| 72 | + set(zip(all_rows, all_cols)), |
| 73 | + ), |
| 74 | + ( |
| 75 | + dict(dims=["all", "all"], alls=[all_rows, all_cols], product=True), |
| 76 | + set([(r, c) for r in all_rows for c in all_cols]), |
| 77 | + ), |
| 78 | + ( |
| 79 | + dict(dims=["all", [2, 4, 5]], alls=[all_rows, all_cols], product=False), |
| 80 | + set(zip(all_rows, [2, 4, 5])), |
| 81 | + ), |
| 82 | + ( |
| 83 | + dict(dims=["all", [2, 4, 5]], alls=[all_rows, all_cols], product=True), |
| 84 | + set([(r, c) for r in all_rows for c in [2, 4, 5]]), |
| 85 | + ), |
| 86 | + ( |
| 87 | + dict(dims=["all", 3], alls=[all_rows, all_cols], product=False), |
| 88 | + set([(all_rows[0], 3)]), |
| 89 | + ), |
| 90 | + ( |
| 91 | + dict(dims=["all", 3], alls=[all_rows, all_cols], product=True), |
| 92 | + set([(r, c) for r in all_rows for c in [3]]), |
| 93 | + ), |
| 94 | + ( |
| 95 | + dict(dims=[[1, 3], "all"], alls=[all_rows, all_cols], product=False), |
| 96 | + set(zip([1, 3], all_cols)), |
| 97 | + ), |
| 98 | + ( |
| 99 | + dict(dims=[[1, 3], "all"], alls=[all_rows, all_cols], product=True), |
| 100 | + set([(r, c) for r in [1, 3] for c in all_cols]), |
| 101 | + ), |
| 102 | + ( |
| 103 | + dict(dims=[[1, 3], [2, 4, 5]], alls=[all_rows, all_cols], product=False), |
| 104 | + set(zip([1, 3], [2, 4, 5])), |
| 105 | + ), |
| 106 | + ( |
| 107 | + dict(dims=[[1, 3], [2, 4, 5]], alls=[all_rows, all_cols], product=True), |
| 108 | + set([(r, c) for r in [1, 3] for c in [2, 4, 5]]), |
| 109 | + ), |
| 110 | + ( |
| 111 | + dict(dims=[[1, 3], 3], alls=[all_rows, all_cols], product=False), |
| 112 | + set([(1, 3)]), |
| 113 | + ), |
| 114 | + ( |
| 115 | + dict(dims=[[1, 3], 3], alls=[all_rows, all_cols], product=True), |
| 116 | + set([(r, c) for r in [1, 3] for c in [3]]), |
| 117 | + ), |
| 118 | + ( |
| 119 | + dict(dims=[2, "all"], alls=[all_rows, all_cols], product=False), |
| 120 | + set([(2, all_cols[0])]), |
| 121 | + ), |
| 122 | + ( |
| 123 | + dict(dims=[2, "all"], alls=[all_rows, all_cols], product=True), |
| 124 | + set([(r, c) for r in [2] for c in all_cols]), |
| 125 | + ), |
| 126 | + ( |
| 127 | + dict(dims=[2, [2, 4, 5]], alls=[all_rows, all_cols], product=False), |
| 128 | + set([(2, 2)]), |
| 129 | + ), |
| 130 | + ( |
| 131 | + dict(dims=[2, [2, 4, 5]], alls=[all_rows, all_cols], product=True), |
| 132 | + set([(r, c) for r in [2] for c in [2, 4, 5]]), |
| 133 | + ), |
| 134 | + (dict(dims=[2, 3], alls=[all_rows, all_cols], product=False), set([(2, 3)])), |
| 135 | + (dict(dims=[2, 3], alls=[all_rows, all_cols], product=True), set([(2, 3)])), |
| 136 | + ], |
| 137 | +) |
| 138 | +def test_indexing_combinations(test_input, expected): |
| 139 | + assert set(_indexing_combinations(**test_input)) == expected |
| 140 | + |
| 141 | + |
| 142 | +def _sort_row_col_lists(rows, cols): |
| 143 | + # makes sure that row and column lists are compared in the same order |
| 144 | + # sorted on rows |
| 145 | + si = sorted(range(len(rows)), key=lambda i: rows[i]) |
| 146 | + rows = [rows[i] for i in si] |
| 147 | + cols = [cols[i] for i in si] |
| 148 | + return (rows, cols) |
| 149 | + |
| 150 | + |
| 151 | +# _indexing_combinations tests most cases of the following function |
| 152 | +# we just need to test that setting rows or cols to 'all' makes product True, |
| 153 | +# and if not, we can still set product to True. |
| 154 | +@pytest.mark.parametrize( |
| 155 | + "test_input,expected", |
| 156 | + [ |
| 157 | + ( |
| 158 | + ("all", [2, 4, 5], False), |
| 159 | + _unzip_pairs([(r, c) for r in range(1, NROWS + 1) for c in [2, 4, 5]]), |
| 160 | + ), |
| 161 | + ( |
| 162 | + ([1, 3], "all", False), |
| 163 | + _unzip_pairs([(r, c) for r in [1, 3] for c in range(1, NCOLS + 1)]), |
| 164 | + ), |
| 165 | + ( |
| 166 | + ([1, 3], "all", True), |
| 167 | + _unzip_pairs([(r, c) for r in [1, 3] for c in range(1, NCOLS + 1)]), |
| 168 | + ), |
| 169 | + (([1, 3], [2, 4, 5], False), _unzip_pairs([(1, 2), (3, 4)])), |
| 170 | + ( |
| 171 | + ([1, 3], [2, 4, 5], True), |
| 172 | + _unzip_pairs([(r, c) for r in [1, 3] for c in [2, 4, 5]]), |
| 173 | + ), |
| 174 | + ], |
| 175 | +) |
| 176 | +def test_select_subplot_coordinates(subplot_fig_fixture, test_input, expected): |
| 177 | + rows, cols, product = test_input |
| 178 | + er, ec = _sort_row_col_lists(*expected) |
| 179 | + r, c = subplot_fig_fixture._select_subplot_coordinates(rows, cols, product=product) |
| 180 | + r, c = _sort_row_col_lists(r, c) |
| 181 | + assert (r == er) and (c == ec) |
0 commit comments