如何在Spark UDF中设置decimal返回类型的精度和小数位数?

46qrfjad  于 2023-10-23  发布在  Apache
关注(0)|答案(3)|浏览(263)

这是我的示例代码。我期望从UDF返回decimal(16,4),但它是decimal(38,18)。
有没有更好的解决办法?
我并不期待答案“cast(price as decimal(16,4))",因为我的UDF中有一些其他的业务逻辑,而不仅仅是casting。
先谢了。

import scala.util.Try
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.Decimal
val spark = SparkSession.builder().master("local[*]").appName("Test").getOrCreate()
import spark.implicits._

val stringToDecimal = udf((s:String, precision:Int, scale: Int) => {
  Try(Decimal(BigDecimal(s), precision, scale)).toOption
})

spark.udf.register("stringToDecimal", stringToDecimal)

val inDf = Seq(
  ("1", "864.412"),
  ("2", "1.600"),
  ("3", "2,56")).toDF("id", "price")

val outDf = inDf.selectExpr("id", "stringToDecimal(price, 16, 4) as price")
outDf.printSchema()
outDf.show()

------------------output----------------
root
  |-- id: string (nullable = true)
  |-- price: decimal(38,18) (nullable = true)

+---+--------------------+
| id|               price|
+---+--------------------+
|  1|864.4120000000000...|
|  2|1.600000000000000000|
|  3|                null|
+---+--------------------+
kuhbmx9i

kuhbmx9i1#

对于Spark 3.0及更低版本,您不能设置Spark用户定义函数(UDF)返回的十进制精度和小数位数,因为精度和小数位数在创建UDF时会被删除。

说明

要创建一个UDF,无论是通过调用函数udf并将lambda/函数作为参数,还是通过使用sparkSession.udf.register方法直接将lambda/函数注册为UDF,Spark都需要转换参数类型并将lambda/函数的类型返回到Spark's DataType
为此,Spark将使用类ScalaReflection中的方法schemaFor将scala类型Map到Spark的DataType。
对于BigDecimalDecimal类型,Map如下所示:

case t if isSubtype(t, localTypeOf[BigDecimal]) =>
  Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
  Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[Decimal]) =>
  Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)

这意味着当lambda/函数返回BigDecimalDecimal时,UDF的返回类型将是DecimalType。SYSTEM_DEFAULT. DecimalType.SYSTEM_DEFAULT类型是Decimal,精度为38,小数位数为18:

val MAX_PRECISION = 38
...
val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18)

总结

因此,每次将lambda或返回DecimalBigDecimal的函数转换为Spark的UDF时,精度和小数位数都会被删除,默认精度为38,小数位数为18。
因此,您唯一的方法是遵循previous answer并在调用时强制转换UDF的返回值

kuhbmx9i

kuhbmx9i2#

Spark将Decimaldecimal(38, 18)关联。您需要显式强制转换

$"price".cast(DataTypes.createDecimalType(32,2))
jexiocij

jexiocij3#

对于pyspark用途:

from pysprak.sql.types import DecimalType
def your_func(value):
    ...
spark.udf.register("your_func", your_func, DecimalType(precision=25, scale=10))

相关问题