numpy 如何在matplotlib.pyplot中删除子图之间的空格?

tvmytwxo  于 2022-12-04  发布在  其他
关注(0)|答案(5)|浏览(121)

我正在做一个项目,需要制作一个10行3列的绘图网格。虽然我已经能够制作绘图并安排子绘图,但我无法制作出一个没有白色的好绘图,比如下面的gridspec documentatation. x1c 0d1x。
我尝试了下面的帖子,但是仍然不能完全去除图片中的白色。有人能给予我一些指导吗?谢谢!

这是我的照片:

下面是我的代码.The full script is here on GitHub.注意:images_2和images_fool都是形状为(1032,10)的展平图像的numpy数组,而delta是形状为(28,28)的图像数组。

def plot_im(array=None, ind=0):
    """A function to plot the image given a images matrix, type of the matrix: \
    either original or fool, and the order of images in the matrix"""
    img_reshaped = array[ind, :].reshape((28, 28))
    imgplot = plt.imshow(img_reshaped)

# Output as a grid of 10 rows and 3 cols with first column being original, second being
# delta and third column being adversaril
nrow = 10
ncol = 3
n = 0

from matplotlib import gridspec
fig = plt.figure(figsize=(30, 30)) 
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1]) 

for row in range(nrow):
    for col in range(ncol):
        plt.subplot(gs[n])
        if col == 0:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_2, ind=row)
        elif col == 1:
            #plt.subplot(nrow, ncol, n)
            plt.imshow(w_delta)
        else:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_fool, ind=row)
        n += 1

plt.tight_layout()
#plt.show()
plt.savefig('grid_figure.pdf')
mjqavswn

mjqavswn1#

开头有一条注解:如果您想完全控制间距,请避免使用plt.tight_layout(),因为它会尝试将图形中的图排列成均匀且分布良好的图。这基本上是很好的,并且会产生令人满意的结果,但会随意调整间距。
您从Matplotlib示例库中引用的GridSpec示例之所以能很好地工作,是因为子图的纵横比不是预定义的。也就是说,子图将简单地在网格上展开,并使设置的间距(在本例中为wspace=0.0, hspace=0.0)独立于图形大小。
与此相反,您使用imshow来绘制图像,并且图像的纵横比默认设置为相等(相当于ax.set_aspect("equal"))。也就是说,您当然可以将set_aspect("auto")放置到每个绘图中(并额外添加wspace=0.0, hspace=0.0作为GridSpec的参数,如图库示例中所示),这将生成一个没有间距的绘图。
然而,当使用图像时,保持相等的纵横比是很有意义的,这样每个像素都一样宽一样高,并且正方形阵列显示为正方形图像。
接下来你需要做的就是调整图像大小和图边距以获得预期的结果。figure的figsize参数是以英寸为单位的图(宽度,高度),这里可以调整这两个数字的比率。子图参数wspace, hspace, top, bottom, left可以手动调整以给予预期的结果。下面是一个例子:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

nrow = 10
ncol = 3

fig = plt.figure(figsize=(4, 10)) 

gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1],
         wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845) 

for i in range(10):
    for j in range(3):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

#plt.tight_layout() # do not use this!!
plt.show()

编辑:

当然,不需要人工调整参数是理想的,因此可以根据行数和列数计算出一些最优的参数。

nrow = 7
ncol = 7

fig = plt.figure(figsize=(ncol+1, nrow+1)) 

gs = gridspec.GridSpec(nrow, ncol,
         wspace=0.0, hspace=0.0, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i in range(nrow):
    for j in range(ncol):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

plt.show()
2uluyalo

2uluyalo2#

尝试在代码中添加以下行:

fig.subplots_adjust(wspace=0, hspace=0)

并为每一个轴对象设置:

ax.set_xticklabels([])
ax.set_yticklabels([])
mutmk8jj

mutmk8jj3#

下面的答案由ImportanceOfBeingErnest给出,但如果要使用plt.subplots及其特性:

fig, axes = plt.subplots(
    nrow, ncol,
    gridspec_kw=dict(wspace=0.0, hspace=0.0,
                     top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                     left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)),
    figsize=(ncol + 1, nrow + 1),
    sharey='row', sharex='col', #  optionally
)
lnlaulya

lnlaulya4#

如果您使用matplotlib.pyplot.subplots,您可以使用Axes数组显示任意数量的图像。您可以通过对matplotlib.pyplot.subplots配置进行一些调整来删除图像之间的空格。

import matplotlib.pyplot as plt

def show_dataset_overview(self, img_list):
"""show each image in img_list without space"""
    img_number = len(img_list)
    img_number_at_a_row = 3
    row_number = int(img_number /img_number_at_a_row) 
    fig_size = (15*(img_number_at_a_row/row_number), 15)
    _, axs = plt.subplots(row_number, 
                          img_number_at_a_row, 
                          figsize=fig_size , 
                          gridspec_kw=dict(
                                       top = 1, bottom = 0, right = 1, left = 0, 
                                       hspace = 0, wspace = 0
                                       )
                         )
    axs = axs.flatten()

    for i in range(img_number):
        axs[i].imshow(img_list[i])
        axs[i].set_xticks([])
        axs[i].set_yticks([])

由于我们首先在这里创建子绘图区,因此我们可以使用gridspec_kw参数(source)为grid_spec给予一些参数。这些参数包括“top = 1,bottom = 0,right = 1,left = 0,hspace = 0,wspace = 0”参数,这些参数将阻止图像间的间距。要查看其他参数,请访问此处。
在设置上面的figure_size时,我通常使用像(30,15)这样的数字大小。我把它概括了一下,并把它添加到代码中。如果你愿意,你可以在这里输入一个手动大小。

mrwjdhj3

mrwjdhj35#

下面是使用ImageGrid类(改编自this answer)的另一种简单方法。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

nrow = 5
ncol = 3
fig = plt.figure(figsize=(4, 10))
grid = ImageGrid(fig, 
                 111, # as in plt.subplot(111)
                 nrows_ncols=(nrow,ncol),
                 axes_pad=0,
                 share_all=True,)

for row in grid.axes_column:
    for ax in row:
        im = np.random.rand(28,28)
        ax.imshow(im)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

相关问题