在spark中,在追加新行时,union()函数是否有其他替代方法?

dldeef67  于 2021-05-27  发布在  Spark
关注(0)|答案(3)|浏览(370)

在我的代码中 table_df 有一些列,我正在上面做一些计算,比如min,max,mean等,我想用指定的schema new\u df\u schema创建新的\u df。在我的逻辑中,我为计算编写了sparksql,并将每个新生成的行追加到最初的空的新的\u df中,最后,它会导致 new_df 所有列的所有计算值。
但问题是,当列的数量更多时,会导致性能问题。在不使用union()函数或任何其他方法来提高性能的情况下,可以做到这一点吗?

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import sparkSession.sqlContext.implicits._

    val table_df = Seq(
      (10, 20, 30, 40, 50),
      (100, 200, 300, 400, 500),
      (111, 222, 333, 444, 555),
      (1123, 2123, 3123, 4123, 5123),
      (1321, 2321, 3321, 4321, 5321)
    ).toDF("col_1", "col_2", "col_3", "col_4", "col_5")
    table_df.show(false)

    table_df.createOrReplaceTempView("table_df")

     val new_df_schema = StructType(
      StructField("Column_Name", StringType, false) ::
        StructField("number_of_values", LongType, false) ::
        StructField("number_of_distinct_values", LongType, false) ::
        StructField("distinct_count_with_nan", LongType, false) ::
        StructField("distinct_count_without_nan", LongType, false) ::
        StructField("is_unique", BooleanType, false) ::
        StructField("number_of_missing_values", LongType, false) ::
        StructField("percentage_of_missing_values", DoubleType, false) ::
        StructField("percentage_of_unique_values", DoubleType, false) ::
        StructField("05_PCT", DoubleType, false) ::
        StructField("25_PCT", DoubleType, false) ::
        StructField("50_PCT", DoubleType, false) ::
        StructField("75_PCT", DoubleType, false) ::
        StructField("95_PCT", DoubleType, false) ::
        StructField("max", DoubleType, false) ::
        StructField("min", DoubleType, false) ::
        StructField("mean", DoubleType, false) ::
        StructField("std", DoubleType, false) ::
        StructField("skewness", DoubleType, false) ::
        StructField("kurtosis", DoubleType, false) ::
        StructField("range", DoubleType, false) ::
        StructField("variance", DoubleType, false) :: Nil
    )
    var new_df = sparkSession.createDataFrame(sparkSession.sparkContext.emptyRDD[Row], new_df_schema)

    for (c <- table_df.columns) {
      val num = sparkSession.sql(
        s"""SELECT
           | '$c' AS Column_Name,
           | COUNT(${c}) AS number_of_values,
           | COUNT(DISTINCT ${c}) AS number_of_distinct_values,
           | COUNT(DISTINCT ${c}) AS distinct_count_with_nan,
           | (COUNT(DISTINCT ${c}) - 1) AS distinct_count_without_nan,
           | (COUNT(${c}) == COUNT(DISTINCT ${c})) AS is_unique,
           | (COUNT(*) - COUNT(${c})) AS number_of_missing_values,
           | ((COUNT(*) - COUNT(${c}))/COUNT(*)) AS percentage_of_missing_values,
           | (COUNT(DISTINCT ${c})/COUNT(*)) AS percentage_of_unique_values,
           | APPROX_PERCENTILE($c,0.05) AS 05_PCT,
           | APPROX_PERCENTILE($c,0.25) AS 25_PCT,
           | APPROX_PERCENTILE($c,0.50) AS 50_PCT,
           | APPROX_PERCENTILE($c,0.75) AS 75_PCT,
           | APPROX_PERCENTILE($c,0.95) AS 95_PCT,
           | MAX($c) AS max,
           | MIN($c) AS min,
           | MEAN($c) AS mean,
           | STD($c) AS std,
           | SKEWNESS($c) AS skewness,
           | KURTOSIS($c) AS kurtosis,
           | (MAX($c) - MIN($c)) AS range,
           | VARIANCE($c) AS variance
           | FROM
           | table_df""".stripMargin)
        .toDF()
      new_df = new_df.union(num) // this results performance issue when then number of columns in table_df is more
    }
    new_df.show(false)

