scala Spark CaseWhen作为ADT的Map

mec1mxoz  于 2022-11-29  发布在  Scala
关注(0)|答案(1)|浏览(164)

bounty将在5天后过期。回答此问题可获得+50声望奖励。Fragan正在寻找标准答案

此问题是this one的后续问题。
快速上下文提醒:
Spark的CaseWhen采用Seq[(Expression, Expression)],其中第一个表达式是条件,第二个表达式是满足条件时要放置的值:

CaseWhen(
    branches: Seq[(Expression, Expression)],
    elseValue: Option[Expression] = None): ...

我希望能够使用Map对象执行Spark的CaseWhen。
Map可以是简单的,例如:

val spec = Map(
    ($"column_one" === 1) -> lit(2),
    ($"column_one" === 2 && $"column_two" === 1) -> lit(1),
    ($"column_one" === 3) -> lit(4),
)

它还可以同时是嵌套的和简单的:

val spec: Map[Column, Any] = Map(
    ($"column_one" === 1) -> Map(
        ($"column_two" === 2) -> lit(54),
        ($"column_two" === 5) -> lit(524)
    ),
    ($"column_one" === 2) -> Map(
        ($"column_two" === 7) -> Map(
            ($"whatever_column" === "whatever") -> lit(12),
            ($"whatever_column" === "whatever_two") -> lit(13)
        ),
        ($"column_two" === 8) -> lit(524)
    ),
    ($"column_one" === 3) -> lit(4)
)

所以,我用@aminmal给出的答案来回答另一个问题:

sealed trait ConditionValue  {
  def enumerate(reduceFunc: (Column, Column) => Column): Seq[(Expression, Expression)]
}
object ConditionValue {

  object implicits{

    implicit def test(condition: Column, value: Column): ConditionValue = {
        print("test")
        SingleLevelCaseWhen(Map(condition -> value))
        
    }
    
    implicit def testTuple(conditionValue: (Column, Column)): ConditionValue = {
        print("testTuple")
        SingleLevelCaseWhen(Map(conditionValue))
        
    }
        
    implicit def testNested(spec: Map[Column, ConditionValue]): ConditionValue = {
        print("testNested")
        NestedCaseWhen(spec)
        
    }
        
    implicit def testMap(spec: Map[Column, Column]): ConditionValue = {
        print("testMap")
        SingleLevelCaseWhen(spec)
        
    }
    
    implicit def expressionToColumn(expr: Expression): Column = new Column(expr)

    implicit def columnToExpression(col: Column): Expression = col.expr
  }
    
  import implicits._

  final case class SingleLevelCaseWhen(specificationMap: Map[Column, Column]) extends ConditionValue{
    override def enumerate(reduceFunc: (Column, Column) => Column): Seq[(Expression, Expression)] =
      specificationMap.map(x => (x._1.expr, x._2.expr)).toSeq
  }
  
  final case class NestedCaseWhen(specificationMap: Map[Column, ConditionValue]) extends ConditionValue{
    override def enumerate(reduceFunc: (Column, Column) => Column): Seq[(Expression, Expression)] =
      specificationMap.mapValues(_.enumerate(reduceFunc)).map{
        case (outerCondition, innerExpressions) => innerExpressions.map{
          case (innerCondition, innerValue) =>
            val conditions: Expression = reduceFunc(outerCondition, innerCondition)
            conditions -> innerValue

        }
      }.reduce(_ ++ _)
  }

}

我现在的问题是如何将Map对象转换为ConditionValue对象。正如你所看到的,我在代码中提供了一些暗示:

  • testcondition, value参数转换为ConditionValue对象。不确定此函数是否有用
  • testTuple(condition, value)元组参数转换为ConditionValue对象。也不确定此参数是否有用
  • testMap将单级Map[Column, Column]转换为ConditionValue对象
  • testNested将嵌套Map转换为ConditionValue对象

它非常适用于单级贴图:

import ConditionValue.implicits._
val spec = Map(
    ($"column_one" === 1) -> lit(2),
    ($"column_one" === 2 && $"column_two" === 1) -> lit(1),
    ($"column_one" === 3) -> lit(4)
)
val d: ConditionValue= spec
>> d: ConditionValue= SingleLevelCaseWhen(Map((column_one = 1) -> 2, ((column_one = 2) AND (column_two = 1)) -> 1, (column_one = 3) -> 4))

