pandas 将颜色分配给绘图直方图的过程

b91juud3  于 11个月前  发布在  其他
关注(0)|答案(1)|浏览(109)

我正在尝试开始使用Dash,我有一个简单的问题。我有一个数据表

Index, Product, Customer_Age, Revenue
1, A, 12, 10
2, B, 99, 12


我想把每个产品的收入画成直方图。然而,我想有酒吧分为不同的颜色为不同的年龄组,说组10。我如何做到这一点?目前,我与

from dash import Dash, html, dash_table, dcc
import plotly.express as px, pandas as pd
data = pd.read_csv('data.csv')
app = Dash(__name__)
app.layout = html.Div([
    dcc.Graph(figure=px.histogram(data, x='Product', y='Revenue', histfunc='sum', color='Customer_Age')),
])
if __name__ == '__main__':
    app.run(debug=True)

每个年龄段都有不同的颜色此外,颜色是相当随机的,而不是一个很好的连续的颜色序列。有没有一种优雅的方式来实现我想要的?

xytpbqjk

xytpbqjk1#

这里是一个示例解决方案,除了直方图之外,还使用条形图来可视化分组数据

在您的特定用例中,由于您对类别的强调,使用条形图而不是直方图来可视化数据可能更有意义。通常,直方图显示来自一个 * 连续 (或至少有效地被视为连续)数值类别(变量)的一系列(有序样本)值的分箱分布。所以,可视化你所描述的颜色编码的年龄分组分布有点复杂。( 要显示多少个直方图?每种颜色一个?* )
在这个例子中,我计算了收入的总和,以条形图的子图形式显示,每个产品一个子图,条形图按年龄段分组,时间间隔设置为10年。当应用程序中的任何栏悬停在上方时,子图下方的虚线表将更新以仅显示悬停的数据。

*下面,我修改了这个演示应用程序,以额外( 响应每个用户悬停交互 * [* 即,通过回调触发 *]**)更新单个直方图,然后显示数据的 * 分布 *,当前悬停的条形表示对这些值计算的 * 总和 *。

例如,在一个示例中,

import random

import dash
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from dash import Dash, html, dcc, dash_table
from dash.dependencies import Input, Output, State
from plotly.subplots import make_subplots

# Sample data
n = 500
ages = list(range(14, 83))
interval = 10
data = pd.DataFrame(
    {
        "Index": list(range(n)),
        "Product": random.choices(["A", "B"], k=n),
        "Customer_Age": random.choices(ages, k=n),
        "Revenue": random.choices(list(range(10, 40, 2)), k=n),
    }
)

# Bin ages into groups of 10
bins = list(range(min(ages) - 1, max(ages) + interval, interval))

labels = [f"{i}-{i+9}" for i in bins[:-1]]
data["Age_Group"] = pd.cut(
    data["Customer_Age"], bins=bins, labels=labels, right=True
)

# Group data by product and age group and sum the revenue
grouped = (
    data.groupby(["Product", "Age_Group"]).sum()["Revenue"].reset_index()
)
print(grouped)

# Sort products alphabetically
products = sorted(data["Product"].unique())

# Create subplot layout with shared y-axis
fig = make_subplots(
    rows=1, cols=2, subplot_titles=products, shared_yaxes=True
)

# Define a custom color map for age groups
# Sequential HSL-like colormap
color_map = {
    label: f"hsl({i * 360 / len(labels)}, 100%, 50%)"
    for i, label in enumerate(labels)
}
# Standard colors:
# color_map = {
#     label: color for label, color in zip(labels, px.colors.qualitative.Set1)
# }

# For each product, plot the revenue by age group as a bar chart
for idx, product in enumerate(products):
    product_data = grouped[grouped["Product"] == product]
    for age_group in labels:
        age_data = product_data[product_data["Age_Group"] == age_group]
        if not age_data.empty:
            revenue = age_data["Revenue"].values[0]
        else:
            revenue = 0
        # Only show legend for the first subplot but ensure both linked
        showlegend = True if idx == 0 else False
        fig.add_trace(
            go.Bar(
                x=[age_group],
                y=[revenue],
                name=age_group,
                marker_color=color_map[age_group],
                showlegend=showlegend,
                legendgroup=age_group,
            ),
            row=1,
            col=idx + 1,
        )

fig.update_layout(barmode="group", title="Revenue by Product and Age Group")

app = Dash(__name__)

app.layout = html.Div(
    [
        dcc.Graph(id="revenue-graph", figure=fig),
        html.Div(
            [
                html.Button("Reset Table", id="reset-button", n_clicks=0),
                html.Button("Show/Hide Sum", id="toggle-button", n_clicks=0),
            ],
            style={"textAlign": "center", "margin": "20px"},
        ),
        html.Div(  # Container Div using CSS Grid
            style={
                "display": "grid",
                "gridTemplateColumns": "1fr 1fr",
                "gap": "10px",
            },
            children=[
                dcc.Graph(id="histogram-plot"),
                dash_table.DataTable(
                    id="data-table",
                    columns=[{"name": i, "id": i} for i in data.columns],
                    data=grouped.to_dict("records"),
                    style_table={"height": "500px", "overflowY": "auto"},
                    style_cell={"textAlign": "center"},
                ),
            ],
        ),
    ],
    style={"padding": "20px", "margin": "5%"},
)

# Determine number of bins
nbins = 10

@app.callback(
    [Output("data-table", "data"), Output("histogram-plot", "figure"),],
    [
        Input("revenue-graph", "hoverData"),
        Input("reset-button", "n_clicks"),
        Input("toggle-button", "n_clicks"),
    ],
)
def update_output(hoverData, reset_clicks, toggle_clicks):
    ctx = dash.callback_context

    if not ctx.triggered_id:
        return (
            grouped.to_dict("records"),
            {},
        )

    if ctx.triggered_id == "reset-button":
        return (
            grouped.to_dict("records"),
            {},
        )

    elif ctx.triggered_id == "toggle-button":
        if toggle_clicks % 2 == 1:
            return (
                data.to_dict("records"),
                {},
            )
        else:
            return (
                grouped.to_dict("records"),
                {},
            )

    elif ctx.triggered_id == "revenue-graph":
        if hoverData:
            point_data = hoverData["points"][0]
            curve_number = point_data["curveNumber"]
            product = products[curve_number // len(labels)]
            age_group = point_data["x"]
            filtered_data = data[
                (data["Product"] == product)
                & (data["Age_Group"] == age_group)
            ]
        else:
            return (
                dash.no_update,
                dash.no_update,
            )

        histogram = px.histogram(
            filtered_data,
            x="Revenue",
            color_discrete_sequence=["rgba(0, 0, 0, 0.1)"],
            nbins=nbins,
            histnorm="probability density",
        )
        histogram.update_traces(
            marker_line_color="black", marker_line_width=1
        )
        histogram.update_layout(
            yaxis_range=[0, 0.1], xaxis_range=[0, 50],
        )

        if toggle_clicks % 2 == 1:
            return (
                filtered_data.to_dict("records"),
                histogram,
            )
        else:
            sum_data = grouped[
                (grouped["Product"] == product)
                & (grouped["Age_Group"] == age_group)
            ]
            return (
                sum_data.to_dict("records"),
                histogram,
            )

if __name__ == "__main__":
    app.run_server(debug=True, dev_tools_hot_reload=True)

结果:

相关问题