==================================================
table_df:
+-----+-----+-----+-----+-----+
|col_1|col_2|col_3|col_4|col_5|
+-----+-----+-----+-----+-----+
|10   |20   |30   |40   |50   |
|100  |200  |300  |400  |500  |
|111  |222  |333  |444  |555  |
|1123 |2123 |3123 |4123 |5123 |
|1321 |2321 |3321 |4321 |5321 |
+-----+-----+-----+-----+-----+

new_df:

+-----------+----------------+-------------------------+-----------------------+--------------------------+---------+------------------------+----------------------------+---------------------------+------+------+------+------+------+------+----+------+------------------+-------------------+-------------------+------+-----------------+
|Column_Name|number_of_values|number_of_distinct_values|distinct_count_with_nan|distinct_count_without_nan|is_unique|number_of_missing_values|percentage_of_missing_values|percentage_of_unique_values|05_PCT|25_PCT|50_PCT|75_PCT|95_PCT|max   |min |mean  |std               |skewness           |kurtosis           |range |variance         |
+-----------+----------------+-------------------------+-----------------------+--------------------------+---------+------------------------+----------------------------+---------------------------+------+------+------+------+------+------+----+------+------------------+-------------------+-------------------+------+-----------------+
|col_1      |5               |5                        |5                      |4                         |true     |0                       |0.0                         |1.0                        |10.0  |100.0 |111.0 |1123.0|1321.0|1321.0|10.0|533.0 |634.0634826261484 |0.4334269738367067 |-1.7463346405299973|1311.0|402036.5         |
|col_2      |5               |5                        |5                      |4                         |true     |0                       |0.0                         |1.0                        |20.0  |200.0 |222.0 |2123.0|2321.0|2321.0|20.0|977.2 |1141.1895986206673|0.4050513738738682 |-1.799741951675132 |2301.0|1302313.7        |
|col_3      |5               |5                        |5                      |4                         |true     |0                       |0.0                         |1.0                        |30.0  |300.0 |333.0 |3123.0|3321.0|3321.0|30.0|1421.4|1649.399072389699 |0.3979251063785061 |-1.8119558312496054|3291.0|2720517.3        |
|col_4      |5               |5                        |5                      |4                         |true     |0                       |0.0                         |1.0                        |40.0  |400.0 |444.0 |4123.0|4321.0|4321.0|40.0|1865.6|2157.926620624529 |0.39502047381456235|-1.8165124206347685|4281.0|4656647.3        |
|col_5      |5               |5                        |5                      |4                         |true     |0                       |0.0                         |1.0                        |50.0  |500.0 |555.0 |5123.0|5321.0|5321.0|50.0|2309.8|2666.59027598917  |0.3935246673563026 |-1.8186685628112493|5271.0|7110703.699999999|
+-----------+----------------+-------------------------+-----------------------+--------------------------+---------+------------------------+----------------------------+---------------------------+------+------+------+------+------+------+----+------+------------------+-------------------+-------------------+------+-----------------+
4ktjp1zp

4ktjp1zp1#

