matplotlib 使用fig.canvas.tostring_rgb()时图像重复

mbzjlibv  于 2023-06-06  发布在  其他
关注(0)|答案(1)|浏览(391)

我使用matplotlib==3.3.4绘制3D数据:

fig = plt.figure(figsize=(15, 10))
ax = fig.gca(projection="3d")
ax.view_init(30, 0)

# facecolors is a 3D volume with some processing
ax.voxels(
    x, y, z, facecolors[:, :, :, -1] != 0, facecolors=facecolors, shade=False
)
fig.canvas.draw()
image_flat = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8")
image_shape = (*fig.canvas.get_width_height(), 3)  # (1500, 1000, 3)
ax.imshow(image_flat.reshape(*image_shape))
plt.show()

(我在BraTS20_3dUnet_3dAutoEncoder上做了一些改进,灵感来自Figure to image as a numpy array)。
然而,当我实际绘制图像时,有两个副本:

我做错了什么?我不知道第二个图像是从哪里来的。

mrphzbgm

mrphzbgm1#

NumPy数组的排序是(rows,cols,ch)。代码image_shape = (*fig.canvas.get_width_height(), 3)切换了rowscols,导致输出图像形状不正确,看起来像两个副本。
image_shape = (*fig.canvas.get_width_height(), 3)替换为:

image_shape = (*fig.canvas.get_width_height()[::-1], 3)

为了避免混淆,我们最好使用两行代码:

cols, rows = fig.canvas.get_width_height()
image_shape = (rows, cols, 3)

可重现的示例(使用来自here的数据):

import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize=(15, 10))
ax = fig.gca(projection="3d")
ax.view_init(30, 0)

# https://stackoverflow.com/questions/76387953/image-duplicated-when-using-matplotlib-fig-canvas-tostring-rgb
# prepare some coordinates
x, y, z = np.indices((8, 8, 8))

# draw cuboids in the top left and bottom right corners, and a link between
# them
cube1 = (x < 3) & (y < 3) & (z < 3)
cube2 = (x >= 5) & (y >= 5) & (z >= 5)
link = abs(x - y) + abs(y - z) + abs(z - x) <= 2

# combine the objects into a single boolean array
voxelarray = cube1 | cube2 | link

# set the colors of each object
colors = np.empty(voxelarray.shape, dtype=object)
colors[link] = 'red'
colors[cube1] = 'blue'
colors[cube2] = 'green'

# and plot everything
#ax = plt.figure().add_subplot(projection='3d')
ax.voxels(voxelarray, facecolors=colors, edgecolor='k')

fig.canvas.draw()
image_flat = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8")
#image_shape = (*fig.canvas.get_width_height(), 3)  # (1500, 1000, 3)
#image_shape = (*fig.canvas.get_width_height()[::-1], 3)  # It should be (1000, 1500, 3) instead of (1500, 1000, 3)
cols, rows = fig.canvas.get_width_height()
image_shape = (rows, cols, 3)
img = image_flat.reshape(*image_shape)

plt.figure()
plt.imshow(img)
plt.show()

修复代码之前的输出图像:

修复代码后的输出图像:

相关问题