matplotlib bar3d图中的错误重叠

kqqjbcuj  于 2023-10-24  发布在  其他
关注(0)|答案(2)|浏览(149)

我已经做了这个3D条形图,但我发现一些条形图有错误的重叠,如下图中的绿色圆圈所示:

该图是由:

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d.axes3d import Axes3D
import matplotlib.colors as colors

fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection='3d')    
matrix = np.array([
[84 80 68 56 60 44 55 39 27 29]
[82 67 63 44 47 33 22 19  9  2]
[53 61 48 34  0 16  0  0  0  0]
[48 25  0  0  0  0  0  0  0  0]])

len_x, len_y = matrix.shape
_x = np.arange(len_x)
_y = np.arange(len_y)

xpos, ypos = np.meshgrid(_x, _y)
xpos = xpos.flatten('F')
ypos = ypos.flatten('F')
zpos = np.zeros_like(xpos)

dx = np.ones_like(zpos)
dy = dx.copy()
dz = matrix.flatten()

cmap=plt.cm.magma(plt.Normalize(0,100)(dz))

ax.bar3d(xpos+0.32, ypos-0.3, zpos, dx-0.6, dy-0.1, dz, zsort='max', color=cmap)

ax.set_xlabel('x')
ax.set_xticks(np.arange(len_x+1))
ax.set_xticklabels(['1000','500','100','50','0'])
ax.set_xlim(0,4)
ax.set_ylabel('y')
ax.set_yticks(np.arange(len_y+1))
ax.set_yticklabels(['0.5','1.','1.5','2.','2.5','3.','3.5','4.','4.5','5.'])
ax.set_ylim(-0.5,10)
ax.set_zlabel('z')
ax.set_zlim(0,100)
ax.view_init(ax.elev, ax.azim+100)

这是一个错误吗?为什么有些条严重重叠?我使用matplotlib版本2.1.0和anaconda python 3.6.3

yv5phkfx

yv5phkfx1#

正如@DavidG在评论中指出的那样,这是一个没有理想解决方案的问题:

我的3D图在某些视角下看起来不正确

这可能是mplot 3d最常报告的问题。问题是-从某些视角-一个3D对象会出现在另一个对象的前面,即使它实际上在它后面。这可能导致图看起来不“物理正确”。
不幸的是,虽然正在做一些工作来减少这种伪影的发生,但它目前是一个棘手的问题,并且在matplotlib支持其核心的3D图形渲染之前无法完全解决。
[来源]
然而,我能够通过玩弄情节的视角和减少酒吧之间的接触面积来大大减少这个问题。
例如,为了改变视角(“相机位置”),我使用:

ax.view_init(elev=30, azim=-60) # Changes the elevation and azimuth

更多细节在how to set “camera position” for 3d plots using python/matplotlib?
根据接触面积,这取决于你的图。在我的例子中,所有的条都在y轴旁边接触,所以我只是稍微减少了dy参数,在条之间留下一些间隙。

xoefb8l8

xoefb8l82#

我把它变成了一个相对通用的函数,它接受DataFrame并绘制它。这与最初的问题无关,但仍然很有用。

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d.axes3d import Axes3D
import matplotlib.colors as colors

def plot_3d_bar(df):
    y_labels = list(df.columns)
    x_labels = list(df.index.to_series())
    fig = plt.figure(figsize=(10,8))
    ax = fig.add_subplot(111, projection='3d')

    matrix = df.values
    len_x, len_y = matrix.shape
    _x = np.arange(len_x)
    _y = np.arange(len_y)

    xpos, ypos = np.meshgrid(_x, _y)
    xpos = xpos.flatten('F')
    ypos = ypos.flatten('F')
    zpos = np.zeros_like(xpos)

    dx = np.ones_like(zpos)
    dy = dx.copy()
    dz = matrix.flatten()
    cmap=plt.cm.magma(plt.Normalize(0,max(dz))(dz))

    ax.bar3d(xpos+0.32, ypos-0.3, zpos, dx-0.6, dy-0.6, dz, zsort='max', color=cmap)
    
    ax.set_xlabel('x')
    ax.set_xticks(np.arange(len_x))
    ax.set_xticklabels(x_labels)
    ax.set_xlim(0, len_x)

    ax.set_ylabel('y')
    ax.set_yticks(np.arange(len_y))
    ax.set_yticklabels(y_labels)
    ax.set_ylim(-0.5, len_y)

    ax.set_zlabel('z')
    #ax.set_zlim(0,3000)

    ax.view_init(elev=30, azim=-60)

    plt.show()

相关问题