java 基于行数的Split Spark数据集

dw1jzc5e  于 2023-06-28  发布在  Java
关注(0)|答案(1)|浏览(92)

我从dynamo db读取数据并将其存储在SparkDataset中,如下所示:

// Building a dataset
            Dataset citations = sparkSession.read()
                .option("tableName", "Covid19Citation")
                .option("region", "eu-west-1")
                .format("dynamodb")
                .load();

我想要的是根据行数分割这个数据集。
例如,如果数据集有超过500行,我想拆分它,并将每个数据集保存为单独的csv文件。因此,我想保存的每个数据集最多应该有500行。如果数据库中有1600行,输出应该是四个xml文件:
第一个xml文件包含500行
第二个xml文件也包含500行
第三个xml文件也包含500行,最后是
第四个xml文件,包含100行。
这是我目前为止尝试过的,但它不起作用:

List<Dataset> datasets = new ArrayList<>();
            while (citations.count() > 0) {
                  Dataset splitted = citations.limit(400);
                  datasets.add(splitted);
                  citations = citations.except(splitted);
            }

            System.out.println("datasets : " + datasets.size());
            for (Dataset d : datasets) {
                  code
                  d.coalesce(1)
                      .write()
                      .format("com.databricks.spark.xml")
                      .option("rootTag", "citations")
                      .option("rowTag", "citation")
                      .mode("overwrite")
                      .save("s3a://someoutputfolder/");
            }

任何帮助都将不胜感激。
谢谢

vfh0ocws

vfh0ocws1#

您可以利用:

  • row_numbermod:将数据集拆分为500个部分
  • repartition:为每个分区生成一个文件
  • partitionBy:为每个分区写入一个xml

这里是scala / parquet中的一个例子(但是你也可以使用xml

val citations = spark.range(1, 2000000).selectExpr("id", "hash(id) value")

// calculate the number of buckets
val total = citations.count
val mod = (total.toFloat / 500).ceil.toInt

citations
.withColumn("id", expr("row_number() over(order by monotonically_increasing_id())"))
.withColumn("bucket", expr(f"mod(id, ${mod})"))
.repartition('bucket)
.write
.partitionBy("bucket")
.format("parquet")
.mode("overwrite")
.save("/tmp/foobar")

// now check the results
val resultDf = spark.read.format("parquet").load("/tmp/foobar")

// as a result you get at most 500 rows
scala> resultDf.groupBy("bucket").count.show
+------+-----+
|bucket|count|
+------+-----+
|  1133|  500|
|  1771|  500|
|  1890|  500|
|  3207|  500|
|  3912|  500|
|  1564|  500|
|  2823|  500|
+------+-----+

// there is no file with more than 500 rows
scala> resultDf.groupBy("bucket").count.filter("count > 500").show
+------+-----+
|bucket|count|
+------+-----+
+------+-----+

// now check there is only one file per bucket
scala> spark.sparkContext.parallelize(resultDf.inputFiles).toDF
.withColumn("part", expr("regexp_extract(value,'(bucket=([0-9]+))')"))
.groupBy("part").count.withColumnRenamed("count", "nb_files")
.orderBy(desc("nb_files")).show(5)
+-----------+--------+
|       part|nb_files|
+-----------+--------+
|bucket=3209|       1|
|bucket=1290|       1|
|bucket=3354|       1|
|bucket=2007|       1|
|bucket=2816|       1|
+-----------+--------+

相关问题