根据列值拆分数据集的行

idfiyjo8  于 2021-07-14  发布在  Spark
关注(0)|答案(1)|浏览(412)

我正在使用 Spark 3.1.1 随着 JAVA 8 ,我正在试着分开一个 dataset<Row> 根据其中一个数值列(大于或小于阈值)的值,只有当行的某些字符串列值相同时,才可能进行拆分:我正在尝试以下操作:

Iterator<Row> iter2 = partition.toLocalIterator();                   
                while (iter2.hasNext()) {
                    Row item = iter2.next();
                      //getColVal is a function that gets the value given a column
                    String numValue = getColVal(item, dim);
                    if (Integer.parseInt(numValue) < threshold)    
                        pl.add(item);  
                    else
                        pr.add(item);

但是如何在拆分之前检查相关行的其他列值(字符串)是否相同,以便执行拆分?
ps:在拆分之前,我尝试按列分组,如下所示:

Dataset<Row> newDataset=oldDataset.groupBy("col1","col4").agg(col("col1"));

但它不起作用
谢谢你的帮助
编辑:
我要拆分的示例数据集是:

abc,9,40,A
abc,7,50,A
cde,4,20,B
cde,3,25,B

如果阈值为 30 然后第一行和最后一行将形成两个数据集,因为它们的第一列和第四列是相同的;否则,分割是不可能的。
编辑:结果输出为

abc,9,40,A
    abc,7,50,A

    cde,4,20,B
    cde,3,25,B
nhjlsmyf

nhjlsmyf1#

我主要使用 pyspark 但是你可以适应你的环境


## could add some conditional logic or just always output 2 data frames where

## one would be empty

print("pdf - two dataframe")

## create pandas dataframe

pdf = pd.DataFrame({'col1':['abc','abc','cde','cde'],'col2':[9,7,4,3],'col3':[40,50,20,25],'col4':['A','A','B','B']})
print( pdf )

## move it to spark

print("sdf")
sdf = spark.createDataFrame(pdf) 

sdf.show()

# +----+----+----+----+

# |col1|col2|col3|col4|

# +----+----+----+----+

# | abc|   9|  40|   A|

# | abc|   7|  50|   A|

# | cde|   4|  20|   B|

# | cde|   3|  25|   B|

# +----+----+----+----+

## filter

pl = sdf.filter('col3 <= 30')\
        .groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))
pr = sdf.filter('col3 > 30')\
        .groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))
print("pl")
pl.show()

# +----+----+-----+

# |col1|col4|sumC2|

# +----+----+-----+

# | cde|   B|    7|

# +----+----+-----+

print("pr")
pr.show()

# +----+----+-----+

# |col1|col4|sumC2|

# +----+----+-----+

# | abc|   A|   16|

# +----+----+-----+

print("pdf - one dataframe")

## create pandas dataframe

pdf = pd.DataFrame({'col1':['abc','abc','cde','cde'],'col2':[9,7,4,3],'col3':[11,29,20,25],'col4':['A','A','B','B']})
print( pdf )

## move it to spark

print("sdf")
sdf = spark.createDataFrame(pdf) 
sdf.show()

# +----+----+----+----+

# |col1|col2|col3|col4|

# +----+----+----+----+

# | abc|   9|  11|   A|

# | abc|   7|  29|   A|

# | cde|   4|  20|   B|

# | cde|   3|  25|   B|

# +----+----+----+----+

pl = sdf.filter('col3 <= 30')\
        .groupBy("col1","col4").agg( F.sum('col2').alias('sumC2') )
pr = sdf.filter('col3 > 30')\
        .groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))

print("pl")
pl.show()

# +----+----+-----+

# |col1|col4|sumC2|

# +----+----+-----+

# | abc|   A|   16|

# | cde|   B|    7|

# +----+----+-----+

print("pr")
pr.show()

# +----+----+-----+

# |col1|col4|sumC2|

# +----+----+-----+

# +----+----+-----+

动态均值滤波

print("pdf - filter by mean")

## create pandas dataframe

pdf = pd.DataFrame({'col1':['abc','abc','cde','cde'],'col2':[9,7,4,3],'col3':[40,50,20,25],'col4':['A','A','B','B']})
print( pdf )

## move it to spark

print("sdf")
sdf = spark.createDataFrame(pdf) 
sdf.show()

# +----+----+----+----+

# |col1|col2|col3|col4|

# +----+----+----+----+

# | abc|   9|  40|   A|

# | abc|   7|  50|   A|

# | cde|   4|  20|   B|

# | cde|   3|  25|   B|

# +----+----+----+----+

w = Window.partitionBy("col1").orderBy("col2")

## add another column, the mean of col2 partitioned by col1

sdf = sdf.withColumn('mean_c2', F.mean('col2').over(w))

## filter by the dynamic mean

pr = sdf.filter('col2 > mean_c2')
pr.show()

# +----+----+----+----+-------+

# |col1|col2|col3|col4|mean_c2|

# +----+----+----+----+-------+

# | cde|   4|  20|   B|    3.5|

# | abc|   9|  40|   A|    8.0|

# +----+----+----+----+-------+

相关问题