10000 Stacked grouped bar chart by RenaudLN · Pull Request #4486 · plotly/plotly.py · GitHub
[go: up one dir, main page]

Skip to content

Stacked grouped bar chart #4486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add comments throughout
  • Loading branch information
RenaudLN committed Jan 20, 2024
commit c12b610415d6cb9c3f7f4bb5b4a4679e74063e4f
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
from copy import deepcopy
import colorsys
from plotly.graph_objects import Bar
from plotly.express._core import build_dataframe
from plotly.express._doc import make_docstring
from plotly.express._chart_types import bar
from copy import deepcopy
from typing import Union

import plotly.colors as pyc
import plotly.io as pio
import plotly.graph_objects as go
from plotly.express._core import build_dataframe
from plotly.express._doc import make_docstring
from plotly.express._chart_types import bar


def get_colors(base_color: Union[str, tuple], n_colors: int):
"""Get a palette of colors derived from base color.

This function leverages the HLS color space.
"""
hls_tuple = colorsys.rgb_to_hls(
*pyc.convert_colors_to_same_type(base_color, "tuple")[0][0]
)
if n_colors == 1:
return [base_color]

if n_colors == 2:
light = colorsys.hls_to_rgb(
hls_tuple[0], min(1, max(0.75, hls_tuple[1] + 0.2)), hls_tuple[2]
)
return pyc.sample_colorscale([light, base_color], n_colors)

light = colorsys.hls_to_rgb(
hls_tuple[0], min(1, max(0.8, hls_tuple[1] + 0.2)), hls_tuple[2]
)
dark = colorsys.hls_to_rgb(
hls_tuple[0], max(0, min(0.3, hls_tuple[1] - 0.2)), hls_tuple[2]
)
return pyc.sample_colorscale([light, base_color, dark], n_colors)


def create_grouped_stacked_bar(
Expand Down Expand Up @@ -41,17 +68,19 @@ def create_grouped_stacked_bar(
"""
Returns a bar chart with grouped and stacked bars.
"""

# Leverage the `build_dataframe` function twice to create the dataframe
# with color and stack_group columns
args = deepcopy(locals())
if data_frame is not None:
df_copy = deepcopy(data_frame)
if color is not None:
color_copy = deepcopy(color)
args = build_dataframe(args=args, constructor=Bar)
args = build_dataframe(args=args, constructor=go.Bar)
df_color = args["data_frame"].copy()
args["color"] = stack_group
args["data_frame"] = df_copy
args = build_dataframe(args=args, constructor=Bar)

args = build_dataframe(args=args, constructor=go.Bar)
color_col = color if isinstance(color, str) else "color"
group_col = stack_group if isinstance(stack_group, str) else "stack_group"
x_col = x if isinstance(x, str) else "x"
Expand All @@ -62,8 +91,12 @@ def create_grouped_stacked_bar(

data_frame = args.pop("data_frame").sort_values([color_col, group_col, x_col])
hover_data = args.pop("hover_data") or [group_col]

# Remove arguments that can't be passed to px.bar, they are used separately
args.pop("stack_group")
args.pop("hover_unified")

# Create the groups metadata, including their order and bar width
groups = list(data_frame[group_col].unique())
if category_orders is not None and group_col in category_orders:
groups = [g for g in category_orders[group_col] if g in groups] + [
Expand All @@ -75,40 +108,20 @@ def create_grouped_stacked_bar(
group_width = (1 - stack_group_gap - (n_groups - 1) * bar_gap) / n_groups
n_colors = data_frame[color_col].nunique()

def get_colors(base_color, n_colors):
hls_tuple = colorsys.rgb_to_hls(
*pyc.convert_colors_to_same_type(base_color, "tuple")[0][0]
)
if n_colors == 1:
return [base_color]

if n_colors == 2:
light = colorsys.hls_to_rgb(
hls_tuple[0], min(1, max(0.75, hls_tuple[1] + 0.2)), hls_tuple[2]
)
return pyc.sample_colorscale([light, base_color], n_colors)

light = colorsys.hls_to_rgb(
hls_tuple[0], min(1, max(0.8, hls_tuple[1] + 0.2)), hls_tuple[2]
)
dark = colorsys.hls_to_rgb(
hls_tuple[0], max(0, min(0.3, hls_tuple[1] - 0.2)), hls_tuple[2]
)
return pyc.sample_colorscale([light, base_color, dark], n_colors)

# Retrieve the template information to create groups with the right colors
if template is None:
if pio.templates.default is not None:
template = pio.templates.default
else:
template = "plotly"

try:
# retrieve the actual template if we were given a name
template = pio.templates[template]
except Exception:
# otherwise try to build a real template
template = go.layout.Template(template)

# `color_discrete_sequence` can be used to override the template colors
color_discrete_sequence = args.pop("color_discrete_sequence")
if color_discrete_sequence is None:
color_discrete_sequence = template.layout.colorway
Expand All @@ -117,13 +130,16 @@ def get_colors(base_color, n_colors):
pyc.get_colorscale(color_discrete_sequence)
)

# Manage the orientation
value_axis = "y"
base_axis = "x"
if orientation == "h":
value_axis = "x"
base_axis = "y"

fig = None
# Create the figures for each group then combine into one,
# with overlapping y-axis (or x-axis if horizontal)
for i, group in enumerate(groups):
group_df = data_frame.query(f"{group_col} == @group")
n_colors = group_df[color_col].nunique()
Expand All @@ -148,6 +164,7 @@ def get_colors(base_color, n_colors):
fig = group_fig
else:
fig.add_traces(group_fig.data)
# Ensure the y-axes (or x-axes) overlap and match
fig.update_layout(
**{
f"{value_axis}axis{i + 1}": {
Expand All @@ -158,7 +175,12 @@ def get_colors(base_color, n_colors):
}
}
)

# Set the base axis type to category to work well with groups
fig.update_layout(**{f"{base_axis}axis_type": "category"})

# Optionally unify the hover, with a modification of the hovertemplate
# to have a nice display
if hover_unified:
fig.update_layout(hovermode="x unified").update_traces(hovertemplate="%{y}")

Expand Down
0