numpy 优化Python函数SmoothStep的多个条件用于Numba矢量化

azpvetkf  于 2023-10-19  发布在  Python
关注(0)|答案(1)|浏览(99)

我实现了一个使用SmoothStep创建平滑矩形函数的函数:

  1. import numpy as np
  2. from numba import jit, njit
  3. import matplotlib.pyplot as plt
  4. @njit
  5. def GenSmoothStep( vX: np.ndarray, lowVal: float, highVal: float, vY: np.ndarray, rollOffWidth: float = 0.1 ):
  6. lowClip = max(lowVal - rollOffWidth, 0)
  7. highClip = min(highVal + rollOffWidth, 1)
  8. for ii in range(vX.size):
  9. valX = vX.flat[ii]
  10. if valX < lowClip:
  11. vY.flat[ii] = 0.0
  12. elif valX < lowVal:
  13. # Smoothstep [lowClip, lowVal]
  14. valXN = (lowVal - valX) / (lowVal - lowClip)
  15. vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
  16. elif valX > highClip:
  17. vY.flat[ii] = 0.0
  18. elif valX > highVal:
  19. # Smoothstep [highVal, highClip]
  20. valXN = (valX - highVal) / (highClip - highVal)
  21. vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
  22. else:
  23. vY.flat[ii] = 1.0
  24. numGridPts = 1000
  25. lowVal = 0.15
  26. highVal = 0.75
  27. rollOffWidth = 0.3
  28. vX = np.linspace(0, 1, numGridPts)
  29. vY = np.empty_like(vX)
  30. GenSmoothStep(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)
  31. plt.plot(vX, vY)

该函数包括几个条件,这意味着向量化不友好。
我想知道是否有一些简单的步骤,使功能更Numba友好。

更新

我采用了@AndrejKesely的代码并更新了它,以处理我的代码中的边缘情况(lowVal = 0.0和/或highVal = 1.0)。
我还添加了一个变体来剪辑没有分支。这是当前的状态:

  1. # %%
  2. import numpy as np
  3. from numba import jit, njit
  4. import matplotlib.pyplot as plt
  5. from timeit import timeit
  6. @njit
  7. def GenSmoothStep000( vX: np.ndarray, lowVal: float, highVal: float, vY: np.ndarray, rollOffWidth: float = 0.1 ):
  8. lowClip = max(lowVal - rollOffWidth, 0)
  9. highClip = min(highVal + rollOffWidth, 1)
  10. for ii in range(vX.size):
  11. valX = vX.flat[ii]
  12. if valX < lowClip:
  13. vY.flat[ii] = 0.0
  14. elif valX < lowVal:
  15. # Smoothstep [lowClip, lowVal]
  16. valXN = (lowVal - valX) / (lowVal - lowClip)
  17. vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
  18. elif valX > highClip:
  19. vY.flat[ii] = 0.0
  20. elif valX > highVal:
  21. # Smoothstep [highVal, highClip]
  22. valXN = (valX - highVal) / (highClip - highVal)
  23. vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
  24. else:
  25. vY.flat[ii] = 1.0
  26. @njit
  27. def Clamp001( x: float, lowBound: float = 0.0, highBound: float = 1.0 ):
  28. if x < lowBound:
  29. return lowBound
  30. if x > highBound:
  31. return highBound
  32. return x
  33. @njit
  34. def Clamp002( x: float, lowBound: float = 0.0, highBound: float = 1.0 ):
  35. return max(min(x, highBound), lowBound)
  36. @njit
  37. def SmoothStep001( x: float, lowBound: float = 0.0, highBound: float = 1.0 ):
  38. x = Clamp001((x - lowBound) / (highBound - lowBound), 0.0, 1.0)
  39. return x * x * (3.0 - 2.0 * x)
  40. @njit
  41. def SmoothStep002( x: float, lowBound: float = 0.0, highBound: float = 1.0 ):
  42. x = Clamp002((x - lowBound) / (highBound - lowBound), 0.0, 1.0)
  43. return x * x * (3.0 - 2.0 * x)
  44. @njit
  45. def GenSmoothStep001( vX: np.ndarray, lowVal: float, highVal: float, vY: np.ndarray, rollOffWidth: float = 0.1 ):
  46. lowClip = max(lowVal - rollOffWidth, 0)
  47. highClip = min(highVal + rollOffWidth, 1)
  48. if (highVal == 1.0) and (lowVal == 0.0):
  49. for ii in range(vX.size):
  50. vY[ii] = 1.0
  51. elif (highVal == 1.0):
  52. for ii in range(vX.size):
  53. vY[ii] = SmoothStep001(vX[ii], lowClip, lowVal)
  54. elif (lowVal == 0.0):
  55. for ii in range(vX.size):
  56. vY[ii] = 1 - SmoothStep001(vX[ii], highVal, highClip)
  57. else:
  58. for ii in range(vX.size):
  59. vY[ii] = SmoothStep001(vX[ii], lowClip, lowVal) * (1 - SmoothStep001(vX[ii], highVal, highClip))
  60. @njit
  61. def GenSmoothStep002( vX: np.ndarray, lowVal: float, highVal: float, vY: np.ndarray, rollOffWidth: float = 0.1 ):
  62. lowClip = max(lowVal - rollOffWidth, 0)
  63. highClip = min(highVal + rollOffWidth, 1)
  64. if (highVal == 1.0) and (lowVal == 0.0):
  65. for ii in range(vX.size):
  66. vY[ii] = 1.0
  67. elif (highVal == 1.0):
  68. for ii in range(vX.size):
  69. vY[ii] = SmoothStep002(vX[ii], lowClip, lowVal)
  70. elif (lowVal == 0.0):
  71. for ii in range(vX.size):
  72. vY[ii] = 1 - SmoothStep002(vX[ii], highVal, highClip)
  73. else:
  74. for ii in range(vX.size):
  75. vY[ii] = SmoothStep002(vX[ii], lowClip, lowVal) * (1 - SmoothStep002(vX[ii], highVal, highClip))
  76. # %%
  77. # Validation + JIT Compilation
  78. numGridPts = 10_000
  79. lowVal = 0.35
  80. highVal = 0.55
  81. rollOffWidth = 0.3
  82. vX = np.linspace(0, 1, numGridPts)
  83. hF, vHa = plt.subplots(nrows = 1, ncols = 3, figsize = (16, 5))
  84. vY = np.empty_like(vX)
  85. GenSmoothStep000(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)
  86. vHa[0].plot(vX, vY)
  87. vY = np.empty_like(vX)
  88. GenSmoothStep001(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)
  89. vHa[1].plot(vX, vY)
  90. vY = np.empty_like(vX)
  91. GenSmoothStep002(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)
  92. vHa[2].plot(vX, vY);
  93. # %%
  94. # Check Performance
  95. time000 = timeit("GenSmoothStep000(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)", number = 10_000, globals = globals())
  96. time001 = timeit("GenSmoothStep001(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)", number = 10_000, globals = globals())
  97. time002 = timeit("GenSmoothStep002(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)", number = 10_000, globals = globals())
  98. print(time000)
  99. print(time001)
  100. print(time002)