def main(args: Array[String]): Unit = {
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions._
    import sparkSession.sqlContext.implicits._

    val df = Seq(
      (10, 20, 30, 40, 50),
      (100, 200, 300, 400, 500),
      (10, 222, 333, 444, 555),
      (1123, 2123, 3123, 4123, 5123),
      (1321, 2321, 3321, 4321, 5321)
    ).toDF("col_1", "col_2", "col_3", "col_4", "col_5")
    df.show(false)

    val descExpr = array(
      df.columns.map(c => struct(
        lit(c).cast("string").as("Column_Name"),
        count(col(c)).cast("string").as("number_of_values"),
        countDistinct(col(c)).cast("string").as("number_of_distinct_values"),
        countDistinct(col(c)).cast("string").as("distinct_count_with_nan"),
        (countDistinct(col(c)) - 1).cast("string").as("distinct_count_without_nan"),
        (count(col(c)) === countDistinct(col(c))).cast("string").as("is_unique"),
        (count("*") - count(col(c))).cast("string").as("number_of_missing_values"),
        ((count("*") - count(col(c))) / count("*")).cast("string").as("percentage_of_missing_values"),
        (countDistinct(col(c)) / count("*")).cast("string").as("percentage_of_unique_values"),
        max(col(c)).cast("string").as("max"),
        min(col(c)).cast("string").as("min"),
        mean(col(c)).cast("string").as("mean"),
        (max(col(c)) - min(col(c))).cast("string").as("range"),
        stddev(col(c)).cast("string").as("std"),
        skewness(col(c)).cast("string").as("skewness"),
        kurtosis(col(c)).cast("string").as("kurtosis"),
        variance(col(c)).cast("string").as("variance")
      )
      ): _*
    ).as("data")

    val columns = Seq("Column_Name", "number_of_values", "number_of_distinct_values", "distinct_count_with_nan", "distinct_count_without_nan", "is_unique", "number_of_missing_values", "percentage_of_missing_values", "percentage_of_unique_values", "max", "min", "mean", "range", "std", "skewness", "kurtosis", "variance")
      .map(c => if (c != "Column_Name" && c != "is_unique") col(c).cast("double").as(c) else col(c))

    var df1 = df
      .select(descExpr)
      .selectExpr("explode(data) as data")
      .select("data.*")
      .select(columns: _*)

    df1 = df1
      .withColumn("is_unique", col("is_unique").cast(BooleanType))

    val approxQuantileDF = df
      .columns
      .map(c => (c, df.stat.approxQuantile(c, Array(0.25, 0.5, 0.75), 0.0)))
      .toList
      .toDF("Column_Name", "approx_quantile")
      .select(
        expr("Column_Name"),
        expr("approx_quantile[0] as 05_PCT"),
        expr("approx_quantile[1] as 25_PCT"),
        expr("approx_quantile[2] as 50_PCT"),
        expr("approx_quantile[3] as 75_PCT"),
        expr("approx_quantile[4] as 95_PCT")
      )

    df1 = df1.join(approxQuantileDF, Seq("Column_Name"), "left")
    df1.show(false)
  }

我希望这会有帮助。此代码只是@srinivas提供的答案的扩展

yqkkidmi

yqkkidmi2#

另一个选择是 summary() 数据集内的api,该数据集按以下格式计算basicstats-

ds.summary("count", "min", "25%", "75%", "max").show()

    // output:
    // summary age   height
    // count   10.0  10.0
    // min     18.0  163.0
    // 25%     24.0  176.0
    // 75%     32.0  180.0
    // max     92.0  192.0

类似地,您可以丰富dataframeapi,以获得所需格式的stats,如下所示-

定义要使用的richdataframe和implicits

import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{NumericType, StringType, StructField, StructType}

import scala.language.implicitConversions

