Python matplotlib 3d plot有两个坐标轴?

ewm0tg9j  于 2023-10-24  发布在  Python
关注(0)|答案(1)|浏览(128)

我试图创建一个类似于下面的this paper的图,本质上是一个有两个不同y轴的3d图。根据this blog的指导,我创建了一个最小的例子。

模块

from mpl_toolkits import mplot3d
import numpy as np
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

创建一些数据

def f(x, y):
    return np.sin(np.sqrt(x ** 2 + y ** 2))

x = np.linspace(-6, 6, 30)
y = np.linspace(-6, 6, 30)

X, Y = np.meshgrid(x, y)
Z = f(X, Y)
Z2 = Z*100+100

阴谋

这就创建了一个很好的3d图,但显然只有一个y轴。我在网上找不到任何关于如何使用python的建议,尽管有一些是matlab的。

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z2, rstride=1, cstride=1,
                cmap='viridis', edgecolor='none')
ax.set_title('surface');
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z');

代码给出:

参考图:

iibxawm4

iibxawm41#

这并不容易。一种可能的变通方法如下:

  • 根据您共享的参考图,我认为您的意思实际上是希望实现第二个z轴,而不是y轴。
  • 3d图的axes对象仍然是单一的和共享的(出于必要性/明显的matplotlib 3d图限制),但是数据被绘制为相同的连续值尺度 *,但是轴刻度和标签被自定义覆盖以反映不同的值尺度 *。

例如,在一个示例中,

import numpy as np
import matplotlib.pyplot as plt

def f(x, y):
    return np.sin(np.sqrt(x ** 2 + y ** 2))

def g(x, y):
    return -np.cos(np.sqrt(x ** 2 + y ** 2))

x = np.linspace(-6, 6, 30)
y = np.linspace(-6, 6, 30)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
Z_new = g(X, Y)
offset = 5
Z_new_offset = Z_new + Z.max() + offset

fig = plt.figure(figsize=(16, 12))
ax = fig.add_subplot(111, projection="3d")

surf1 = ax.plot_surface(
    X, Y, Z, rstride=1, cstride=1, cmap="viridis", edgecolor="none", alpha=0.7
)

surf2 = ax.plot_surface(
    X,
    Y,
    Z_new_offset,
    rstride=1,
    cstride=1,
    cmap="plasma",
    edgecolor="none",
    alpha=0.7,
)

z_ticks_original = np.linspace(Z.min(), Z.max(), 5)

# Add custom tick labels and tick marks for the new plot on the left
z_ticks_new = np.linspace(Z_new_offset.min(), Z_new_offset.max(), 5)
for z_tick in z_ticks_new:
    ax.text(
        X.min() - 0.5,
        Y.min() - 2.5,
        z_tick + 0.25,
        f"{z_tick - (offset+1):.1f}",
        color="k",
        verticalalignment="center",
    )
    ax.plot(
        [X.min() - 0.5, X.min()],
        [Y.min() - 0.5, Y.min()],
        [z_tick, z_tick],
        color="k",
    )

ax.set_zticks(np.block([z_ticks_original, z_ticks_new]))

fig.canvas.draw()
labels = []
for lab, tick in zip(ax.get_zticklabels(), ax.get_zticks()):
    if float(tick) >= 1.0:
        lab.set_text("")
    labels += [lab]
ax.set_zticklabels(labels)

# Draw the left Z-axis line
ax.plot(
    [X.min() - 0.5] * 2,
    [Y.min() - 0.5] * 2,
    [z_ticks_new.min(), z_ticks_new.max()],
    color="k",
)

ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")

cbar1 = fig.colorbar(
    surf1, ax=ax, pad=-0.075, orientation="vertical", shrink=0.5
)
cbar1.set_label("Z Values (primary z-axis)")
cbar2 = fig.colorbar(
    surf2,
    ax=ax,
    pad=0.12,
    orientation="vertical",
    shrink=0.5,
    ticks=z_ticks_original,
)
cbar2.set_label("Z Values (secondary z-axis)")

plt.show()

生产:x1c 0d1x

相关问题