输出为(在我的计算机上,Intel Core i7-6800K):

  1. 0.23776450000877958
  2. 0.23713289998704568
  3. 0.23025239999697078

所以看起来还是很接近的。

xzlaal3s

xzlaal3s1#

IIUC你只想合并smoothstep

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from numba import njit
  4. @njit
  5. def smoothstep(edge0, edge1, x):
  6. x = np.clip((x - edge0) / (edge1 - edge0), 0, 1)
  7. return x * x * (3.0 - 2.0 * x)
  8. numGridPts = 1000
  9. lowVal = 0.15
  10. highVal = 0.75
  11. vX = np.linspace(0, 1, numGridPts)
  12. vY = smoothstep(0, lowVal, vX) * (1 - smoothstep(highVal, 1, vX))
  13. plt.plot(vX, vY)
  14. plt.show()

显示此图表:

编辑:新版本(但通常与组合smoothsteps相同),没有分配和if-s(支持rollOffWidth):

  1. @njit
  2. def clamp(x, lowerlimit=0.0, upperlimit=1.0):
  3. if x < lowerlimit:
  4. return lowerlimit
  5. if x > upperlimit:
  6. return upperlimit
  7. return x
  8. @njit
  9. def smoothstep(edge0, edge1, x):
  10. x = clamp((x - edge0) / (edge1 - edge0), 0, 1)
  11. return x * x * (3.0 - 2.0 * x)
  12. @njit
  13. def GenSmoothStep2(
  14. vX: np.ndarray,
  15. lowVal: float,
  16. highVal: float,
  17. vY: np.ndarray,
  18. rollOffWidth: float = 0.1,
  19. ):
  20. lowClip = max(lowVal - rollOffWidth, 0)
  21. highClip = min(highVal + rollOffWidth, 1)
  22. for i in range(len(vX)):
  23. vY[i] = smoothstep(lowClip, lowVal, vX[i]) * (
  24. 1 - smoothstep(highVal, highClip, vX[i])
  25. )

基准:

  1. numGridPts = 1000
  2. lowVal = 0.45
  3. highVal = 0.65
  4. rollOffWidth = 0.3
  5. vX = np.linspace(0, 1, numGridPts)
  6. vY = np.empty_like(vX)
  7. # warm up numba:
  8. GenSmoothStep(vX, lowVal, highVal, vY, rollOffWidth=rollOffWidth)
  9. GenSmoothStep2(vX, lowVal, highVal, vY, rollOffWidth=rollOffWidth)
  10. from timeit import timeit
  11. t1 = timeit(
  12. "GenSmoothStep(vX, lowVal, highVal, vY, rollOffWidth=rollOffWidth)",
  13. number=10_000,
  14. globals=globals(),
  15. )
  16. t2 = timeit(
  17. "GenSmoothStep2(vX, lowVal, highVal, vY, rollOffWidth=rollOffWidth)",
  18. number=10_000,
  19. globals=globals(),
  20. )
  21. print(t1)
  22. print(t2)

我的计算机(AMD 5700 X)上的打印:

  1. 0.010718749952502549
  2. 0.007840660051442683

因此,新功能的速度快了约36%。

展开查看全部

相关问题