此外,它仅适用于嵌套Map:

val spec = Map[Column, ConditionValue](
    ($"column_one" === 1) -> Map(
        ($"column_two" === 2) -> lit(54),
        ($"column_two" === 5) -> lit(524)
    ),
    ($"column_one" === 2) -> Map[Column, ConditionValue](
        ($"column_two" === 7) -> Map(
            ($"whatever_column" === "whatever") -> lit(12),
            ($"whatever_column" === "whatever_two") -> lit(13)
        )
    )
)
val d: ConditionValue= spec
>>d: ConditionValue= NestedCaseWhen(Map((column_one = 1) -> SingleLevelCaseWhen(Map((column_two = 2) -> 54, (column_two = 5) -> 524)), (column_one = 2) -> NestedCaseWhen(Map((column_two = 7) -> SingleLevelCaseWhen(Map((whatever_column = whatever) -> 12, (whatever_column = whatever_two) -> 13))))))

现在有两件事困扰着我:

  • 它不适用于混合贴图(同时嵌套和简单)
  • 在处理嵌套Map时,我必须显式指定Map的类型Map[Column, ConditionValue]

谁能帮我一下吗?
谢谢你,

编辑

我有点“修复”了It doesn't work with mixed Maps (Nested and Simple in the same time)的问题,这意味着:

implicit def testVal(value: Column): ConditionValue = {
        testMap(Map(lit(true) -> value))
    }

不确定这是不是最好的解决方案.

s5a0g9ez

s5a0g9ez1#

我认为你可以通过使用类型推理来解决第二个问题。例如,你可以定义一个函数,它接受一个Map [Column,Any]并返回一个ConditionValue:

def mapToConditionValue(spec: Map[Column, Any]): ConditionValue = {
  spec.map {
    case (condition, value: Column) => (condition, SingleLevelCaseWhen(Map(condition -> value)))
    case (condition, value: Map[Column, Any]) => (condition, mapToConditionValue(value))
  }
  NestedCaseWhen(spec)
}

那么你可以这样称呼它:

val spec: Map[Column, Any] = Map(
    ($"column_one" === 1) -> Map(
        ($"column_two" === 2) -> lit(54),
        ($"column_two" === 5) -> lit(524)
    ),
    ($"column_one" === 2) -> Map(
        ($"column_two" === 7) -> Map(
            ($"whatever_column" === "whatever") -> lit(12),
            ($"whatever_column" === "whatever_two") -> lit(13)
        ),
        ($"column_two" === 8) -> lit(524)
    ),
    ($"column_one" === 3) -> lit(4)
)
val d: ConditionValue = mapToConditionValue(spec)

这对于嵌套贴图和单级别贴图都适用。
编辑:通过类型擦除应该可以做到这一点。你可以使用scala. reflect. runtime. universe包来获取Map中的值的类型,然后使用模式匹配来确定返回什么类型的ConditionValue。
下面的示例说明如何使用类型擦除来确定Map中值的类型,然后返回相应的ConditionValue:

import scala.reflect.runtime.universe._

def mapToConditionValue(spec: Map[Column, Any]): ConditionValue = {
  spec.map {
    case (condition, value) =>
      val tpe = value.getClass.getTypeName
      tpe match {
        case "scala.collection.immutable.Map" =>
          (condition, mapToConditionValue(value.asInstanceOf[Map[Column, Any]]))
        case "org.apache.spark.sql.Column" =>
          (condition, SingleLevelCaseWhen(Map(condition -> value.asInstanceOf[Column])))
        case _ =>
          throw new IllegalArgumentException(s"Unsupported type: $tpe")
      }
  }
  NestedCaseWhen(spec)
}

val spec: Map[Column, Any] = Map(
    ($"column_one" === 1) -> Map(
        ($"column_two" === 2) -> lit(54),
        ($"column_two" === 5) -> lit(524)
    ),
    ($"column_one" === 2) -> Map(
        ($"column_two" === 7) -> Map(
            ($"whatever_column" === "whatever") -> lit(12),
            ($"whatever_column" === "whatever_two") -> lit(13)
        ),
        ($"column_two" === 8) -> lit(524)
    ),
    ($"column_one" === 3) -> lit(4)
)
val d: ConditionValue = mapToConditionValue(spec)

相关问题