rust 如何计算不同列的行之间的标准差和平均值?

t5zmwmid  于 11个月前  发布在  其他
关注(0)|答案(1)|浏览(86)

我想计算所有行中以Sum of EBIT [CY 2]开头的列的平均值和标准差。我可以通过将10列相加并除以10来计算平均值。

如下所示:

pub fn industry_beta_f(raw_data:DataFrame, marginal_tax_rate:Expr) -> DataFrame{

    let df = raw_data.clone().lazy()
                                .with_columns([
                                    ((col("Sum of EBIT [CY 2011] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2012] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2013] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2014] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2015] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2016] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2017] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2018] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2019] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2020] ($USDmm, Historical rate)")) / lit(10.0)).alias("moments_mean"),
                                    (col("Sum of EBIT [CY 2011] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2012] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2013] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2014] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2015] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2016] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2017] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2018] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2019] ($USDmm, Historical rate)") + col("Sum of EBIT [CY 2020] ($USDmm, Historical rate)")).std(1).alias("moments_std"),

                                ])
                                .with_columns([
                                    when(col("moments_mean").gt(lit(0.0)))
                                    .then(col("moments_std") / col("moments_mean"))
                                    .otherwise(f64::NAN)
                                    .alias("Standard deviation in operating income (last 10 years)")
                                ])
                                .select([col("Industry Name"),
                                        col("Number of firms"),
                                        col("Standard deviation in operating income (last 10 years)")])
                                .collect()
                                .unwrap();
    return df
}

字符串
我在计算以Sum of EBIT [CY 2]开头的列的所有行的标准差时遇到了麻烦。因为使用std()公式,它计算的是每列的标准差,而不是跨行的标准差。
Current Output
Expected Output
两个输出之间存在巨大的差距。因为,对于当前输出,std是跨列计算的,而对于预期输出,std是跨行计算的。

8hhllhi2

8hhllhi21#

您可以使用DataFrame的内置mean_horizontal方法计算均值。
标准差不支持开箱即用,因此有点棘手。首先计算平均值,然后计算平方误差和,然后将其除以列数-1,如下所示:

use polars::prelude::*;

fn main() -> PolarsResult<()> {
    let mut df = df! (
        "col_1" => &[1, 2, 3, 4, 5],
        "col_2" => &[2, 4, 6, 8, 5],
        "col_3" => &[10, 8, 6, 4, 5],
    )?;

    let n = df.get_column_names().len() as i32;

    let col_mean = df.mean_horizontal(polars::frame::NullStrategy::Ignore)?;

    let df_w_mean: &mut DataFrame;
    if let Some(mean) = col_mean {
        df_w_mean = df.with_column(mean.with_name("col_mean"))?;
    } else {
        return Err(PolarsError::ComputeError("No mean can be calculated".into()));
    }
    
    let sse = df_w_mean
        .clone()
        .lazy()
        .with_column((col("*") - col("col_mean")).pow(2))
        .collect()?
        .sum_horizontal(polars::frame::NullStrategy::Ignore)?;
    
    let df_w_sse: &mut DataFrame;
    if let Some(sse) = sse {
        df_w_sse = df_w_mean.with_column(sse.with_name("col_std"))?;
    } else {
        return Err(PolarsError::ComputeError("No sse can be caculated".into()));
    }

    let result = df_w_sse.clone().lazy().with_column((col("col_std") / lit(n - 1)).sqrt()).collect()?;
    println!("{:?}", result);

    Ok(())
}

字符串

相关问题