在matplotlib pcolor grid中通过使用值数组使用网格线绘制边框

monwx1rj  于 2023-03-30  发布在  其他
关注(0)|答案(1)|浏览(152)

我有两个数组,我使用颜色在一个单独的节点网格上绘制。一个表示一些集群,另一个表示一些其他值,我将其称为我的特征。示例代码:

import numpy as np
import matplotlib.pyplot as plt

clusters = np.array([[0,2,1], [0,3,1], [3,3,1]]) # make cluster data
features= np.array([[0,0.4,0.7], [0.1,0.3,0.7], [0.5,0.4,0.8]]) # make data of features

# plot clusters
plt.figure()
plt.pcolor(clusters, cmap='jet') # color in nodes
plt.colorbar(ticks=[i for i in range(0, np.amax(clusters)+1)]) # make colorbar legend per cluster
plt.show()

# plot feature grid
plt.figure()
plt.pcolor(features, cmap='bone', vmin=0, vmax=1) # color in nodes
plt.colorbar() # make colorbar legend
plt.show()

在第二个带有特征数据的灰色网格中,我希望通过网格线显示聚类之间的边界。预期结果如下所示,其中红线表示聚类之间的边界:

有没有什么方法可以使用集群数组的数据自动绘制这些网格线?任何帮助将不胜感激!

ecbunoof

ecbunoof1#

这是可行的:

vlines, hlines = [], []
# loop over array with clusters to obtain positions of vertical and horizontal lines
for row_idx, row in enumerate(clusters): # loop over array rows
    for col_idx, col in enumerate(row): # loop over array columns per row
        if col_idx+1 < clusters.shape[1]: # skip final column, has no right-side neighbouring node
            # save tuple if it requires a vertical line indicating a different cluster right of the node
            if clusters[row_idx, col_idx] != clusters[row_idx, col_idx+1]: vlines.append((row_idx, col_idx+1))
        if row_idx+1 < clusters.shape[0]: # skip final row, has no bottom neighbouring node
            # save a tuple if it requires a horizontal line indicating a different cluster below the node
            if clusters[row_idx, col_idx] != clusters[row_idx+1, col_idx]: hlines.append((row_idx+1, col_idx))

# make features plot
plt.figure()
plt.pcolor(features, cmap='bone', vmin=0, vmax=1) # color in nodes
plt.colorbar() # make colorbar legend
for vline in vlines: # plot all vertical lines
    plt.axvline(vline[1], ymin=vline[0]/clusters.shape[0], ymax=vline[0]/clusters.shape[0] + 1/clusters.shape[0], color='red')
for hline in hlines: # plot all horizontal lines
    plt.axhline(hline[0], xmin=hline[1]/clusters.shape[0], xmax=hline[1]/clusters.shape[0] + 1/clusters.shape[0], color='red')
plt.show()

相关问题