matplotlib 根据条件着色

vxbzzdmp  于 2023-05-18  发布在  其他
关注(0)|答案(1)|浏览(106)

我使用以下代码来显示绘图。我希望根据阈值0.05有条件地格式化这些值。

from matplotlib.colors import to_rgba

# Generate x-axis
x = np.linspace(0, len(data_formula), len(data_formula))
colors = np.where(data_formula <= 0.05, "blue", "green")
plt.plot(x, data_formula, c=colors)

# Add labels and title
plt.ylabel('Volume')
plt.xlabel('time')
plt.title('Energy')

# Display the plot
plt.show()

不幸的是,我收到了错误:array(['blue', 'blue', 'blue', ..., 'blue', 'blue', 'blue'], dtype='<U5') is not a valid value for color,这表明我向color参数传递了错误的值。我试过用列表等。但好像不管用。哪里出了问题?参数color是不接受任何数据结构还是格式错误?
供参考:data_formula定义如下:

def energy(x):
    
    return (x**2)

data_formula = np.apply_along_axis(energy, axis=0, arr=data_normalized)

数据类型:numpy.ndarray

vsikbqxv

vsikbqxv1#

要使用不同的颜色,您应该使用其颜色分割数据段:

#%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, len(data_formula), len(data_formula))
crossings = np.where(np.diff(data_formula > 0.05))[0]
start = 0

# Plot segments
for ind in crossings:
    plt.plot(x[start:ind+1], data_formula[start:ind+1], 
             color='blue' if data_formula[start] <= 0.05 else 'green')
    start = ind+1

plt.plot(x[start:], data_formula[start:], 
         color='blue' if data_formula[start] <= 0.05 else 'green')

plt.ylabel('Volume')
plt.xlabel('time')
plt.title('Energy')
plt.show()

我假设你的数据是这样的:

data_formula = np.array([0, 0.01, 0.02, 0.04, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])

如果您只需要点而不是线段,请执行以下操作:

import numpy as np
import matplotlib.pyplot as plt

for x, y in enumerate(data_formula):
    plt.scatter(x, y, color='blue' if y <= 0.05 else 'green')

    
plt.ylabel('Volume')
plt.xlabel('time')
plt.title('Energy')
plt.show()

这最后一个灵感来自stackoverflow: different colors series scatter plot on matplotlib
如果前面的答案需要很长时间,请尝试:

import numpy as np
import matplotlib.pyplot as plt

# data_formula = np.array([0, 0.01, 0.02, 0.04, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])

xy_blue = [(x, y) for x, y in enumerate(data_formula) if y <= 0.05]
xy_gree = [(x, y) for x, y in enumerate(data_formula) if y > 0.05]

plt.scatter([t[0] for t in xy_blue], [t[1] for t in xy_blue], color='blue')
plt.scatter([t[0] for t in xy_gree], [t[1] for t in xy_gree], color='green')

    
plt.ylabel('Volume')
plt.xlabel('time')
plt.title('Energy')
plt.show()

无论如何,我从来没有见过一个图需要超过30秒,所以也许你必须重新加载你的环境,你有太多的数据绘制,应该作出不同的图,或者你有一个低RAM的问题。

相关问题