matplotlib 条件行颜色导致ValueError

8ljdwjyq  于 2023-10-24  发布在  其他
关注(0)|答案(2)|浏览(98)

我试图用matplotlib.pytplot绘制一个线图,并有一个形状为(42,7)的三角形框架df。三角形框架具有以下结构(仅显示相关列):

timepoint   value   point
2021-01-01   10      0
2021-02-01   20      0
....
2021-11-01   10      0
2021-12-01   50      1
2022-01-01   60      1
...

我尝试用以下方式绘制条件颜色(每个point=0的值是蓝色,每个point=1的值是红色):

import numpy as np
col = np.where(df['point'] == 0, 'b', 'r')

plt.plot(df['timepoint'], df['value'], c=col)
plt.show()

我得到错误消息:
ValueError:阵列('b ',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' b','b',' r','r',' r','r',' r','r',' r','r',' r','r',' r','r',' r','r',' r','r',' r','r',' r','r','r'],dtype='<U1')不是颜色的有效值
当我看这个问题ValueError: Invalid RGBA argument: What is causing this error?时,我没有找到任何解决方案,因为我的颜色数组的形状是:col.shape(42, )

8nuwlpux

8nuwlpux1#

我不认为你可以在这里使用np.where来定义你的线图的颜色(就像他们用here定义散点图一样),因为在前者中,c参数期望整条线都是一种颜色,而在后者中,你可以提供一个颜色数组来Map到点。
所以,这里有一个可能的选择来解决这个问题:

(_, p0), (_, others) = df.groupby(df["point"].eq(0), sort=False)

plt.plot(p0.timepoint, p0.value, c="b", label="0 Points")
plt.plot(others.timepoint, others.value, c="r", label="Others")

plt.legend()
plt.show();

或者使用pivot/plot

other_points = df["point"].loc[lambda s: s.ne(0)].unique()

(
    df.pivot(index="timepoint", columns="point", values="value")
        .plot(style={0: "b", **{k: "r" for k in other_points}})
)

输出量:

2ekbmq32

2ekbmq322#

您遇到的错误是因为plt.plot()函数中的c参数期望每个数据点的有效颜色规范,但您提供的颜色名称数组('b'表示蓝色,'r'表示红色)不是c参数的有效值。
要绘制线图的条件颜色,您可以使用循环遍历DataFrame行,并使用所需的颜色绘制每个线段。以下是您如何做到这一点(使用示例框架):

import pandas as pd
import matplotlib.pyplot as plt

# Sample DataFrame
data = {
    'timepoint': ['2021-01-01', '2021-02-01', '2021-11-01', '2021-12-01', '2022-01-01'],
    'value': [10, 20, 10, 50, 60],
    'point': [0, 0, 0, 1, 1]
}

df = pd.DataFrame(data)

# Convert 'timepoint' column to datetime type
df['timepoint'] = pd.to_datetime(df['timepoint'])

# Initialize variables to keep track of the previous point value and color
prev_point = None
prev_color = None

fig, ax = plt.subplots()

for index, row in df.iterrows():
    # Check the 'point' value for the current row
    current_point = row['point']

    # Set the color based on the 'point' value
    if current_point == 0:
        color = 'b'  # Blue for point=0
    else:
        color = 'r'  # Red for point=1

    # If the point value has changed, create a new segment in the plot
    if prev_point is not None and current_point != prev_point:
        ax.plot(df.loc[df.index[index - 1]:index - 1, 'timepoint'], df.loc[df.index[index - 1]:index - 1, 'value'], c=prev_color, label=f'Point {prev_point}', marker='o')  # Add marker='o' to display points

    ax.plot([row['timepoint']], [row['value']], c=color, marker='o')  # Add marker='o' to display points

    # Update previous point and color
    prev_point = current_point
    prev_color = color

# Add labels, legend, and show the plot
ax.set_xlabel('Timepoint')
ax.set_ylabel('Value')
ax.legend()
plt.xticks(rotation=45)  # Rotate x-axis labels for better readability
plt.show()

我们也可以使用np.where方法。我们可以使用plt.scatter来实现,如下面的代码:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Sample DataFrame
data = {
    'timepoint': ['2021-01-01', '2021-02-01', '2021-11-01', '2021-12-01', '2022-01-01'],
    'value': [10, 20, 10, 50, 60],
    'point': [0, 0, 0, 1, 1]
}

df = pd.DataFrame(data)

# Convert 'timepoint' column to datetime type
df['timepoint'] = pd.to_datetime(df['timepoint'])

# Use np.where to conditionally set colors
col = np.where(df['point'] == 0, 'b', 'r')

# Create the scatter plot with conditional colors
plt.scatter(df['timepoint'], df['value'], c=col, s=50)

# Show the plot
plt.xlabel('Timepoint')
plt.ylabel('Value')
plt.title('Conditional Scatter Plot')
plt.xticks(rotation=45)  # Rotate x-axis labels for better readability
plt.grid(True)
plt.show()

这是更新后的代码,用于创建线条图,但使用循环来区分颜色:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Sample DataFrame
data = {
    'timepoint': ['2021-01-01', '2021-02-01', '2021-11-01', '2021-12-01', '2022-01-01'],
    'value': [10, 20, 10, 50, 60],
    'point': [0, 0, 0, 1, 1]
}

df = pd.DataFrame(data)

# Convert 'timepoint' column to datetime type
df['timepoint'] = pd.to_datetime(df['timepoint'])

# Use np.where to conditionally set colors
col = np.where(df['point'] == 0, 'b', 'r')

# Create the line graph with conditional colors
fig, ax = plt.subplots()

for color, group in df.groupby(col):
    ax.plot(group['timepoint'], group['value'], label=color)

# Add labels, legend, and show the plot
ax.set_xlabel('Timepoint')
ax.set_ylabel('Value')
plt.xticks(rotation=45)  # Rotate x-axis labels for better readability
plt.grid(True)
plt.legend()
plt.show()


你可以查看我的kaggle here

相关问题