用dplyr拟合几个回归模型

iyr7buue  于 2023-05-20  发布在  其他
关注(0)|答案(8)|浏览(111)

我想用dplyr为每个小时(因子变量)拟合一个模型,我得到了一个错误,我不太确定是什么问题。

df.h <- data.frame( 
  hour     = factor(rep(1:24, each = 21)),
  price    = runif(504, min = -10, max = 125),
  wind     = runif(504, min = 0, max = 2500),
  temp     = runif(504, min = - 10, max = 25)  
)

df.h <- tbl_df(df.h)
df.h <- group_by(df.h, hour)

group_size(df.h) # checks out, 21 obs. for each factor variable

# different attempts:
reg.models <- do(df.h, formula = price ~ wind + temp)

reg.models <- do(df.h, .f = lm(price ~ wind + temp, data = df.h))

我试过各种不同的方法,但就是不管用。

4dc9hkyq

4dc9hkyq1#

在2015年5月左右,最简单的方法是使用broombroom包含三个函数,用于按组处理来自统计操作的复杂返回对象:tidy(处理来自按组统计操作的系数向量)、glance(处理来自按组统计操作的汇总统计)和augment(处理来自按组统计操作的观察水平结果)。
下面演示了如何使用它将线性回归的各种结果按组提取到整齐的data_frame s中。
1.tidy

library(dplyr)
library(broom)

df.h = data.frame( 
  hour     = factor(rep(1:24, each = 21)),
  price    = runif(504, min = -10, max = 125),
  wind     = runif(504, min = 0, max = 2500),
  temp     = runif(504, min = - 10, max = 25)  
)

dfHour = df.h %>% group_by(hour) %>%
  do(fitHour = lm(price ~ wind + temp, data = .))

# get the coefficients by group in a tidy data_frame
dfHourCoef = tidy(dfHour, fitHour)
dfHourCoef

其给出,

Source: local data frame [72 x 6]
    Groups: hour

hour        term     estimate   std.error  statistic     p.value
1     1 (Intercept) 53.336069324 21.33190104  2.5002961 0.022294293
2     1        wind -0.008475175  0.01338668 -0.6331053 0.534626575
3     1        temp  1.180019541  0.79178607  1.4903262 0.153453756
4     2 (Intercept) 77.737788772 23.52048754  3.3051096 0.003936651
5     2        wind -0.008437212  0.01432521 -0.5889765 0.563196358
6     2        temp -0.731265113  1.00109489 -0.7304653 0.474506855
7     3 (Intercept) 38.292039924 17.55361626  2.1814331 0.042655670
8     3        wind  0.005422492  0.01407478  0.3852630 0.704557388
9     3        temp  0.426765270  0.83672863  0.5100402 0.616220435
10    4 (Intercept) 30.603119492 21.05059583  1.4537888 0.163219027
..  ...         ...          ...         ...        ...         ...

1.augment

# get the predictions by group in a tidy data_frame
dfHourPred = augment(dfHour, fitHour)
dfHourPred

其给出,

Source: local data frame [504 x 11]
Groups: hour

hour       price      wind      temp  .fitted  .se.fit     .resid       .hat   .sigma      .cooksd .std.resid
1     1  83.8414055   67.3780 -6.199231 45.44982 22.42649  38.391590 0.27955950 42.24400 0.1470891067  1.0663820
2     1   0.3061628 2073.7540 15.134085 53.61916 14.10041 -53.312993 0.11051343 41.43590 0.0735584714 -1.3327207
3     1  80.3790032  520.5949 24.711938 78.08451 20.03558   2.294497 0.22312869 43.64059 0.0003606305  0.0613746
4     1 121.9023855 1618.0864 12.382588 54.23420 10.31293  67.668187 0.05911743 40.23212 0.0566557575  1.6447224
5     1  -0.4039594 1542.8150 -5.544927 33.71732 14.53349 -34.121278 0.11740628 42.74697 0.0325125137 -0.8562896
6     1  29.8269832  396.6951  6.134694 57.21307 16.04995 -27.386085 0.14318542 43.05124 0.0271028701 -0.6975290
7     1  -7.1865483 2009.9552 -5.657871 29.62495 16.93769 -36.811497 0.15946292 42.54487 0.0566686969 -0.9466312
8     1  -7.8548693 2447.7092 22.043029 58.60251 19.94686 -66.457379 0.22115706 39.63999 0.2983443034 -1.7753911
9     1  94.8736726 1525.3144 24.484066 69.30044 15.93352  25.573234 0.14111563 43.12898 0.0231796755  0.6505701
10    1  54.4643001 2473.2234 -7.656520 23.34022 21.83043  31.124076 0.26489650 42.74790 0.0879837510  0.8558507
..  ...         ...       ...       ...      ...      ...        ...        ...      ...          ...        ...

