Pyspark单元测试:如何模拟sql调用(并且只模拟sql调用)?

ttisahbt  于 2023-11-16  发布在  Spark
关注(0)|答案(2)|浏览(148)

我在测试以下函数时遇到问题

# module1.py
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql import SparkSession
from pyspark.sql import functions as f

spark = SparkSession.getActiveSession()

def filter_big_table(
    conditions: Optional[List[str]] = None,
) -> SparkDataFrame:

    sdf = spark.sql(f"select * from SUPER_BIG_TABLE")

    if conditions:
        sdf = sdf.filter(f.col("condition_col").isin(conditions))
    return sdf

字符串
下面是我如何尝试使用unittest.mockpytest来实现它:

#conftest.py
import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark():
    return SparkSession.builder.master("local").appName("unit_tests").getOrCreate()


在实际的测试中,我用元组fixture创建了一个假的DataFrame,这样我就不必真正运行这个查询spark.sql(f"select * from SUPER_BIG_TABLE")

# test.py
from unittest.mock import patch

from module1 import filter_big_table
from . import assert_sdf_equal  # Helper function

DATA_FIXTURE = List[Tuple[str, int]]
DATA_SCHEMA = StructType

@patch("pyspark.sql.SparkSession.sql")
def test_filter_big_table(mock_data, spark):
    
    input_sdf = spark.createDataFrame(DATA_FIXTURE, DATA_SCHEMA)
    mock_data.return_value = input_sdf

    expected_sdf = spark.createDataFrame(DATA_FIXTURE, DATA_SCHEMA)
    output_sdf = filter_big_table()

    # Passing no conditions should return the same spark dataframe
    assert_sdf_equal(output_sdf, expected_sdf)


但是,我一直得到的错误是AttributeError: 'NoneType' object has no attribute 'sql'
我觉得这是因为模块在module1.py上获取当前spark会话的方式,但我不想为了通过测试而更改模块。

编辑

尝试使用@patch("module1.spark")代替,但不是一个AttributeError,我收到一个mock对象,而不是input_sdf fixture。还尝试将mock_data.return_value替换为mock_data.side_effects,但没有帮助。

hl0ma9xz

hl0ma9xz1#

为了让它工作,不得不改变模块:(
现在,会话在函数内部创建。

# module1.py
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql import SparkSession
from pyspark.sql import functions as f

def filter_big_table(
    conditions: Optional[List[str]] = None,
) -> SparkDataFrame:
    spark = SparkSession.getActiveSession()
    sdf = spark.sql(f"select * from SUPER_BIG_TABLE")

    if conditions:
        sdf = sdf.filter(f.col("condition_col").isin(conditions))
    return sdf

字符串
装饰器是@patch("module1.SparkSession.sql")
这样就不会发生名称冲突,并且mock成功地替换了sql方法。

ymdaylpp

ymdaylpp2#

我仍然不得不修改你的模块1一点点,但这仍然是有指导意义的。
台词:

spark = SparkSession.getActiveSession()

字符串
不一定创建一个带有SQL属性的Spark对象,结果是:

Exception has occurred: AttributeError
None does not have the attribute 'sql'


如果初始化spark的方式不同:

spark = SparkSession.builder.appName('UnitTests').getOrCreate()


那么它确实具有SQL属性。
然后你可以使用patch来引用初始化模块中的spark对象,我们可以在测试函数之外,根据下面的例子:
保持module1与你原来的帖子一样,改变了spark的创建方式:

# module1.py
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql import SparkSession
from pyspark.sql import functions as f

# CHANGED FROM GETACTIVESESSION
spark = SparkSession.builder.appName('UnitTests').getOrCreate()

def filter_big_table(
    conditions = None,
) -> SparkDataFrame:

    sdf = spark.sql(f"select * from SUPER_BIG_TABLE")

    if conditions:
        sdf = sdf.filter(f.col("condition_col").isin(conditions))
    return sdf


您的测试现在显示:

# test.py
from unittest.mock import patch    
import module1

def test_filter_big_table():
    mock_data = ["list of data"]
    with patch("module1.spark.sql", return_value=mock_data):
        
        output_sdf = module1.filter_big_table() 
        # the call that reads...
        # sdf = spark.sql(f"select * from SUPER_BIG_TABLE") 
        # within this module should now return mock_data

    assert mock_data == output_sdf
    
    
if __name__ == "__main__":
    test_filter_big_table()

相关问题