在R中使用Tidymodels的配方中使用选择函数和行求和

vbopmzt1  于 2023-02-27  发布在  其他
关注(0)|答案(1)|浏览(119)

我有一个 Dataframe ,我想用一个特定的预测器集除以大于零的预测器(也来自该特定集)的数量,当我试图将此操作包含在配方中时,它似乎除以了特定集中的预测器总数,忽略了它应该大于零的条件。
示例:

df <- data.frame(matrix(c(16, 8, 4, 2, 32, 16, 8, 4, 0, 32, 16, 8, 0, 0, 32, 16, 0, 0, 0, 32), 4, 5))

  X1 X2 X3 X4 X5
1 16 32  0  0  0
2  8 16 32  0  0
3  4  8 16 32  0
4  2  4  8 16 32

vars <- names(df)[-1]

df_temp <- df %>% 
  mutate(pos_count = rowSums(df %>% select(all_of(vars)) > 0))

df_temp <- df_temp %>% 
  mutate(across(all_of(vars), .fns = ~./pos_count))

lm_recipe <- 
  recipe(X1 ~ X2 + X3 + X4 + X5, data = df_temp) 

lm_model <- 
  linear_reg(penalty = 0) %>%  
  set_engine("glmnet", lower.limits = rep(0, 5), upper.limits = rep(1, 5), intercept = FALSE)

lm_wflow <- 
  workflow() %>% 
  add_model(lm_model) %>%
  add_recipe(lm_recipe)

lm_fit <- fit(lm_wflow,  df_temp)
lm_fit %>% tidy()

  term        estimate penalty
1 (Intercept)   0            0
2 X2            0.492        0
3 X3            0.240        0
4 X4            0.112        0
5 X5            0.0256       0

这似乎或多或少起作用(估计值应为0, 1/2, 1/4, 1/8 and 1/16)。
但是,当我在配方中加入数据准备时,所有预测因子都除以预测因子总数(在本例中为4个):

lm_recipe <- 
  recipe(X1 ~ X2 + X3 + X4 + X5, data = df) %>% 
  step_mutate(pos_count = sum(all_of(vars) > 0)) %>%
  step_mutate(across(all_of(vars), .fns = ~./pos_count)) 

lm_wflow <- 
  workflow() %>% 
  add_model(lm_model) %>%
  add_recipe(lm_recipe)

lm_fit <- fit(lm_wflow,  df)

lm_fit %>% tidy()

  term        estimate penalty
1 (Intercept)    0           0
2 X2             1           0
3 X3             0.478       0
4 X4             0           0
5 X5             0           0
6 pos_count      0           0

augment(lm_fit, df)

     X1    X2    X3    X4    X5 .pred
1    16    32     0     0     0  8   
2     8    16    32     0     0  7.82
3     4     8    16    32     0  3.91
4     2     4     8    16    32  1.96

我需要怎么改变食谱才能解决这个问题?谢谢!

eqqqjvef

eqqqjvef1#

出现问题的原因是您在step_mutate()中使用了sum(),而不是之前使用的rowSums()

df <- data.frame(matrix(c(16, 8, 4, 2, 32, 16, 8, 4, 0, 32, 16, 8, 0, 0, 32, 16, 0, 0, 0, 32), 4, 5))

vars <- names(df)[-1]

library(recipes)

lm_recipe <- 
  recipe(X1 ~ X2 + X3 + X4 + X5, data = df) %>% 
  step_mutate(pos_count = rowSums(pick(any_of(vars)) > 0)) %>%
  step_mutate(across(any_of(vars), .fns = ~./pos_count))

prep(lm_recipe) |>
  bake(new_data = NULL)
#> # A tibble: 4 × 6
#>      X2    X3    X4    X5    X1 pos_count
#>   <dbl> <dbl> <dbl> <dbl> <dbl>     <dbl>
#> 1 32     0      0       0    16         1
#> 2  8    16      0       0     8         2
#> 3  2.67  5.33  10.7     0     4         3
#> 4  1     2      4       8     2         4

创建于2023年2月17日,使用reprex v2.0.2

相关问题