R语言 如何在使用tidymodels在折叠的每个训练部分上训练的每个交叉验证折叠中应用预处理?

hivapdat  于 2023-04-18  发布在  其他
关注(0)|答案(1)|浏览(121)

我尝试使用tidymodels R包来创建ml管道。(一个配方),并将其应用于我的交叉验证的每个重新样本。但这使用了(全局)训练数据来预处理褶皱。我认为正确的做法是在每个“分析”上定义一个预处理配方在一些实施例中,用户可以将其应用于折叠的“评估”(即,训练)部分,并将其应用于折叠的“评估”(即,测试)部分。
下面的代码给出了我的问题的一个例子:

library(tidyverse)
library(tidymodels)

set.seed(1000)

mtcars = mtcars |> select(mpg, hp)
init_split <- initial_split(mtcars, prop = 0.9)

preprocessing_recipe <- recipe(mpg ~ hp,
                           data = training(init_split)
) |>
step_normalize(all_predictors())
preprocessing_recipe = preprocessing_recipe %>% prep()
preprocessing_recipe

cv_folds <-  bake(preprocessing_recipe, new_data = training(init_split)) %>%
vfold_cv(v = 3)

## these resamples are not properly scaled:

training(cv_folds$splits[[1]]) %>% lapply(mean)

## $hp
## [1] 0.1442218

training(cv_folds$splits[[1]]) %>% lapply(sd)

## $hp
## [1] 1.167365

## while the preprocessing on the training data leads to exactly scaled data:

preprocessing_recipe$template %>% lapply(mean)

## $hp
## [1] -1.249001e-16

preprocessing_recipe$template %>% lapply(sd)

## $hp
## [1] 1

上面失败的原因很清楚。但是我如何改变上面的管道(高效,优雅)来定义fold的每个train部分的配方,并将其应用到测试部分?在我看来,这是避免数据泄漏的方法。我在任何帖子的文档中没有找到任何提示。谢谢!

nbewdwxp

nbewdwxp1#

当你使用一个配方时,你是整个生产线的一部分,您不太可能希望自己在诊断目的之外使用prep()bake()。我们推荐的是将配方与workflow()一起使用,以便能够将其附加到建模模型。这里我添加了一个线性回归规范。这两个可以一起使用fit()predict()。但是你也可以将它们放入你的交叉验证循环中,根据你的需要使用fit_resamples()tune_grid()
有关详细信息,请参阅:

library(tidyverse)
library(tidymodels)

set.seed(1000)

mtcars <- mtcars |> 
  select(mpg, hp)
init_split <- initial_split(mtcars, prop = 0.9)
mtcars_training <- training(init_split)

mtcars_folds <- vfold_cv(mtcars_training, v = 3)

preprocessing_recipe <- recipe(mpg ~ hp,
                               data = mtcars_training) |>
  step_normalize(all_predictors())

lm_spec <- linear_reg()

wf_spec <- workflow() |>
  add_recipe(preprocessing_recipe) |>
  add_model(lm_spec)

resampled_fits <- fit_resamples(
  wf_spec,
  resamples = mtcars_folds,
  control = control_resamples(extract = function(x) {
    tidy(x, "recipe", number = 1)
  })
)

通过查看配方的估计值,我们可以看到工作流适合每个折叠。我在control_resamples()extract参数中添加了一个函数,可以提取配方中计算的训练均值和sd。

resampled_fits |> 
  collect_extracts() |> 
  pull(.extracts)
#> [[1]]
#> # A tibble: 2 × 4
#>   terms statistic value id             
#>   <chr> <chr>     <dbl> <chr>          
#> 1 hp    mean      140.  normalize_x5pUR
#> 2 hp    sd         77.3 normalize_x5pUR
#> 
#> [[2]]
#> # A tibble: 2 × 4
#>   terms statistic value id             
#>   <chr> <chr>     <dbl> <chr>          
#> 1 hp    mean      144.  normalize_x5pUR
#> 2 hp    sd         57.4 normalize_x5pUR
#> 
#> [[3]]
#> # A tibble: 2 × 4
#>   terms statistic value id             
#>   <chr> <chr>     <dbl> <chr>          
#> 1 hp    mean      150.  normalize_x5pUR
#> 2 hp    sd         74.9 normalize_x5pUR

我们可以看到它们与原始褶皱的均值和sd相匹配

mtcars_folds$splits |>
  map(analysis) |>
  map(~ tibble(mean = mean(.x$hp), sd = sd(.x$hp)))
#> [[1]]
#> # A tibble: 1 × 2
#>    mean    sd
#>   <dbl> <dbl>
#> 1  140.  77.3
#> 
#> [[2]]
#> # A tibble: 1 × 2
#>    mean    sd
#>   <dbl> <dbl>
#> 1  144.  57.4
#> 
#> [[3]]
#> # A tibble: 1 × 2
#>    mean    sd
#>   <dbl> <dbl>
#> 1  150.  74.9

相关问题