如何用spark和scala计算迭代计算

72qzrwbm  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(547)

在下面的示例中,代码生成一个系统地应用于同一组原始记录的计算。相反,代码必须使用先前计算的值来生成后续数量。

package playground

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{KeyValueGroupedDataset, SparkSession}

object basic2 extends App {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  val spark = SparkSession
    .builder()
    .appName("Sample app")
    .master("local")
    .getOrCreate()

  import spark.implicits._

  final case class Owner(car: String, pcode: String, qtty: Double)
  final case class Invoice(car: String, pcode: String, qtty: Double)

  val data = Seq(
    Owner("A", "666", 80),
    Owner("B", "555", 20),
    Owner("A", "444", 50),
    Owner("A", "222", 20),
    Owner("C", "444", 20),
    Owner("C", "666", 80),
    Owner("C", "555", 120),
    Owner("A", "888", 100)
  )

  val fleet = Seq(Invoice("A", "666", 15), Invoice("A", "888", 12))

  val owners = spark.createDataset(data)
  val invoices = spark.createDataset(fleet)

  val gb: KeyValueGroupedDataset[Invoice, (Owner, Invoice)] = owners
    .joinWith(invoices, invoices("car") === owners("car"), "inner")
    .groupByKey(_._2)

  gb.flatMapGroups {
      case (fleet, group) ⇒
        val subOwner: Vector[Owner] = group.toVector.map(_._1)
        val calculatedRes = subOwner.filter(_.car == fleet.car)
        calculatedRes.map(c => c.copy(qtty = .3 * c.qtty + fleet.qtty))
    }
    .show()
}

/**
  * +---+-----+----+
  * |car|pcode|qtty|
  * +---+-----+----+
  * |  A|  666|39.0|
  * |  A|  444|30.0|
  * |  A|  222|21.0|
  * |  A|  888|45.0|
  * |  A|  666|36.0|
  * |  A|  444|27.0|
  * |  A|  222|18.0|
  * |  A|  888|42.0|
  * +---+-----+----+
  * 
  * +---+-----+----+
  * |car|pcode|qtty|
  * +---+-----+----+
  * |  A|  666|0.3 * 39.0 + 12|
  * |  A|  444|0.3 * 30.0 + 12|
  * |  A|  222|0.3 * 21.0 + 12|
  * |  A|  888|0.3 * 45.0 + 12|
  * +---+-----+----+
  */

上面的第二个表显示了预期的输出。第一个表是这个问题的代码产生的结果。
如何以迭代的方式产生预期的输出?
请注意,计算顺序并不重要,结果会有所不同,但它仍然是一个有效的答案。

afdcj2ne

afdcj2ne1#

检查以下代码。

val getQtty = udf((invoicesQtty:Seq[Double],ownersQtty:Double) => {
    invoicesQtty.tail.foldLeft((0.3 * ownersQtty + invoicesQtty.head))(
      (totalIQ,nextInvoiceQtty) => 0.3 * totalIQ + nextInvoiceQtty
    )
  })

   val getQttyStr = udf((invoicesQtty:Seq[Double],ownersQtty:Double) => {
    val totalIQ = (0.3 * ownersQtty + invoicesQtty.head)
      invoicesQtty.tail.foldLeft("")(
        (data,nextInvoiceQtty) => {
          s"0.3 * ${if(data.isEmpty) totalIQ else s"(${data})"} + ${nextInvoiceQtty}"
        }
      )
  })
owners
    .join(invoices, invoices("car") === owners("car"), "inner")
    .orderBy(invoices("qtty").desc)
    .groupBy(owners("car"),owners("pcode"))
    .agg(
      collect_list(invoices("qtty")).as("invoices_qtty"),
      first(owners("qtty")).as("owners_qtty")
    )
    .withColumn("qtty",getQtty($"invoices_qtty",$"owners_qtty"))
    .withColumn("qtty_str",getQttyStr($"invoices_qtty",$"owners_qtty"))
    .show(false)

结果

+---+-----+-------------+-----------+----+-----------------+
|car|pcode|invoices_qtty|owners_qtty|qtty|qtty_str         |
+---+-----+-------------+-----------+----+-----------------+
|A  |666  |[15.0, 12.0] |80.0       |23.7|0.3 * 39.0 + 12.0|
|A  |888  |[15.0, 12.0] |100.0      |25.5|0.3 * 45.0 + 12.0|
|A  |444  |[15.0, 12.0] |50.0       |21.0|0.3 * 30.0 + 12.0|
|A  |222  |[15.0, 12.0] |20.0       |18.3|0.3 * 21.0 + 12.0|
+---+-----+-------------+-----------+----+-----------------+

相关问题