python 用于混淆矩阵的具有标绘归一化的自定义颜色

az31mfrm  于 2023-08-02  发布在  Python
关注(0)|答案(2)|浏览(159)

我在试着创建一个混淆矩阵

  • 中性值为0,应显示为白色
  • 正值应显示为绿色,数字越大,绿色越多。最接近0,绿色较少(与白色混合)
  • 负值应显示为红色,数字越小,红色越深。最接近0,较少红色(与白色混合)

我想要一个以0为中心的绿色渐变。

import plotly.graph_objects as go
import numpy as np

# Define the colors for 0, positive, and negative values
zero_color = 'white'
positive_color = 'green'
negative_color = 'red'

colors = [
    (0, negative_color),
    (np.min(abs(custom_cf_matrix)) / np.max(abs(custom_cf_matrix)), zero_color),  # PROBABLY HERE IS WHAT I AM DOING WRONG
    (1, positive_color)
]

custom_cf_matrix = np.array([[395, -5], [-200, 20]])

group_values = [395, -5, -200, 20]

# labels = [f"{v1}\n{v2}" for v1, v2 in zip(group_names, group_percentages)]

labels = [f"{v1}\n{v2}" for v1, v2 in zip(group_names, group_values)]
labels = np.asarray(labels).reshape(2, 2)

# Create the figure
fig = px.imshow(
    custom_cf_matrix,
    labels={"x": "Predicted Label", "y": "True Label"},
    color_continuous_scale=colors,
    range_color=[np.min(custom_cf_matrix), np.max(custom_cf_matrix)],
    width=500,
    height=500,
)

fig.update_xaxes(side="bottom")
fig.update_yaxes(side="left")

# Update the annotations to use black font color
annotations = [
    dict(
        text=text,
        x=col,
        y=row,
        font=dict(color="black", size=16),  # Set font color to black
        showarrow=False,
        xanchor="center",
        yanchor="middle",
    )
    for row in range(2)
    for col, text in zip(range(2), labels[row])
]

fig.update_layout(
    title="Value-Weighted Confusion Matrix",
    title_x=0.25,  # Center the title horizontally
    annotations=annotations,
)

fig.update_xaxes(tickvals=[0, 1], ticktext=["0", "1"], showticklabels=True)
fig.update_yaxes(tickvals=[0, 1], ticktext=["0", "1"], showticklabels=True)

字符串

xytpbqjk

xytpbqjk1#

也许这就是你的意思?

smallest_number = custom_cf_matrix.min()
largest_number = custom_cf_matrix.max()

# Create a custom color scale
colors = [
    (0, negative_color),
    (0.5, zero_color),  # Normalize the midpoint value to 0.5
    (1, positive_color)
]

# Create the figure
fig = px.imshow(
    custom_cf_matrix,
    labels={"x": "Predicted Label", "y": "True Label"},
    color_continuous_scale=colors,
    range_color=[-max(largest_number, -smallest_number), max(largest_number, -smallest_number)],
    width=500,
    height=500,
)

字符串

3hvapo4f

3hvapo4f2#

解决了!
解决方案是将最大值和最小值归一化。然后找出这个标准化之间的0:

# Find the maximum and minimum values in the array
max_value = np.max(custom_cf_matrix)
min_value = np.min(custom_cf_matrix)

# Find the normalized value for 0
normalized_zero = (0 - min_value) / (max_value - min_value)

colors = [
    (0, negative_color),
    (normalized_zero, zero_color),  # Normalize the midpoint value to 0.5
    (1, positive_color)
]

字符串

相关问题