matplotlib 如何使用ax正确编辑图表图例

pod7payv  于 2023-03-09  发布在  其他
关注(0)|答案(1)|浏览(189)

我有一些代码输出以下图像:

我想编辑左下图中的图例。我想将其改为:
H_0:“星星”
H_1:“绿色”
堆叠在另一个的顶部,其中“星星”和“绿色”被替换为各自的实际符号(在图中使用)。

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from mpl_toolkits import mplot3d
  4. import matplotlib.colors as colors
  5. # Define the function to plot
  6. def f(x, y):
  7. return np.sin(np.sqrt(x**2 + y**2))
  8. # Generate data for the x, y, and z coordinates
  9. x = np.linspace(-6, 6, 100)
  10. y = np.linspace(-6, 6, 100)
  11. X, Y = np.meshgrid(x, y)
  12. Z = f(X, Y)
  13. cmap = colors.ListedColormap(['black'])
  14. # Create a 3D figure and a contour plot side by side
  15. fig = plt.figure(figsize=(10, 8))
  16. ax1 = fig.add_subplot(221, projection='3d')
  17. ax2 = fig.add_subplot(222)
  18. ax3 = fig.add_subplot(223)
  19. ax4 = fig.add_subplot(224)
  20. # NEED A NAME FOR THIS SECTION
  21. Zero_dim_births = np.array([-1, -.25, .75])
  22. One_dim_births = np.array([-1,.75])
  23. # Plot the surface on the left subplot
  24. ax1.plot_surface(X, Y, Z, cmap='jet')
  25. i = 1 # intialize i
  26. level_set_speed = .075 # how quickly the level sets expand
  27. plane_speed = .05 # how quickly the plane moves up
  28. for a in np.arange(-1,1.05,plane_speed): # controls the movement of the plane
  29. i += level_set_speed
  30. #Plot the plane moving up the surface on the left
  31. ax1.cla()
  32. plane = np.zeros_like(X)
  33. plane = np.zeros_like(X) + a
  34. ax1.plot_surface(X, Y, Z, cmap='jet')
  35. ax1.plot_wireframe(X, Y, plane, color='black')
  36. # Plot the contour on the right subplot
  37. contour_levels = np.arange(Z.min(), Z.min()+i, i/2)
  38. ax2.contourf(X, Y, Z, levels=contour_levels, cmap=cmap, extend='min')\
  39. # Plot the persistence diagram
  40. ax3.cla()
  41. input = np.arange(-1.1,1.1,.1)
  42. id_fun = input
  43. ax3.plot(input,id_fun, color = 'blue')
  44. ax3.axhline(y = a, color = 'black', linestyle = '-')
  45. for index in range(0,3):
  46. ax3.plot(Zero_dim_births[index], 1, marker = "o", markersize = 7, markeredgecolor = "green", markerfacecolor = "green")
  47. for index in range(0,2):
  48. ax3.plot(One_dim_births[index], 1, marker = "*", markersize = 5, markeredgecolor = "red", markerfacecolor = "red")
  49. # Plot the barcodes
  50. ax4.cla()
  51. ax4.plot([-1,1], [1,1], color='green') # 0D component 1
  52. ax4.plot([-1,1], [2,2], color='red') # 1D component 1
  53. ax4.plot([-.25,1], [3,3], color='green') # 0D component 2
  54. ax4.plot([.75,1], [4,4], color='green') # 0D component 3
  55. ax4.plot([.75,1], [5,5], color='red') # 1D component 4
  56. ax4.axvline(x = a, color = 'black', linestyle = '-')
  57. # Labels for all the plots
  58. # plot 1
  59. ax1.set_xlabel('x')
  60. ax1.set_ylabel('y')
  61. ax1.set_zlabel('z')
  62. ax1.set_title(r'$f(x, y) = sin(sqrt(x^2 + y^2))$')
  63. # plot 2
  64. ax2.set_title('Sublevel set filtration')
  65. ax3.set_xlabel('Birth (height)')
  66. ax3.set_ylabel('Death')
  67. ax3.set_title('Persistence Diagram')
  68. ax3.legend(['NEED TO EDIT'], loc='lower right')
  69. # plot 4
  70. ax4.set_xlabel('Persistence')
  71. ax4.set_ylabel('Components')
  72. ax4.set_yticks([])
  73. ax4.set_title('Barcodes')
  74. # snapshot
  75. plt.pause(.1)
  76. # Show the plot
  77. plt.tight_layout()
  78. plt.show()

有谁能教我如何像我想的那样重新演绎这个传奇吗?

cigdeys3

cigdeys31#

这里的技巧是使用matplotlib lines的Line2D来创建一个“假”线,然后你可以从它创建一个图例。下面是一段代码:

  1. import matplotlib.lines as mlines
  2. # Create lines with markers
  3. star = mlines.Line2D([], [], color='white', marker='*', markerfacecolor='r', markeredgecolor='r',ls='', label='H_0')
  4. dot = mlines.Line2D([], [], color='white', marker='o', markerfacecolor='g', markeredgecolor='g',ls='', label='H_1')
  5. # Add legend
  6. ax3.legend(handles=[star, dot], loc='lower right')