1.glance

# get the summary statistics by group in a tidy data_frame
dfHourSumm = glance(dfHour, fitHour)
dfHourSumm

其给出,

Source: local data frame [24 x 12]
Groups: hour

hour  r.squared adj.r.squared    sigma statistic    p.value df    logLik      AIC      BIC deviance df.residual
1     1 0.12364561    0.02627290 42.41546 1.2698179 0.30487225  3 -106.8769 221.7538 225.9319 32383.29          18
2     2 0.03506944   -0.07214506 36.79189 0.3270961 0.72521125  3 -103.8900 215.7799 219.9580 24365.58          18
3     3 0.02805424   -0.07993974 39.33621 0.2597760 0.77406651  3 -105.2942 218.5884 222.7665 27852.07          18
4     4 0.17640603    0.08489559 41.37115 1.9277147 0.17434859  3 -106.3534 220.7068 224.8849 30808.30          18
5     5 0.12575453    0.02861615 42.27865 1.2945915 0.29833246  3 -106.8091 221.6181 225.7962 32174.72          18
6     6 0.08114417   -0.02095092 35.80062 0.7947901 0.46690268  3 -103.3164 214.6328 218.8109 23070.31          18
7     7 0.21339168    0.12599076 32.77309 2.4415266 0.11529934  3 -101.4609 210.9218 215.0999 19333.36          18
8     8 0.21655629    0.12950699 40.92788 2.4877430 0.11119114  3 -106.1272 220.2543 224.4324 30151.65          18
9     9 0.23388711    0.14876346 35.48431 2.7476160 0.09091487  3 -103.1300 214.2601 218.4381 22664.45          18
10   10 0.18326177    0.09251307 40.77241 2.0194425 0.16171339  3 -106.0472 220.0945 224.2726 29923.01          18
..  ...        ...           ...      ...       ...        ... ..       ...      ...      ...      ...         ...
63lcw9qa

63lcw9qa2#

截至2020年中期(并更新为适应dplyr 1.0+截至2022-04),tchakravarty's answer将失败。为了避免broomdpylr似乎相互作用的新方法,可以使用以下broom::tidybroom::augmentbroom::glance的组合。我们只需要将它们与nest_by()summarize()结合使用(以前在do()和后来的unnest()中)。

library(dplyr)
library(broom)
library(tidyr)

set.seed(42)
df.h = data.frame( 
  hour     = factor(rep(1:24, each = 21)),
  price    = runif(504, min = -10, max = 125),
  wind     = runif(504, min = 0, max = 2500),
  temp     = runif(504, min = - 10, max = 25)  
)

df.h %>%
  nest_by(hour) %>%
  mutate(mod = list(lm(price ~ wind + temp, data = data))) %>%
  summarize(tidy(mod))
# # A tibble: 72 × 6
# # Groups:   hour [24]
#    hour  term        estimate std.error statistic   p.value
#    <fct> <chr>          <dbl>     <dbl>     <dbl>     <dbl>
# 1  1     (Intercept) 87.4       15.8        5.55  0.0000289
# 2  1     wind        -0.0129     0.0120    -1.08  0.296    
# 3  1     temp         0.588      0.849      0.693 0.497    
# 4  2     (Intercept) 92.3       21.6        4.27  0.000466 
# 5  2     wind        -0.0227     0.0134    -1.69  0.107    
# 6  2     temp        -0.216      0.841     -0.257 0.800    
# 7  3     (Intercept) 61.1       18.6        3.29  0.00409  
# 8  3     wind         0.00471    0.0128     0.367 0.718    
# 9  3     temp         0.425      0.964      0.442 0.664    
# 10 4     (Intercept) 31.6       15.3        2.07  0.0529   

df.h %>% 
  nest_by(hour) %>%
  mutate(mod = list(lm(price ~ wind + temp, data = data))) %>% 
  summarize(augment(mod))
