diff --git a/CHANGELOG.md b/CHANGELOG.md index 6204929c605..1e15bc6cfb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added - `px.imshow` now supports `facet_col` and `animation_frame` arguments for visualizing 3-d and 4-d images [2746](https://github.com/plotly/plotly.py/pull/2746) +- `px.defaults` now supports `color_discrete_map`, `symbol_map`, `line_dash_map`, `labels` and `category_orders` as well as a `.reset()` method [2957](https://github.com/plotly/plotly.py/pull/2957) ### Fixed diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 1b82522a8d8..7ed26491afe 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -31,16 +31,16 @@ def scatter( error_y_minus=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, orientation=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, symbol_sequence=None, - symbol_map={}, + symbol_map=None, opacity=None, size_max=None, marginal_x=None, @@ -82,11 +82,11 @@ def density_contour( hover_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, orientation=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, marginal_x=None, marginal_y=None, trendline=None, @@ -151,8 +151,8 @@ def density_heatmap( hover_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, orientation=None, color_continuous_scale=None, range_color=None, @@ -227,13 +227,13 @@ def line( error_y_minus=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, orientation=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, line_dash_sequence=None, - line_dash_map={}, + line_dash_map=None, log_x=False, log_y=False, range_x=None, @@ -272,10 +272,10 @@ def area( facet_col_spacing=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, orientation=None, groupnorm=None, log_x=False, @@ -324,10 +324,10 @@ def bar( error_y_minus=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -375,10 +375,10 @@ def timeline( text=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -419,10 +419,10 @@ def histogram( hover_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, marginal=None, opacity=None, orientation=None, @@ -486,10 +486,10 @@ def violin( custom_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, orientation=None, violinmode=None, log_x=False, @@ -535,10 +535,10 @@ def box( custom_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, orientation=None, boxmode=None, log_x=False, @@ -587,10 +587,10 @@ def strip( custom_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, orientation=None, stripmode=None, log_x=False, @@ -645,16 +645,16 @@ def scatter_3d( error_z_minus=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, size_max=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, symbol_sequence=None, - symbol_map={}, + symbol_map=None, opacity=None, log_x=False, log_y=False, @@ -697,12 +697,12 @@ def line_3d( error_z_minus=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, line_dash_sequence=None, - line_dash_map={}, + line_dash_map=None, log_x=False, log_y=False, log_z=False, @@ -738,15 +738,15 @@ def scatter_ternary( custom_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, symbol_sequence=None, - symbol_map={}, + symbol_map=None, opacity=None, size_max=None, title=None, @@ -778,12 +778,12 @@ def line_ternary( text=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, line_dash_sequence=None, - line_dash_map={}, + line_dash_map=None, line_shape=None, title=None, template=None, @@ -813,15 +813,15 @@ def scatter_polar( text=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, symbol_sequence=None, - symbol_map={}, + symbol_map=None, opacity=None, direction="clockwise", start_angle=90, @@ -858,12 +858,12 @@ def line_polar( text=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, line_dash_sequence=None, - line_dash_map={}, + line_dash_map=None, direction="clockwise", start_angle=90, line_close=False, @@ -898,10 +898,10 @@ def bar_polar( base=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -950,10 +950,10 @@ def choropleth( custom_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -1003,15 +1003,15 @@ def scatter_geo( size=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, symbol_sequence=None, - symbol_map={}, + symbol_map=None, opacity=None, size_max=None, projection=None, @@ -1060,12 +1060,12 @@ def line_geo( line_group=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, line_dash_sequence=None, - line_dash_map={}, + line_dash_map=None, projection=None, scope=None, center=None, @@ -1102,10 +1102,10 @@ def scatter_mapbox( size=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -1140,10 +1140,10 @@ def choropleth_mapbox( custom_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -1176,8 +1176,8 @@ def density_mapbox( custom_data=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -1215,10 +1215,10 @@ def line_mapbox( line_group=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, zoom=8, center=None, mapbox_style=None, @@ -1246,15 +1246,15 @@ def scatter_matrix( hover_name=None, hover_data=None, custom_data=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, symbol_sequence=None, - symbol_map={}, + symbol_map=None, opacity=None, size_max=None, title=None, @@ -1280,7 +1280,7 @@ def parallel_coordinates( data_frame=None, dimensions=None, color=None, - labels={}, + labels=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -1304,7 +1304,7 @@ def parallel_categories( data_frame=None, dimensions=None, color=None, - labels={}, + labels=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -1332,11 +1332,11 @@ def pie( values=None, color=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, hover_name=None, hover_data=None, custom_data=None, - labels={}, + labels=None, title=None, template=None, width=None, @@ -1384,11 +1384,11 @@ def sunburst( range_color=None, color_continuous_midpoint=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, hover_name=None, hover_data=None, custom_data=None, - labels={}, + labels=None, title=None, template=None, width=None, @@ -1434,11 +1434,11 @@ def treemap( range_color=None, color_continuous_midpoint=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, hover_name=None, hover_data=None, custom_data=None, - labels={}, + labels=None, title=None, template=None, width=None, @@ -1488,10 +1488,10 @@ def funnel( text=None, animation_frame=None, animation_group=None, - category_orders={}, - labels={}, + category_orders=None, + labels=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, opacity=None, orientation=None, log_x=False, @@ -1519,11 +1519,11 @@ def funnel_area( values=None, color=None, color_discrete_sequence=None, - color_discrete_map={}, + color_discrete_map=None, hover_name=None, hover_data=None, custom_data=None, - labels={}, + labels=None, title=None, template=None, width=None, diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d847a16e9c1..af416b7f6e1 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -45,21 +45,34 @@ class PxDefaults(object): "width", "height", "color_discrete_sequence", + "color_discrete_map", "color_continuous_scale", "symbol_sequence", + "symbol_map", "line_dash_sequence", + "line_dash_map", "size_max", + "category_orders", + "labels", ] def __init__(self): + self.reset() + + def reset(self): self.template = None self.width = None self.height = None self.color_discrete_sequence = None + self.color_discrete_map = {} self.color_continuous_scale = None self.symbol_sequence = None + self.symbol_map = {} self.line_dash_sequence = None + self.line_dash_map = {} self.size_max = 20 + self.category_orders = {} + self.labels = {} defaults = PxDefaults() @@ -848,11 +861,7 @@ def one_group(x): def apply_default_cascade(args): # first we apply px.defaults to unspecified args - for param in ( - ["color_discrete_sequence", "color_continuous_scale"] - + ["symbol_sequence", "line_dash_sequence", "template"] - + ["width", "height", "size_max"] - ): + for param in defaults.__slots__: if param in args and args[param] is None: args[param] = getattr(defaults, param) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py index 1a298eb484f..c3236de32ce 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py @@ -80,110 +80,132 @@ def test_labels(): def test_px_templates(): - import plotly.io as pio - import plotly.graph_objects as go - - tips = px.data.tips() - - # use the normal defaults - fig = px.scatter() - assert fig.layout.template == pio.templates[pio.templates.default] - - # respect changes to defaults - pio.templates.default = "seaborn" - fig = px.scatter() - assert fig.layout.template == pio.templates["seaborn"] - - # special px-level defaults over pio defaults - pio.templates.default = "seaborn" - px.defaults.template = "ggplot2" - fig = px.scatter() - assert fig.layout.template == pio.templates["ggplot2"] - - # accept names in args over pio and px defaults - fig = px.scatter(template="seaborn") - assert fig.layout.template == pio.templates["seaborn"] - - # accept objects in args - fig = px.scatter(template={}) - assert fig.layout.template == go.layout.Template(data_scatter=[{}]) - - # read colorway from the template - fig = px.scatter( - tips, - x="total_bill", - y="tip", - color="sex", - template=dict(layout_colorway=["red", "blue"]), - ) - assert fig.data[0].marker.color == "red" - assert fig.data[1].marker.color == "blue" - - # default colorway fallback - fig = px.scatter(tips, x="total_bill", y="tip", color="sex", template=dict()) - assert fig.data[0].marker.color == px.colors.qualitative.D3[0] - assert fig.data[1].marker.color == px.colors.qualitative.D3[1] - - # pio default template colorway fallback - pio.templates.default = "seaborn" - px.defaults.template = None - fig = px.scatter(tips, x="total_bill", y="tip", color="sex") - assert fig.data[0].marker.color == pio.templates["seaborn"].layout.colorway[0] - assert fig.data[1].marker.color == pio.templates["seaborn"].layout.colorway[1] - - # pio default template colorway fallback - pio.templates.default = "seaborn" - px.defaults.template = "ggplot2" - fig = px.scatter(tips, x="total_bill", y="tip", color="sex") - assert fig.data[0].marker.color == pio.templates["ggplot2"].layout.colorway[0] - assert fig.data[1].marker.color == pio.templates["ggplot2"].layout.colorway[1] - - # don't overwrite top margin when set in template - fig = px.scatter(title="yo") - assert fig.layout.margin.t is None - - fig = px.scatter() - assert fig.layout.margin.t == 60 - - fig = px.scatter(template=dict(layout_margin_t=2)) - assert fig.layout.margin.t is None - - # don't force histogram gridlines when set in template - pio.templates.default = "none" - px.defaults.template = None - fig = px.scatter( - tips, x="total_bill", y="tip", marginal_x="histogram", marginal_y="histogram" - ) - assert fig.layout.xaxis2.showgrid - assert fig.layout.xaxis3.showgrid - assert fig.layout.yaxis2.showgrid - assert fig.layout.yaxis3.showgrid - - fig = px.scatter( - tips, - x="total_bill", - y="tip", - marginal_x="histogram", - marginal_y="histogram", - template=dict(layout_yaxis_showgrid=False), - ) - assert fig.layout.xaxis2.showgrid - assert fig.layout.xaxis3.showgrid - assert fig.layout.yaxis2.showgrid is None - assert fig.layout.yaxis3.showgrid is None - - fig = px.scatter( - tips, - x="total_bill", - y="tip", - marginal_x="histogram", - marginal_y="histogram", - template=dict(layout_xaxis_showgrid=False), - ) - assert fig.layout.xaxis2.showgrid is None - assert fig.layout.xaxis3.showgrid is None - assert fig.layout.yaxis2.showgrid - assert fig.layout.yaxis3.showgrid + try: + import plotly.io as pio + import plotly.graph_objects as go + + tips = px.data.tips() + + # use the normal defaults + fig = px.scatter() + assert fig.layout.template == pio.templates[pio.templates.default] + + # respect changes to defaults + pio.templates.default = "seaborn" + fig = px.scatter() + assert fig.layout.template == pio.templates["seaborn"] + + # special px-level defaults over pio defaults + pio.templates.default = "seaborn" + px.defaults.template = "ggplot2" + fig = px.scatter() + assert fig.layout.template == pio.templates["ggplot2"] + + # accept names in args over pio and px defaults + fig = px.scatter(template="seaborn") + assert fig.layout.template == pio.templates["seaborn"] + + # accept objects in args + fig = px.scatter(template={}) + assert fig.layout.template == go.layout.Template(data_scatter=[{}]) + + # read colorway from the template + fig = px.scatter( + tips, + x="total_bill", + y="tip", + color="sex", + template=dict(layout_colorway=["red", "blue"]), + ) + assert fig.data[0].marker.color == "red" + assert fig.data[1].marker.color == "blue" + + # default colorway fallback + fig = px.scatter(tips, x="total_bill", y="tip", color="sex", template=dict()) + assert fig.data[0].marker.color == px.colors.qualitative.D3[0] + assert fig.data[1].marker.color == px.colors.qualitative.D3[1] + + # pio default template colorway fallback + pio.templates.default = "seaborn" + px.defaults.template = None + fig = px.scatter(tips, x="total_bill", y="tip", color="sex") + assert fig.data[0].marker.color == pio.templates["seaborn"].layout.colorway[0] + assert fig.data[1].marker.color == pio.templates["seaborn"].layout.colorway[1] + + # pio default template colorway fallback + pio.templates.default = "seaborn" + px.defaults.template = "ggplot2" + fig = px.scatter(tips, x="total_bill", y="tip", color="sex") + assert fig.data[0].marker.color == pio.templates["ggplot2"].layout.colorway[0] + assert fig.data[1].marker.color == pio.templates["ggplot2"].layout.colorway[1] + + # don't overwrite top margin when set in template + fig = px.scatter(title="yo") + assert fig.layout.margin.t is None + + fig = px.scatter() + assert fig.layout.margin.t == 60 + + fig = px.scatter(template=dict(layout_margin_t=2)) + assert fig.layout.margin.t is None + + # don't force histogram gridlines when set in template + pio.templates.default = "none" + px.defaults.template = None + fig = px.scatter( + tips, + x="total_bill", + y="tip", + marginal_x="histogram", + marginal_y="histogram", + ) + assert fig.layout.xaxis2.showgrid + assert fig.layout.xaxis3.showgrid + assert fig.layout.yaxis2.showgrid + assert fig.layout.yaxis3.showgrid + + fig = px.scatter( + tips, + x="total_bill", + y="tip", + marginal_x="histogram", + marginal_y="histogram", + template=dict(layout_yaxis_showgrid=False), + ) + assert fig.layout.xaxis2.showgrid + assert fig.layout.xaxis3.showgrid + assert fig.layout.yaxis2.showgrid is None + assert fig.layout.yaxis3.showgrid is None + + fig = px.scatter( + tips, + x="total_bill", + y="tip", + marginal_x="histogram", + marginal_y="histogram", + template=dict(layout_xaxis_showgrid=False), + ) + assert fig.layout.xaxis2.showgrid is None + assert fig.layout.xaxis3.showgrid is None + assert fig.layout.yaxis2.showgrid + assert fig.layout.yaxis3.showgrid + finally: + # reset defaults to prevent all other tests from failing if this one does + px.defaults.reset() + + +def test_px_defaults(): + px.defaults.labels = dict(x="hey x") + px.defaults.category_orders = dict(color=["b", "a"]) + px.defaults.color_discrete_map = dict(b="red") + fig = px.scatter(x=[1, 2], y=[1, 2], color=["a", "b"]) + try: + assert fig.data[0].name == "b" + assert fig.data[0].marker.color == "red" + assert fig.layout.xaxis.title.text == "hey x" + finally: + # reset defaults to prevent all other tests from failing if this one does + px.defaults.reset() def assert_orderings(days_order, days_check, times_order, times_check):