class RichDataFrame(ds: DataFrame) {
  def statSummary(statistics: String*): DataFrame = {
    val defaultStatistics = Seq("max", "min", "mean", "std", "skewness", "kurtosis")
    val statFunctions = if (statistics.nonEmpty) statistics else defaultStatistics
    val selectedCols = ds.schema
      .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
      .map(_.name)

    val percentiles = statFunctions.filter(a => a.endsWith("%")).map { p =>
      try {
        p.stripSuffix("%").toDouble / 100.0
      } catch {
        case e: NumberFormatException =>
          throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
      }
    }
    require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
    val aggExprs = selectedCols.flatMap(c => {
      var percentileIndex = 0
      statFunctions.map { stats =>
        if (stats.endsWith("%")) {
          val index = percentileIndex
          percentileIndex += 1
          expr(s"cast(percentile_approx($c, array(${percentiles.mkString(", ")}))[$index] as string)")
        } else {
          expr(s"cast($stats($c) as string)")
        }
      }
    })

    val aggResult = ds.select(aggExprs: _*).head()

    val r = aggResult.toSeq.grouped(statFunctions.length).toArray
      .zip(selectedCols)
      .map{case(seq, column) => column +: seq }
      .map(Row.fromSeq)

    val output = StructField("columns", StringType) +: statFunctions.map(c => StructField(c, StringType))

    val spark = ds.sparkSession
    spark.createDataFrame(spark.sparkContext.parallelize(r), StructType(output))
  }
}

object RichDataFrame {

  trait Enrichment {
    implicit def enrichMetadata(ds: DataFrame): RichDataFrame =
      new RichDataFrame(ds)
  }

  object implicits extends Enrichment

}

使用以下提供的测试数据进行测试

val table_df = Seq(
      (10, 20, 30, 40, 50),
      (100, 200, 300, 400, 500),
      (111, 222, 333, 444, 555),
      (1123, 2123, 3123, 4123, 5123),
      (1321, 2321, 3321, 4321, 5321)
    ).toDF("col_1", "col_2", "col_3", "col_4", "col_5")
    table_df.show(false)
    table_df.printSchema()

    /**
      * +-----+-----+-----+-----+-----+
      * |col_1|col_2|col_3|col_4|col_5|
      * +-----+-----+-----+-----+-----+
      * |10   |20   |30   |40   |50   |
      * |100  |200  |300  |400  |500  |
      * |111  |222  |333  |444  |555  |
      * |1123 |2123 |3123 |4123 |5123 |
      * |1321 |2321 |3321 |4321 |5321 |
      * +-----+-----+-----+-----+-----+
      *
      * root
      * |-- col_1: integer (nullable = false)
      * |-- col_2: integer (nullable = false)
      * |-- col_3: integer (nullable = false)
      * |-- col_4: integer (nullable = false)
      * |-- col_5: integer (nullable = false)
      */

    import RichDataframe.implicits._
    table_df.statSummary()
      .show(false)

    /**
      * +-------+----+---+------+------------------+------------------+-------------------+
      * |columns|max |min|mean  |std               |skewness          |kurtosis           |
      * +-------+----+---+------+------------------+------------------+-------------------+
      * |col_1  |1321|10 |533.0 |634.0634826261484 |0.4334269738367066|-1.7463346405299973|
      * |col_2  |2321|20 |977.2 |1141.1895986206675|0.405051373873868 |-1.7997419516751323|
      * |col_3  |3321|30 |1421.4|1649.399072389699 |0.3979251063785061|-1.8119558312496056|
      * |col_4  |4321|40 |1865.6|2157.926620624529 |0.3950204738145622|-1.816512420634769 |
      * |col_5  |5321|50 |2309.8|2666.5902759891706|0.3935246673563024|-1.81866856281125  |
      * +-------+----+---+------+------------------+------------------+-------------------+
      */

您还可以按如下所示指定所需的函数

import RichDataframe.implicits._
 table_df.statSummary("sum", "count", "25%", "75%")
      .show(false)

    /**
      * +-------+-----+-----+---+----+
      * |columns|sum  |count|25%|75% |
      * +-------+-----+-----+---+----+
      * |col_1  |2665 |5    |100|1123|
      * |col_2  |4886 |5    |200|2123|
      * |col_3  |7107 |5    |300|3123|
      * |col_4  |9328 |5    |400|4123|
      * |col_5  |11549|5    |500|5123|
      * +-------+-----+-----+---+----+
      */

更新-根据ask,新的df计算如下

