numpy 如何让嵌套的for循环在python中执行得更快?

flseospp  于 2024-01-08  发布在  Python
关注(0)|答案(2)|浏览(199)

以下是我的脚本:

  1. for a in range(-100, 101):
  2. for b in range(-100, 101):
  3. for c in range(-100, 101):
  4. for d in range(-100, 101):
  5. if abs(2**a*3**b*5**c*7**d-0.3048) <= 10**(-6):
  6. print('a=',a, ', b=', b, ', c=', c,', d=', d,', the number=', 2**a*3**b*5**c*7**d, ', error=', abs(2**a*3**b*5**c*7**d-.3048))

字符串
在python中执行上面的脚本花了27分15秒。我知道它经历了201^4次表达式计算,但我需要更快地运行这些计算(因为我想尝试range(-200,201)等等)。
我想知道是否有可能使上面的代码执行得更快。我认为使用numpy数组会有所帮助,但不确定如何应用它,以及它是否真的有效。

1szpjjfi

1szpjjfi1#

对于这些类型的计算,您可以尝试numba JIT:

  1. from numba import njit
  2. @njit
  3. def fn():
  4. for a in range(-100, 101):
  5. for b in range(-100, 101):
  6. for c in range(-100, 101):
  7. for d in range(-100, 101):
  8. n = (2.0**a) * (3.0**b) * (5.0**c) * (7.0**d)
  9. v = n - 0.3048
  10. if abs(v) <= 1e-06:
  11. print(
  12. "a=",
  13. a,
  14. ", b=",
  15. b,
  16. ", c=",
  17. c,
  18. ", d=",
  19. d,
  20. ", the number=",
  21. n,
  22. ", error=",
  23. abs(n - 3.048),
  24. )
  25. fn()

字符串
在我的机器(AMD 5700 X)上运行这段代码需要大约57秒(包括编译步骤)。相比之下,如果没有@njit(只是普通的Python),这只需要4分钟。

  1. a= -78 , b= -89 , c= -14 , d= 89 , the number= 0.3047994427888104 , error= 2.7432005572111895
  2. a= -78 , b= -57 , c= 50 , d= 18 , the number= 0.30479915330101043 , error= 2.7432008466989894
  3. a= -69 , b= -85 , c= 87 , d= 0 , the number= 0.3047993420932106 , error= 2.7432006579067894
  4. a= -63 , b= 42 , c= -99 , d= 80 , the number= 0.3048005478488736 , error= 2.7431994521511265
  5. a= -63 , b= 74 , c= -35 , d= 9 , the number= 0.3048002583600241 , error= 2.743199741639976
  6. a= -54 , b= 14 , c= -62 , d= 62 , the number= 0.3048007366419375 , error= 2.7431992633580626
  7. a= -54 , b= 46 , c= 2 , d= -9 , the number= 0.30480044715290866 , error= 2.7431995528470914
  8. a= -54 , b= 78 , c= 66 , d= -80 , the number= 0.3048001576641548 , error= 2.7431998423358452
  9. a= -45 , b= -14 , c= -25 , d= 44 , the number= 0.30480092543511833 , error= 2.7431990745648815
  10. a= -45 , b= 18 , c= 39 , d= -27 , the number= 0.3048006359459102 , error= 2.7431993640540897
  11. a= -36 , b= -10 , c= 76 , d= -45 , the number= 0.30480082473902875 , error= 2.7431991752609712
  12. a= 5 , b= -44 , c= -72 , d= 82 , the number= 0.30479914163960603 , error= 2.743200858360394
  13. a= 14 , b= -72 , c= -35 , d= 64 , the number= 0.304799330431799 , error= 2.743200669568201
  14. a= 14 , b= -40 , c= 29 , d= -7 , the number= 0.3047990409441057 , error= 2.743200959055894
  15. a= 23 , b= -100 , c= 2 , d= 46 , the number= 0.30479951922410875 , error= 2.7432004807758914
  16. a= 23 , b= -68 , c= 66 , d= -25 , the number= 0.30479922973623635 , error= 2.7432007702637637
  17. a= 29 , b= 91 , c= -56 , d= -16 , the number= 0.30480014600271205 , error= 2.743199853997288
  18. a= 38 , b= 31 , c= -83 , d= 37 , the number= 0.30480062428444915 , error= 2.743199375715551
  19. a= 38 , b= 63 , c= -19 , d= -34 , the number= 0.30480033479552704 , error= 2.743199665204473
  20. a= 47 , b= 3 , c= -46 , d= 19 , the number= 0.30480081307756046 , error= 2.7431991869224395
  21. a= 47 , b= 35 , c= 18 , d= -52 , the number= 0.30480052358845894 , error= 2.743199476411541
  22. a= 56 , b= 7 , c= 55 , d= -70 , the number= 0.3048007123815079 , error= 2.7431992876184923
  23. a= 65 , b= -21 , c= 92 , d= -88 , the number= 0.3048009011746738 , error= 2.7431990988253263
  24. a= 97 , b= -27 , c= -93 , d= 57 , the number= 0.3047990292827057 , error= 2.7432009707172944
  25. real 0m57,939s
  26. user 0m0,009s
  27. sys 0m0,009s


看看你的代码,你可以使用parallel rangeprange)来进一步加快速度:

  1. from numba import njit, prange
  2. @njit(parallel=True)
  3. def fn():
  4. for a in prange(-100, 101):
  5. i_a = 2.0**a
  6. for b in prange(-100, 101):
  7. i_b = i_a * 3.0**b
  8. for c in prange(-100, 101):
  9. i_c = i_b * 5.0**c
  10. for d in prange(-100, 101):
  11. n = i_c * (7.0**d)
  12. v = n - 0.3048
  13. if abs(v) <= 1e-06:
  14. print(
  15. "a=",
  16. a,
  17. ", b=",
  18. b,
  19. ", c=",
  20. c,
  21. ", d=",
  22. d,
  23. ", the number=",
  24. n,
  25. ", error=",
  26. abs(n - 3.048),
  27. )
  28. fn()


在我的8 C/16 T机器上只需~2.7秒。
@EDIT:添加了存储中间结果。谢谢@yotheguitou

展开查看全部
puruo6ea

puruo6ea2#

几分钟后就开始了。
主要的速度改进只是预先计算所有的权力。我怀疑itertools实际上给了我任何东西。
你可能不是故意在一个位置使用.3048,在打印消息中使用3.048。我把两者都改成了.3048。也许你是指另一个。

  1. import itertools
  2. aa = {i: 2 ** i for i in range(-100, 101)}
  3. bb = {i: 3 ** i for i in range(-100, 101)}
  4. cc = {i: 5 ** i for i in range(-100, 101)}
  5. dd = {i: 7 ** i for i in range(-100, 101)}
  6. for (a, avalue), (b, bvalue), (c, cvalue), (d, dvalue) in itertools.product(aa.items(), bb.items(), cc.items(), dd.items()):
  7. if abs(avalue * bvalue * cvalue * dvalue - .3048) <= 1e-6:
  8. value = avalue * bvalue * cvalue * dvalue
  9. print('a=',a, ', b=', b, ', c=', c,', d=', d,', the number=', value, ', error=', abs(value - .3048))

字符串

相关问题