# # A tibble: 504 × 10
# # Groups:   hour [24]
#    hour   price  wind   temp .fitted .resid   .hat .sigma  .cooksd .std.resid
#    <fct>  <dbl> <dbl>  <dbl>   <dbl>  <dbl>  <dbl>  <dbl>    <dbl>      <dbl>
#  1 1     113.    288. -1.75     82.7  30.8  0.123    37.8 0.0359       0.877 
#  2 1     117.   2234. 18.4      69.5  47.0  0.201    36.4 0.165        1.40  
#  3 1      28.6  1438.  4.75     71.7 -43.1  0.0539   37.1 0.0265      -1.18  
#  4 1     102.    366.  9.77     88.5  13.7  0.151    38.4 0.00926      0.395 
#  5 1      76.6  2257. -4.69     55.6  21.0  0.245    38.2 0.0448       0.644 
#  6 1      60.1   633. -3.18     77.4 -17.3  0.0876   38.4 0.00749     -0.484 
#  7 1      89.4   376. -4.16     80.1   9.31 0.119    38.5 0.00314      0.264 
#  8 1       8.18 1921. 19.2      74.0 -65.9  0.173    34.4 0.261       -1.93  
#  9 1      78.7   575. -6.11     76.4   2.26 0.111    38.6 0.000170     0.0640
# 10 1      85.2   763. -0.618    77.2   7.94 0.0679   38.6 0.00117      0.219 
# # … with 494 more rows

df.h %>% 
  nest_by(hour) %>%
  mutate(mod = list(lm(price ~ wind + temp, data = data))) %>% 
  summarize(glance(mod))
# # A tibble: 24 × 13
# # Groups:   hour [24]
#    hour  r.squared adj.r.squared sigma statistic p.value    df logLik   AIC
#    <fct>     <dbl>         <dbl> <dbl>     <dbl>   <dbl> <dbl>  <dbl> <dbl>
#  1 1        0.0679       -0.0357  37.5     0.655   0.531     2  -104.  217.
#  2 2        0.139         0.0431  42.7     1.45    0.261     2  -107.  222.
#  3 3        0.0142       -0.0953  43.1     0.130   0.879     2  -107.  222.
#  4 4        0.0737       -0.0293  36.7     0.716   0.502     2  -104.  216.
#  5 5        0.213         0.126   37.8     2.44    0.115     2  -104.  217.
#  6 6        0.0813       -0.0208  33.5     0.796   0.466     2  -102.  212.
#  7 7        0.0607       -0.0437  40.7     0.582   0.569     2  -106.  220.
#  8 8        0.153         0.0592  36.3     1.63    0.224     2  -104.  215.
#  9 9        0.166         0.0736  36.5     1.79    0.195     2  -104.  216.
# 10 10       0.110         0.0108  40.0     1.11    0.351     2  -106.  219.
# # … with 14 more rows, and 4 more variables: BIC <dbl>, deviance <dbl>,
# #   df.residual <int>, nobs <int>

感谢Bob Muenchen's Blog的启发。

5hcedyr0

5hcedyr03#

在dplyr 0.4中,您可以执行以下操作:

df.h %>% do(model = lm(price ~ wind + temp, data = .))
kqlmhetl

kqlmhetl4#

do的文档:
.f:应用于每个片段的函数。提供给.f的第一个未命名参数将是一个 Dataframe 。
所以:

reg.models <- do(df.h, 
                 .f=function(data){
                     lm(price ~ wind + temp, data=data)
                 })

可能还有助于保存模型适合的时间:

reg.models <- do(df.h, 
                 .f=function(data){
                     m <- lm(price ~ wind + temp, data=data)
                     m$hour <- unique(data$hour)
                     m
                 })
qlfbtfca

qlfbtfca5#

我相信有一个比loki's answer更简洁的答案,它放弃了自替换的/supersededdo()

library(dplyr)
library(broom)
library(tidyr)

h.lm <- df.h %>%
      nest_by(hour) %>%
      mutate(fitHour = list(lm(price ~ wind + temp, data = data))) %>%
      summarise(tidy_out = list(tidy(fitHour)),
                glance_out = list(glance(fitHour)),
                augment_out = list(augment(fitHour))) %>%
      ungroup()

h.lm
# # A tibble: 24 x 4
#    hour  tidy_out         glance_out        augment_out
#    <fct> <list>           <list>            <list>
#  1 1     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  2 2     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  3 3     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  4 4     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  5 5     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  6 6     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  7 7     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  8 8     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
#  9 9     <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
# 10 10    <tibble [3 × 5]> <tibble [1 × 12]> <tibble [21 × 9]>
# # … with 14 more rows

与他们回答类似,为了访问,只需将所需的任何组件解嵌套即可:

unnest(select(h.lm, hour, tidy_out))
# # A tibble: 72 x 6
#    hour  term        estimate std.error statistic p.value
#    <fct> <chr>          <dbl>     <dbl>     <dbl>   <dbl>
#  1 1     (Intercept) 63.2       20.9        3.02  0.00728
#  2 1     wind        -0.00237    0.0139    -0.171 0.866
#  3 1     temp        -0.266      0.950     -0.280 0.783
#  4 2     (Intercept) 65.1       23.0        2.83  0.0111
#  5 2     wind         0.00691    0.0129     0.535 0.599
#  6 2     temp        -0.448      0.877     -0.510 0.616
#  7 3     (Intercept) 65.2       17.8        3.67  0.00175
#  8 3     wind         0.00515    0.0112     0.458 0.652
#  9 3     temp        -1.87       0.695     -2.69  0.0148
# 10 4     (Intercept) 49.7       17.6        2.83  0.0111
# # … with 62 more rows
dfty9e19