val count_star = table_df.count()
    table_df.statSummary("count", "approx_count_distinct", "5%", "25%", "50%", "75%", "95%",
    "max", "min", "mean", "std", "SKEWNESS", "KURTOSIS", "VARIANCE")
      .withColumn("count_star", lit(count_star))
      .selectExpr(
        "columns AS Column_Name",
        "COUNT AS number_of_values",
        "approx_count_distinct AS number_of_distinct_values",
        "approx_count_distinct AS distinct_count_with_nan",
        "(approx_count_distinct - 1) AS distinct_count_without_nan",
        "(count == approx_count_distinct) AS is_unique",
        "(count_star - count) AS number_of_missing_values",
        "((count_star - count)/count) AS percentage_of_missing_values",
        "(approx_count_distinct/count) AS percentage_of_unique_values",
        "`5%` AS 05_PCT",
        "`25%` AS 25_PCT",
        "`50%` AS 50_PCT",
        "`75%` AS 75_PCT",
        "`95%` AS 95_PCT",
        "MAX AS max",
        "MIN AS min",
        "MEAN AS mean",
        "STD AS std",
        "SKEWNESS AS skewness",
        "KURTOSIS AS kurtosis",
        "(MAX - MIN) AS range",
        "VARIANCE AS variance"
      ).show(false)

    /**
      * +-----------+----------------+-------------------------+-----------------------+--------------------------+---------+------------------------+----------------------------+---------------------------+------+------+------+------+------+----+---+------+------------------+------------------+-------------------+------+------------------+
      * |Column_Name|number_of_values|number_of_distinct_values|distinct_count_with_nan|distinct_count_without_nan|is_unique|number_of_missing_values|percentage_of_missing_values|percentage_of_unique_values|05_PCT|25_PCT|50_PCT|75_PCT|95_PCT|max |min|mean  |std               |skewness          |kurtosis           |range |variance          |
      * +-----------+----------------+-------------------------+-----------------------+--------------------------+---------+------------------------+----------------------------+---------------------------+------+------+------+------+------+----+---+------+------------------+------------------+-------------------+------+------------------+
      * |col_1      |5               |5                        |5                      |4.0                       |true     |0.0                     |0.0                         |1.0                        |10    |100   |111   |1123  |1321  |1321|10 |533.0 |634.0634826261484 |0.4334269738367066|-1.7463346405299973|1311.0|402036.5          |
      * |col_2      |5               |5                        |5                      |4.0                       |true     |0.0                     |0.0                         |1.0                        |20    |200   |222   |2123  |2321  |2321|20 |977.2 |1141.1895986206675|0.405051373873868 |-1.7997419516751323|2301.0|1302313.7000000002|
      * |col_3      |5               |5                        |5                      |4.0                       |true     |0.0                     |0.0                         |1.0                        |30    |300   |333   |3123  |3321  |3321|30 |1421.4|1649.399072389699 |0.3979251063785061|-1.8119558312496056|3291.0|2720517.3         |
      * |col_4      |5               |5                        |5                      |4.0                       |true     |0.0                     |0.0                         |1.0                        |40    |400   |444   |4123  |4321  |4321|40 |1865.6|2157.926620624529 |0.3950204738145622|-1.816512420634769 |4281.0|4656647.3         |
      * |col_5      |5               |5                        |5                      |4.0                       |true     |0.0                     |0.0                         |1.0                        |50    |500   |555   |5123  |5321  |5321|50 |2309.8|2666.5902759891706|0.3935246673563024|-1.81866856281125  |5271.0|7110703.7         |
      * +-----------+----------------+-------------------------+-----------------------+--------------------------+---------+------------------------+----------------------------+---------------------------+------+------+------+------+------+----+---+------+------------------+------------------+-------------------+------+------------------+
      */
bgibtngc

bgibtngc3#

替代品 union .
检查以下代码。

