如何在本例中使用spark agg和filter?

toiithl6  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(399)

我正在研究sparkDataframe方法,并坚持如何实现以下结果。

spark sql(这很有效)

q = """
select breed, 
       avg(weight) as avg_wt, 
       avg(weight) filter (where age > 1) avg_wt_age_gt1
from cats 
group by breed 
order by breed
"""
spark.sql(q).show()

问题:如何使用pysparkDataframe方法得到相同的结果?

我的尝试

(sdf.groupBy("breed").agg(
    F.avg('weight').alias('avg_wt')

# ,F.avg('weight').where(F.col('age')>1).alias('avg_wt')

)
.show()
)

Required output table
+-----------------+-----------------+--------------+
|            breed|           avg_wt|avg_wt_age_gt1|
+-----------------+-----------------+--------------+
|British Shorthair|              4.5|           4.5|
|       Maine Coon|            5.575|         5.575|
|          Persian|4.566666666666666|          4.75|
|          Siamese|              5.8|           5.5|
+-----------------+-----------------+--------------+

设置和数据

import numpy as np
import pandas as pd

import pyspark
from pyspark.sql.types import *
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark import SparkConf, SparkContext, SQLContext
spark = pyspark.sql.SparkSession.builder.appName('app').getOrCreate()
sc = spark.sparkContext
sqlContext = SQLContext(sc)
sqc = sqlContext

# sdf = sqlContext.createDataFrame(df)

df = pd.DataFrame({
    'name': [
        'Molly', 'Ashes', 'Felix', 'Smudge', 'Tigger', 'Alfie', 'Oscar',
        'Millie', 'Misty', 'Puss', 'Smokey', 'Charlie'
    ],
    'breed': [
        'Persian', 'Persian', 'Persian', 'British Shorthair',
        'British Shorthair', 'Siamese', 'Siamese', 'Maine Coon', 'Maine Coon',
        'Maine Coon', 'Maine Coon', 'British Shorthair'
    ],
    'weight': [4.2, 4.5, 5.0, 4.9, 3.8, 5.5, 6.1, 5.4, 5.7, 5.1, 6.1, 4.8],
    'color': [
        'Black', 'Black', 'Tortoiseshell', 'Black', 'Tortoiseshell', 'Brown',
        'Black', 'Tortoiseshell', 'Brown', 'Tortoiseshell', 'Brown', 'Black'
    ],
    'age': [1, 5, 2, 4, 2, 5, 1, 5, 2, 2, 4, 4]
})

schema = StructType([
    StructField('name', StringType(), True),
    StructField('breed', StringType(), True),
    StructField('weight', DoubleType(), True),
    StructField('color', StringType(), True),
    StructField('age', IntegerType(), True),
])

sdf = sqlContext.createDataFrame(df, schema)
sdf.createOrReplaceTempView("cats")
9cbw7uwe

9cbw7uwe1#

你可以使用 when..otherwise 聚合函数中的条件。

from pyspark.sql.functions import avg,when
sdf.groupBy("breed").agg(avg('weight').alias('avg_wt'),
                         avg(when(col('age') > 1,col('weight'))).alias('avg_wt_1')
                        )

相关问题