matplotlib 以特定的行和列格式保存多个图形

zvms9eto  于 2023-05-23  发布在  其他
关注(0)|答案(1)|浏览(123)

我正在尝试以特定格式保存多个图形。
这将是关于保存3个数字的行在一个垂直的数字。
不管它是奇数还是偶数,因为长度随着函数的执行而变化。我试过把几种解决办法结合起来,但都没有成功。
这是函数的代码:

def save_macro_figures(self):
    num_figures = len(gv.figures)
    rows = num_figures // 2  # Número de filas, asumiendo dos figuras por fila
    if num_figures % 2 != 0:
        rows += 1  # Agregar una fila adicional si hay un número impar de figuras

    macro_figures = []
    for i in range(rows):
        # Obtener las dos figuras de la fila
        row_figures = gv.figures[i*2: (i+1)*2]
        row_figures_images = [
            np.array(fig.canvas.renderer.buffer_rgba()) for fig in row_figures]
        row_figures_concatenated = np.concatenate(
            row_figures_images, axis=0)
        macro_figures.append(row_figures_concatenated)

    macro_figures_combined = np.concatenate(macro_figures, axis=0)

    Image.fromarray(macro_figures_combined).save(
        './data/outpouts/png/zero-macrofigure.png')
    matplotlib.pyplot.close()

目前我所取得的是保存在1单行的一切

  • 但我希望它是3位数每行1单列 *
aydmsdu9

aydmsdu91#

如果要使用串联,则必须单独处理每一行,然后串联行。在这里,我只是预先分配组合图,然后计算出组合图中每个单独图的索引。

#!/usr/bin/env python
"""
Combine several figures into one while specifying only the number of columns.
"""

import matplotlib as mpl; mpl.use("TKAgg")

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

def save_as_multi_column_figure(figures, filepath, columns=3):
    figure_arrays = [np.array(fig.canvas.renderer.buffer_rgba()) for fig in figures]
    figure_height, figure_width, figure_channels = figure_arrays[0].shape # assumes all figures have the same shape

    total_figures = len(figures)
    multi_column_width_in_figures = columns
    multi_column_height_in_figures = int(np.ceil(total_figures / multi_column_width_in_figures))

    multi_column_width_in_pixel = figure_width * multi_column_width_in_figures
    multi_column_height_in_pixel = figure_height * multi_column_height_in_figures
    multi_column = np.zeros((multi_column_height_in_pixel, multi_column_width_in_pixel, figure_channels), dtype=np.uint8)

    for ii, arr in enumerate(figure_arrays):
        col = ii % multi_column_width_in_figures
        row = int(ii / multi_column_width_in_figures)
        multi_column[row * figure_height : (row + 1) * figure_height, col * figure_width : (col + 1) * figure_width] = arr

    Image.fromarray(multi_column).save(filepath)

if __name__ == '__main__':

    figures = []
    for ii in range(8):
        fig, ax = plt.subplots()
        ax.plot(np.random.rand(10))
        fig.canvas.draw()
        figures.append(fig)

    save_as_multi_column_figure(figures, "test.png")
    plt.close()

相关问题