dfty9e196#

我认为你可以用更合适的方式使用dplyr,你不需要像@fabians anwser那样定义函数。

results<-df.h %.% 
group_by(hour) %.% 
do(failwith(NULL, lm), formula = price ~ wind + temp)

results<-do(group_by(tbl_df(df.h), hour),
failwith(NULL, lm), formula = price ~ wind + temp)

**编辑:**当然,没有failwith也可以工作

results<-df.h %.% 
    group_by(hour) %.% 
    do(lm, formula = price ~ wind + temp)

results<-do(group_by(tbl_df(df.h), hour),
lm, formula = price ~ wind + temp)
fcwjkofz

fcwjkofz7#

tidyverse后期的几个修订版,do()运算符被取代,我们可以用少一行代码来适应每组一个模型。

library("broom")
library("tidyverse")

df.h <- data.frame(
  hour     = factor(rep(1:24, each = 21)),
  price    = runif(504, min = -10, max = 125),
  wind     = runif(504, min = 0, max = 2500),
  temp     = runif(504, min = -10, max = 25)
)

df.h %>%
  group_by(hour) %>%
  group_modify(
    # Use `tidy`, `glance` or `augment` to extract different information from the fitted models.
    ~ tidy(lm(price ~ wind + temp, data = .))
  )
#> # A tibble: 72 × 6
#> # Groups:   hour [24]
#>    hour  term        estimate std.error statistic  p.value
#>    <fct> <chr>          <dbl>     <dbl>     <dbl>    <dbl>
#>  1 1     (Intercept) 73.9      16.3         4.52  0.000266
#>  2 1     wind        -0.0256    0.0119     -2.15  0.0456  
#>  3 1     temp         1.72      0.861       2.00  0.0604  
#>  4 2     (Intercept) 81.5      18.4         4.42  0.000331
#>  5 2     wind        -0.0111    0.00973    -1.14  0.270   
#>  6 2     temp        -1.60      0.763      -2.09  0.0506  
#>  7 3     (Intercept) 59.9      16.1         3.73  0.00154 
#>  8 3     wind         0.00358   0.0102      0.349 0.731   
#>  9 3     temp        -1.82      0.664      -2.74  0.0134  
#> 10 4     (Intercept) 49.6      18.5         2.69  0.0151  
#> # … with 62 more rows

reprex package(v2.0.1)于2022-04-20创建

hjqgdpho

hjqgdpho8#

从dplyr 1.0.0开始,group_split为这个操作提供了一个方便的快捷方式:

library(dplyr)
#> Warning: package 'dplyr' was built under R version 4.2.3
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(broom)
#> Warning: package 'broom' was built under R version 4.2.2
library(purrr)
#> Warning: package 'purrr' was built under R version 4.2.2
df.h <- data.frame( 
  hour     = factor(rep(1:24, each = 21)),
  price    = runif(504, min = -10, max = 125),
  wind     = runif(504, min = 0, max = 2500),
  temp     = runif(504, min = - 10, max = 25)  
)

df.g <- group_split(df.h, hour)
map_dfr(df.g, function(x) {
  tidy(lm(price ~ wind + temp, data=x)) |> 
    mutate(hour = x$hour[[1]])
  })
#> # A tibble: 72 × 6
#>    term          estimate std.error statistic  p.value hour 
#>    <chr>            <dbl>     <dbl>     <dbl>    <dbl> <fct>
#>  1 (Intercept) 115.         25.4       4.53   0.000260 1    
#>  2 wind         -0.00627     0.0129   -0.487  0.632    1    
#>  3 temp         -2.57        1.26     -2.04   0.0568   1    
#>  4 (Intercept)  71.0        16.6       4.28   0.000455 2    
#>  5 wind          0.00262     0.0112    0.233  0.818    2    
#>  6 temp         -0.824       0.834    -0.989  0.336    2    
#>  7 (Intercept)  39.3        22.5       1.74   0.0984   3    
#>  8 wind          0.000342    0.0137    0.0250 0.980    3    
#>  9 temp         -0.248       0.964    -0.257  0.800    3    
#> 10 (Intercept)  56.1        21.6       2.59   0.0184   4    
#> # ℹ 62 more rows

创建于2023-05-15带有reprex v2.0.2

相关问题