Skip to content

Commit 2425b0a

Browse files
authored
Merge branch 'master' into version-6-migration
2 parents 2d192e7 + 84f4e66 commit 2425b0a

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

packages/python/plotly/plotly/express/_core.py

+21-21
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,6 @@ def set_cartesian_axis_opts(args, axis, letter, orders):
652652

653653

654654
def configure_cartesian_marginal_axes(args, fig, orders):
655-
if "histogram" in [args["marginal_x"], args["marginal_y"]]:
656-
fig.layout["barmode"] = "overlay"
657-
658655
nrows = len(fig._grid_ref)
659656
ncols = len(fig._grid_ref[0])
660657

@@ -1497,17 +1494,14 @@ def build_dataframe(args, constructor):
14971494
# If data_frame is provided, we parse it into a narwhals DataFrame, while accounting
14981495
# for compatibility with pandas specific paths (e.g. Index/MultiIndex case).
14991496
if df_provided:
1500-
15011497
# data_frame is pandas-like DataFrame (pandas, modin.pandas, cudf)
15021498
if nw.dependencies.is_pandas_like_dataframe(args["data_frame"]):
1503-
15041499
columns = args["data_frame"].columns # This can be multi index
15051500
args["data_frame"] = nw.from_native(args["data_frame"], eager_only=True)
15061501
is_pd_like = True
15071502

15081503
# data_frame is pandas-like Series (pandas, modin.pandas, cudf)
15091504
elif nw.dependencies.is_pandas_like_series(args["data_frame"]):
1510-
15111505
args["data_frame"] = nw.from_native(
15121506
args["data_frame"], series_only=True
15131507
).to_frame()
@@ -1861,7 +1855,6 @@ def _check_dataframe_all_leaves(df: nw.DataFrame) -> None:
18611855
for row_idx, row in zip(
18621856
null_indices_mask, null_mask.filter(null_indices_mask).iter_rows()
18631857
):
1864-
18651858
i = row.index(True)
18661859

18671860
if not all(row[i:]):
@@ -1990,7 +1983,6 @@ def process_dataframe_hierarchy(args):
19901983

19911984
if args["color"]:
19921985
if discrete_color:
1993-
19941986
discrete_aggs.append(args["color"])
19951987
agg_f[args["color"]] = nw.col(args["color"]).max()
19961988
agg_f[f'{args["color"]}{n_unique_token}'] = (
@@ -2045,7 +2037,6 @@ def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFra
20452037
).drop([f"{col}{n_unique_token}" for col in discrete_aggs])
20462038

20472039
for i, level in enumerate(path):
2048-
20492040
dfg = (
20502041
df.group_by(path[i:], drop_null_keys=True)
20512042
.agg(**agg_f)
@@ -2422,7 +2413,6 @@ def get_groups_and_orders(args, grouper):
24222413
# figure out orders and what the single group name would be if there were one
24232414
single_group_name = []
24242415
unique_cache = dict()
2425-
grp_to_idx = dict()
24262416

24272417
for i, col in enumerate(grouper):
24282418
if col == one_group:
@@ -2440,27 +2430,28 @@ def get_groups_and_orders(args, grouper):
24402430
else:
24412431
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
24422432

2443-
grp_to_idx = {k: i for i, k in enumerate(orders)}
2444-
24452433
if len(single_group_name) == len(grouper):
24462434
# we have a single group, so we can skip all group-by operations!
24472435
groups = {tuple(single_group_name): df}
24482436
else:
2449-
required_grouper = list(orders.keys())
2437+
required_grouper = [group for group in orders if group in grouper]
24502438
grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__())
2451-
sorted_group_names = list(grouped.keys())
24522439

2453-
for i, col in reversed(list(enumerate(required_grouper))):
2454-
sorted_group_names = sorted(
2455-
sorted_group_names,
2456-
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
2457-
)
2440+
sorted_group_names = sorted(
2441+
grouped.keys(),
2442+
key=lambda values: [
2443+
orders[group].index(value) if value in orders[group] else -1
2444+
for group, value in zip(required_grouper, values)
2445+
],
2446+
)
24582447

24592448
# calculate the full group_names by inserting "" in the tuple index for one_group groups
24602449
full_sorted_group_names = [
24612450
tuple(
24622451
[
2463-
"" if col == one_group else sub_group_names[grp_to_idx[col]]
2452+
""
2453+
if col == one_group
2454+
else sub_group_names[required_grouper.index(col)]
24642455
for col in grouper
24652456
]
24662457
)
@@ -2487,6 +2478,10 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
24872478
constructor = go.Bar
24882479
args = process_dataframe_timeline(args)
24892480

2481+
# If we have marginal histograms, set barmode to "overlay"
2482+
if "histogram" in [args.get("marginal_x"), args.get("marginal_y")]:
2483+
layout_patch["barmode"] = "overlay"
2484+
24902485
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
24912486
args, constructor, trace_patch, layout_patch
24922487
)
@@ -2558,7 +2553,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
25582553
legendgroup=trace_name,
25592554
showlegend=(trace_name != "" and trace_name not in trace_names),
25602555
)
2561-
if trace_spec.constructor in [go.Bar, go.Violin, go.Box, go.Histogram]:
2556+
2557+
# Set 'offsetgroup' only in group barmode (or if no barmode is set)
2558+
barmode = layout_patch.get("barmode")
2559+
if trace_spec.constructor in [go.Bar, go.Box, go.Violin, go.Histogram] and (
2560+
barmode == "group" or barmode is None
2561+
):
25622562
trace.update(alignmentgroup=True, offsetgroup=trace_name)
25632563
trace_names.add(trace_name)
25642564

packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py

+21
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,27 @@ def test_orthogonal_orderings(backend, days, times):
289289
assert_orderings(backend, days, days, times, times)
290290

291291

292+
def test_category_order_with_category_as_x(backend):
293+
# https://github.com/plotly/plotly.py/issues/4875
294+
tips = nw.from_native(px.data.tips(return_type=backend))
295+
fig = px.bar(
296+
tips,
297+
x="day",
298+
y="total_bill",
299+
color="smoker",
300+
barmode="group",
301+
facet_col="sex",
302+
category_orders={
303+
"day": ["Thur", "Fri", "Sat", "Sun"],
304+
"smoker": ["Yes", "No"],
305+
"sex": ["Male", "Female"],
306+
},
307+
)
308+
assert fig["layout"]["xaxis"]["categoryarray"] == ("Thur", "Fri", "Sat", "Sun")
309+
for trace in fig["data"]:
310+
assert set(trace["x"]) == {"Thur", "Fri", "Sat", "Sun"}
311+
312+
292313
def test_permissive_defaults():
293314
msg = "'PxDefaults' object has no attribute 'should_not_work'"
294315
with pytest.raises(AttributeError, match=msg):

0 commit comments

Comments
 (0)