color='white'ls=''的存在使得图例中仅显示标记(星星、绿色),后面没有线。

以下是完整的代码,以防万一:

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from mpl_toolkits import mplot3d
  4. import matplotlib.colors as colors
  5. import matplotlib.lines as mlines
  6. # Define the function to plot
  7. def f(x, y):
  8. return np.sin(np.sqrt(x**2 + y**2))
  9. # Generate data for the x, y, and z coordinates
  10. x = np.linspace(-6, 6, 100)
  11. y = np.linspace(-6, 6, 100)
  12. X, Y = np.meshgrid(x, y)
  13. Z = f(X, Y)
  14. cmap = colors.ListedColormap(['black'])
  15. # Create a 3D figure and a contour plot side by side
  16. fig = plt.figure(figsize=(10, 8))
  17. ax1 = fig.add_subplot(221, projection='3d')
  18. ax2 = fig.add_subplot(222)
  19. ax3 = fig.add_subplot(223)
  20. ax4 = fig.add_subplot(224)
  21. # NEED A NAME FOR THIS SECTION
  22. Zero_dim_births = np.array([-1, -.25, .75])
  23. One_dim_births = np.array([-1,.75])
  24. # Plot the surface on the left subplot
  25. ax1.plot_surface(X, Y, Z, cmap='jet')
  26. i = 1 # intialize i
  27. level_set_speed = .075 # how quickly the level sets expand
  28. plane_speed = .05 # how quickly the plane moves up
  29. for a in np.arange(-1,1.05,plane_speed): # controls the movement of the plane
  30. i += level_set_speed
  31. #Plot the plane moving up the surface on the left
  32. ax1.cla()
  33. plane = np.zeros_like(X)
  34. plane = np.zeros_like(X) + a
  35. ax1.plot_surface(X, Y, Z, cmap='jet')
  36. ax1.plot_wireframe(X, Y, plane, color='black')
  37. # Plot the contour on the right subplot
  38. contour_levels = np.arange(Z.min(), Z.min()+i, i/2)
  39. ax2.contourf(X, Y, Z, levels=contour_levels, cmap=cmap, extend='min') \
  40. \
  41. # Plot the persistence diagram
  42. ax3.cla()
  43. input = np.arange(-1.1,1.1,.1)
  44. id_fun = input
  45. ax3.plot(input,id_fun, color = 'blue')
  46. ax3.axhline(y = a, color = 'black', linestyle = '-')
  47. for index in range(0,3):
  48. ax3.plot(Zero_dim_births[index], 1, marker = "o", markersize = 7, markeredgecolor = "green",
  49. markerfacecolor = "green", label='H_1')
  50. for index in range(0,2):
  51. ax3.plot(One_dim_births[index], 1, marker = "*", markersize = 5, markeredgecolor = "red",
  52. markerfacecolor = "red", label='H_0')
  53. # Plot the barcodes
  54. ax4.cla()
  55. ax4.plot([-1,1], [1,1], color='green') # 0D component 1
  56. ax4.plot([-1,1], [2,2], color='red') # 1D component 1
  57. ax4.plot([-.25,1], [3,3], color='green') # 0D component 2
  58. ax4.plot([.75,1], [4,4], color='green') # 0D component 3
  59. ax4.plot([.75,1], [5,5], color='red') # 1D component 4
  60. ax4.axvline(x = a, color = 'black', linestyle = '-')
  61. # Labels for all the plots
  62. # plot 1
  63. ax1.set_xlabel('x')
  64. ax1.set_ylabel('y')
  65. ax1.set_zlabel('z')
  66. ax1.set_title(r'$f(x, y) = sin(sqrt(x^2 + y^2))$')
  67. # plot 2
  68. ax2.set_title('Sublevel set filtration')
  69. ax3.set_xlabel('Birth (height)')
  70. ax3.set_ylabel('Death')
  71. ax3.set_title('Persistence Diagram')
  72. # Add legend
  73. star = mlines.Line2D([], [], color='white', marker='*', markerfacecolor='r', markeredgecolor='r',
  74. ls='', label='H_0')
  75. dot = mlines.Line2D([], [], color='white', marker='o', markerfacecolor='g', markeredgecolor='g',
  76. ls='', label='H_1')
  77. ax3.legend(handles=[star, dot], loc='lower right')
  78. # plot 4
  79. ax4.set_xlabel('Persistence')
  80. ax4.set_ylabel('Components')
  81. ax4.set_yticks([])
  82. ax4.set_title('Barcodes')
  83. # snapshot
  84. plt.pause(.1)
  85. # Show the plot
  86. plt.tight_layout()
  87. plt.show()

希望这对你有帮助,干杯!

展开查看全部

相关问题