在Matplotlib中如何确保abline在每个子图中以1:1为中心,并具有单独的轴限制?

5cnsuln7  于 2023-08-06  发布在  其他
关注(0)|答案(1)|浏览(113)

我有一个2x2的子图网格,其中每个子图包含一个具有不同数据点的散点图。我试图在每个子图中绘制一条公共的abline(斜率=1,截距=0),以可视化数据点之间的关系。但是,由于每个子图中的数据范围不同,因此在所有子图中,abline并不以1:1居中显示。
我想确保abline在每个子图中以1:1居中,同时根据该特定子图中的数据点为每个图保持单独的轴限制。换句话说,我希望abline通过每个子图数据点的中心,而不会扭曲数据。
有人能指导我如何在每个子图中实现abline的正确居中,同时根据该子图中的数据点保持单独的轴限制吗?
这就是代码:

timesteps = [185, 159, 53, 2]

def abline(ax, slope, intercept):
    """Plot a line from slope and intercept"""
    x_vals = np.array(ax.get_xlim())
    y_vals = intercept + slope * x_vals
    ax.plot(x_vals, y_vals, 'r--')

fig, axs = plt.subplots(2, 2, figsize=(12, 8))

for i, timestep in enumerate(timesteps):
    mask = np.where(nan_mask[timestep, :, :] == 0)
    data_tmwm_values = data_tmwm[timestep, :, :][mask]
    ds_plot_values = ds_og_red[timestep, :, :][mask]

    row = i // 2  # Integer division to get the row index
    col = i % 2  # Modulo operation to get the column index
    
    ax = axs[row, col]
    ax.scatter(data_tmwm_values, ds_plot_values, s=20)
    ax.set_xlabel('TMWM')
    ax.set_ylabel('Original')
    ax.set_title(f'Scatter Plot (Timestep: {timestep})')

    correlation_matrix = np.corrcoef(data_tmwm_values, ds_plot_values)
    r_value = correlation_matrix[0, 1]

    r_squared = r_value ** 2
    abline(ax, 1, 0)
    ax.text(0.05, 0.95, f"R\u00b2 value: {r_squared:.3f}", transform=ax.transAxes, ha='left', va='top')

plt.tight_layout()
plt.show()

字符串
这就是图像:


的数据
我已经尝试过使用get_xlim()和get_ylim()函数来设置每个子图的轴限制,但它不会导致abline的正确居中。

cnh2zyt3

cnh2zyt31#

看起来您想要一个身份线,但您正在尝试线性拟合。线性拟合可能对您仍然有用,因为您计算了各种相关性度量并叠加R2。
下面的示例显示了如何添加线性拟合以及恒等(y=x)线。


的数据

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

timesteps = [185, 159, 53, 2]

fig, axs = plt.subplots(2, 2, figsize=(12, 8))

for timestep, ax in zip(timesteps, axs.flatten()):
    #Synthetic data
    data_tmwm_values = np.random.randn(200) * 10 + timestep / 2
    ds_plot_values = np.random.randn(200) * 20 + timestep / 2

    ax.scatter(data_tmwm_values, ds_plot_values, s=20)
    ax.set_xlabel('TMWM')
    ax.set_ylabel('Original')
    ax.set_title(f'Scatter Plot (Timestep: {timestep})')

    correlation_matrix = np.corrcoef(data_tmwm_values, ds_plot_values)
    r_value = correlation_matrix[0, 1]
    r_squared = r_value ** 2
    ax.text(0.05, 0.95, f"R\u00b2 value: {r_squared:.3f}", transform=ax.transAxes, ha='left', va='top')
    
    #Fit a straight line
    slope, intercept = np.polyfit(data_tmwm_values, ds_plot_values, deg=1)
    #Add the line to the plot, preserving the x and y ranges of the data
    x_low, x_high, y_low, y_high = ax.axis() #Get axis limits
    ax.plot([x_low, x_high], slope * np.array([x_low, x_high]) + intercept, 'r--', label='linear fit of data')
    
    #add identity line, in case that is what you wanted
    lim_low = min(x_low, y_low)
    lim_high = max(x_high, y_high)
    ax.plot([lim_low, lim_high], [lim_low, lim_high], '-k', linewidth=2, label='y=x identity line')
    
    #add legend for a plot, to clarify what the lines represent
    if ax is axs[0, 1]: ax.legend(loc='upper right') 
    
    #optional - clip limits to remove some padding
    ax.axis([lim_low, lim_high, lim_low, lim_high])
    
plt.tight_layout()
plt.show()

字符串

相关问题