matplotlib 如何在3D投影图中删除偏移?

5gfr0r5j  于 2023-05-07  发布在  其他
关注(0)|答案(1)|浏览(120)

我意识到我用matplotlib在3D中绘制的图有轻微的“错位”。下面是一个MWE:

import numpy as np
from matplotlib import pyplot as plt
figure = plt.figure(figsize=(8,10.7))
ax = plt.gca(projection='3d')
ax.plot_surface(np.array([[0, 0], [30, 30]]),
                np.array([[10, 10], [10, 10]]),
                np.array([[10, 20], [10, 20]]), 
                rstride=1, cstride=1
)
ax.plot_surface(np.array([[0, 0], [30, 30]]),
                np.array([[20, 20], [20, 20]]),
                np.array([[10, 20], [10, 20]]), 
                rstride=1, cstride=1
)
plt.show()
plt.close()

很明显,面元没有正确地居中,因为曲面似乎从10.5开始,到20.5结束,而不是10和20。如何实现后者?

编辑:我担心建议的答案有问题。x轴没有黑色实线,默认情况下是这样的:

当我取出建议的 Package 时,我得到:

不幸的是,当我取出我正在绘制的东西时,这个问题在Jupyter笔记本中是不可重现的,但是,我想知道你是否可以向我指出我必须做什么,以便在我的情况下,x轴再次有一条黑线?

mrzz3bfm

mrzz3bfm1#

这是由matplotlib处理3D Axis坐标引起的。它故意移动minsmaxs来创建一些人工填充:
axis3d.py#L190-L194

class Axis(maxis.XAxis):
   ...
   def _get_coord_info(self, renderer):
       ...
       # Add a small offset between min/max point and the edge of the plot
       deltas = (maxs - mins) / 12
       mins -= 0.25 * deltas
       maxs += 0.25 * deltas
       ...
       return mins, maxs, centers, deltas, bounds_proj, highs

从v3.5.1开始,没有参数可以控制此行为。
但是,我们可以使用functools.wraps来创建一个围绕Axis._get_coord_info的 Package 器,该 Package 器取消minsmaxs的移位。为了防止此 Package 器多次取消移位(例如,当重新运行其Jupyter单元格时),track the wrapper state通过_unpadded属性:

from functools import wraps

def unpad(f): # where f will be Axis._get_coord_info
    @wraps(f)
    def wrapper(*args, **kwargs):
        mins, maxs, centers, deltas, bounds_proj, highs = f(*args, **kwargs)
        mins += 0.25 * deltas # undo original subtraction
        maxs -= 0.25 * deltas # undo original addition
        return mins, maxs, centers, deltas, bounds_proj, highs

    if getattr(f, '_unpadded', False): # bypass if already unpadded
        return f
    else:
        wrapper._unpadded = True # mark as unpadded
        return wrapper

打印前应用unpad Package 器:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axis3d import Axis

X = np.array([[0, 0], [30, 30]])
Z = np.array([[10, 20], [10, 20]])
y1, y2, y3 = 10, 16, 20

# wrap Axis._get_coord_info with our unpadded version
Axis._get_coord_info = unpad(Axis._get_coord_info)

fig, ax = plt.subplots(figsize=(8, 10.7), subplot_kw={'projection': '3d'})
ax.plot_surface(X, np.tile(y1, (2, 2)), Z, rstride=1, cstride=1)
ax.plot_surface(X, np.tile(y2, (2, 2)), Z, rstride=1, cstride=1)
ax.plot_surface(X, np.tile(y3, (2, 2)), Z, rstride=1, cstride=1)

plt.show()

相关问题