python pyspark rdd.map函数模拟作用域

kx5bkwkv  于 2024-01-05  发布在  Python
关注(0)|答案(1)|浏览(141)

这些转换函数中的一些尊重mock,另一些不尊重,我不知道为什么。
这里有一个名为etl_job.py的文件,它包含各种转换函数,这些函数使用rdd.map将一列添加到DataFrame中,并使用从dependencies.utils导入的get_random_bool函数。

  1. # etl_job.py
  2. from dependencies.utils import get_random_bool
  3. def transform_data1(df):
  4. return df.rdd.map(lambda row: row+(get_random_bool(),)).toDF()
  5. def transform_data2(df):
  6. g = lambda row: row+(get_random_bool(),)
  7. return df.rdd.map(lambda row: g(row)).toDF()
  8. def transform_data3(df):
  9. g = lambda row: row+(get_random_bool(),)
  10. h = lambda row: g(row)
  11. return df.rdd.map(lambda row: h(row)).toDF()
  12. def transform_data4(df):
  13. return df.rdd.map(lambda row: f(row)).toDF()
  14. def transform_data5(df):
  15. g = lambda row: f(row)
  16. return df.rdd.map(lambda row: g(row)).toDF()
  17. def f(row):
  18. return row+(get_random_bool(),)

字符串
此测试文件尝试修补jobs.etl_job.py中导入的get_random_bool函数。

  1. from pyspark.sql import SparkSession
  2. from unittest.mock import patch
  3. from jobs.etl_job import transform_data1, transform_data2, transform_data3, transform_data4, transform_data5
  4. # Create SparkSession
  5. spark = SparkSession.builder \
  6. .master('local[1]') \
  7. .appName('Time Tests') \
  8. .getOrCreate()
  9. spark.sparkContext.setLogLevel("WARN")
  10. df = spark.createDataFrame([["1"],["2"]])
  11. print('original dataframe')
  12. df.show()
  13. with patch('jobs.etl_job.get_random_bool') as f:
  14. f.return_value = 'notabool'
  15. df_t = transform_data1(df)
  16. print('with transform_data1')
  17. df_t.show()
  18. df_t = transform_data2(df)
  19. print('with transform_data2')
  20. df_t.show()
  21. df_t = transform_data3(df)
  22. print('with transform_data3')
  23. df_t.show()
  24. df_t = transform_data4(df)
  25. print('with transform_data4')
  26. df_t.show()
  27. df_t = transform_data5(df)
  28. print('with transform_data5')
  29. df_t.show()


这是输出。

  1. original dataframe
  2. +---+
  3. | _1|
  4. +---+
  5. | 1|
  6. | 2|
  7. +---+
  8. with transform_data1
  9. +---+--------+
  10. | _1| _2|
  11. +---+--------+
  12. | 1|notabool|
  13. | 2|notabool|
  14. +---+--------+
  15. with transform_data2
  16. +---+--------+
  17. | _1| _2|
  18. +---+--------+
  19. | 1|notabool|
  20. | 2|notabool|
  21. +---+--------+
  22. with transform_data3
  23. +---+--------+
  24. | _1| _2|
  25. +---+--------+
  26. | 1|notabool|
  27. | 2|notabool|
  28. +---+--------+
  29. with transform_data4
  30. +---+----+
  31. | _1| _2|
  32. +---+----+
  33. | 1|true|
  34. | 2|true|
  35. +---+----+
  36. with transform_data5
  37. +---+-----+
  38. | _1| _2|
  39. +---+-----+
  40. | 1|false|
  41. | 2|false|
  42. +---+-----+


前三个转换工作正常--添加所有值都是notabool的模拟列--但第四个和第五个转换不正常--它们添加了一列布尔值而不是模拟值。如果模拟函数在transform函数内调用,或者如果模拟函数由transform函数的内部函数调用,则它工作正常;但如果transform函数调用调用被模拟函数的外部函数,则使用实际的非模拟函数。
有人能解释这种行为吗?

uemypmqf

uemypmqf1#

我知道这是一个老问题,但这个答案可能会帮助其他遇到同样问题的人。
我花了很多时间对这个问题感到困惑。事实证明,由于某种原因,mocking不能很好地与Spark并行线程一起工作。记住,map()函数中的lambda在executor上运行,而不是驱动程序,即在不同的线程/进程上。因此mocking必须在executor线程的上下文中发生,而不是驱动程序线程。
我让它正常工作的唯一方法是将lambda Package 在一个可以模拟的函数中,然后在测试中,用一个执行修补并调用原始函数的函数来模拟该函数(这应该发生在工作线程的上下文中)。

  1. # etl_job.py
  2. from dependencies.utils import get_random_bool
  3. def transform_row(row):
  4. # will be called within the context of a worker thread
  5. return f(row)
  6. def transform_data4(df):
  7. return df.rdd.map(transform_row).toDF()
  8. def f(row):
  9. return row+(get_random_bool(),)

字符串
测试文件:

  1. from pyspark.sql import SparkSession
  2. from unittest.mock import patch
  3. from jobs.etl_job import ..., transform_data4, transform_row
  4. ...
  5. def mock_transform_row(row):
  6. # patch in the context of the worker thread
  7. with patch('jobs.etl_job.get_random_bool') as f:
  8. f.return_value = 'notabool'
  9. return transform_row(row) # delegate to the original function
  10. with patch('jobs.etl_job.transform_row', side_effect=mock_tranform_row):
  11. ...
  12. df_t = transform_data4(df)
  13. print('with transform_data4')
  14. df_t.show()

展开查看全部

相关问题