scala> df.show(false)
+-----+-----+-----+-----+-----+
|col_1|col_2|col_3|col_4|col_5|
+-----+-----+-----+-----+-----+
|10   |20   |30   |40   |50   |
|100  |200  |300  |400  |500  |
|111  |222  |333  |444  |555  |
|1123 |2123 |3123 |4123 |5123 |
|1321 |2321 |3321 |4321 |5321 |
+-----+-----+-----+-----+-----+

生成所需的表达式。

scala> val descExpr = array(
    df.columns
    .map(c => struct(
        lit(c).cast("string").as("column_name"),
        max(col(c)).cast("string").as("max"),
        min(col(c)).cast("string").as("min"),
        mean(col(c)).cast("string").as("mean"),
        stddev(col(c)).cast("string").as("std"),
        skewness(col(c)).cast("string").as("skewness"),
        kurtosis(col(c)).cast("string").as("kurtosis")
        )
    ):_*
).as("data")

必需的列。

val columns = Seq("column_name","max","min","mean","std","skewness","kurtosis")
 .map(c => if(c != "column_name") col(c).cast("double").as(c) else col(c))```

最终输出

scala> df
 .select(descExpr)
 .selectExpr("explode(data) as data")
 .select("data.*")
 .select(columns:_*)
 .show(false)

+-----------+------+----+------+------------------+-------------------+-------------------+
|column_name|max   |min |mean  |std               |skewness           |kurtosis           |
+-----------+------+----+------+------------------+-------------------+-------------------+
|col_1      |1321.0|10.0|533.0 |634.0634826261484 |0.43342697383670664|-1.7463346405299978|
|col_2      |2321.0|20.0|977.2 |1141.1895986206673|0.4050513738738679 |-1.7997419516751327|
|col_3      |3321.0|30.0|1421.4|1649.3990723896993|0.397925106378506  |-1.8119558312496056|
|col_4      |4321.0|40.0|1865.6|2157.9266206245293|0.3950204738145622 |-1.8165124206347691|
|col_5      |5321.0|50.0|2309.8|2666.5902759891706|0.3935246673563026 |-1.81866856281125  |
+-----------+------+----+------+------------------+-------------------+-------------------+

更新

scala> val finalDF = df.select(descExpr).selectExpr("explode(data) as data").select("data.*").select(columns:_*)

使用创建新Dataframe Approx Quantile 对于所有列。

scala> val approxQuantileDF = df
.columns
.map(c => (c,df.stat.approxQuantile(c,Array(0.25,0.5,0.75),0.0)))
.toList
.toDF("column_name","approx_quantile")
scala> finalDF
        .join(approxQuantileDF,
              Seq("column_name"),
              "left"
    ).show(false)
+-----------+------+----+------+------------------+-------------------+-------------------+----------------------+
|column_name|max   |min |mean  |std               |skewness           |kurtosis           |approx_quantile       |
+-----------+------+----+------+------------------+-------------------+-------------------+----------------------+
|col_1      |1321.0|10.0|533.0 |634.0634826261484 |0.43342697383670664|-1.7463346405299978|[100.0, 111.0, 1123.0]|
|col_2      |2321.0|20.0|977.2 |1141.1895986206673|0.4050513738738679 |-1.7997419516751327|[200.0, 222.0, 2123.0]|
|col_3      |3321.0|30.0|1421.4|1649.3990723896993|0.397925106378506  |-1.8119558312496056|[300.0, 333.0, 3123.0]|
|col_4      |4321.0|40.0|1865.6|2157.9266206245293|0.3950204738145622 |-1.8165124206347691|[400.0, 444.0, 4123.0]|
|col_5      |5321.0|50.0|2309.8|2666.5902759891706|0.3935246673563026 |-1.81866856281125  |[500.0, 555.0, 5123.0]|
+-----------+------+----+------+------------------+-------------------+-------------------+----------------------+

相关问题