R语言 将函数应用于元素的所有组合

8iwquhpp  于 2023-01-28  发布在  其他
关注(0)|答案(3)|浏览(155)

假设我有一个三变量函数

f <- function(x,y,z) x*y*z

现在,给定三组数据

X <- seq(1,10)
Y <- seq(11,20)
Z <- seq(21,30)

我想对X,Y,Z中的所有元素组合有效地计算f,最聪明的方法是什么?(实际上,我的函数更复杂,集合也很大)。

qij5mzcb

qij5mzcb1#

可能是这样的,用prod得到expand.grid中元素与行的乘积。

X <- 1:3
Y <- 4:6
Z <- 7:9

apply(expand.grid(X, Y, Z), 1, function(x) prod(x))
 [1]  28  56  84  35  70 105  42  84 126  32  64  96  40  80 120  48  96 144  36
[20]  72 108  45  90 135  54 108 162

使用任意函数
x一个一个一个一个x一个一个二个x
使用dplyr,使用crossing
注意crossingexpand_grid/expand.grid之间的差异
"crossing()"是"expand_grid()"的 Package 器,用于消除重复项并对其输入进行排序
所以它自然会慢一点,获得经常需要的功能。

library(dplyr)
library(tidyr)

crossing(X, Y, Z) %>% 
  rowwise() %>% 
  mutate(result = f(X, Y, Z))
# A tibble: 27 × 4
# Rowwise: 
       X     Y     Z result
   <int> <int> <int>  <int>
 1     1     4     7     28
 2     1     4     8     32
 3     1     4     9     36
 4     1     5     7     35
 5     1     5     8     40
 6     1     5     9     45
 7     1     6     7     42
 8     1     6     8     48
 9     1     6     9     54
10     2     4     7     56
# … with 17 more rows

最后是data.table方法,使用CJ
"CJ":* C * ross * J * oin."数据表"由向量的叉积形成。

library(data.table)

CJ(X, Y, Z)[, f(X, Y, Z)]
 [1]  28  32  36  35  40  45  42  48  54  56  64  72  70  80  90  84  96 108  84
[20]  96 108 105 120 135 126 144 162
plupiseo

plupiseo2#

我认为正确的答案可能取决于XYZ的大小。有几个选项可以扩展所有值的组合,expand.grid()来自基本R和expand_grid()来自tidyr包以及CJ()来自data.table包,正如@Andre维尔德贝格提到的。然后,您可以在扩展数据集apply()的行上使用for循环、在管道中使用mutate()或使用data.table方法来计算函数的结果。考虑上面提出的情况,其中每种方法的长度都是10。看看expand.grid()expand_grid()CJ()的基准测试,它们具有相似的量级,尽管expand.grid()CJ()方法平均起来更快,而不是expand_grid()方法。

library(dplyr)
library(tidyr)
library(data.table)
library(microbenchmark)
  

f <- function(x,y,z) x*y*z

X <- seq(1,10)
Y <- seq(11,20)
Z <- seq(21,30)

microbenchmark(
  expand_grid(X,Y,Z), 
  expand.grid(X,Y,Z), 
  CJ(X, Y, Z), times=25)
#> Unit: microseconds
#>                  expr     min      lq     mean  median      uq      max neval
#>  expand_grid(X, Y, Z) 245.750 259.667 361.7473 269.126 282.834 2507.043    25
#>  expand.grid(X, Y, Z) 102.750 107.084 147.9838 112.584 122.625  952.251    25
#>           CJ(X, Y, Z) 115.293 124.168 202.5123 132.209 137.251 1885.709    25
#>  cld
#>    a
#>    a
#>    a

在考虑计算结果的不同方法时,我实现了四种解决方案:

  • exp1()是带有expand.grid()的for循环
  • exp2()是具有expand.grid()apply()
  • exp3()是使用expand_grid()dplyr解决方案
  • exp4()是使用CJ()data.table解决方案。

data.table解决方案是dplyr解决方案的明显赢家,其优势大约为3倍(尽管根据CLD,差异在统计学上并不显著)。

exp1 <- function(){
df = expand.grid(X, Y, Z)
for (i in 1:nrow(df)) {
  df$prod = f(df[i,1], df[i,2], df[i,3])
}
}

exp2 <- function(){
  df = expand.grid(X, Y, Z)
  df$prod <- apply(df, 1, function(x)f(x[1], x[2], x[3]))
}

exp3 <- function(){
  df = expand_grid(X,Y,Z)
  df %>% mutate(prod = f(X,Y,Z))
}

exp4 <- function(){
  CJ(X,Y,Z)[,prod := f(X,Y,Z)]
}

microbenchmark(exp1(), exp2(), exp3(), exp4(), times = 25)
#> Unit: microseconds
#>    expr       min        lq       mean    median        uq       max neval cld
#>  exp1() 28588.667 29375.543 31602.4207 30595.001 32012.667 44312.917    25   c
#>  exp2()  3129.459  3254.584  3498.8607  3285.625  3341.584  6549.251    25  b 
#>  exp3()  1334.209  1411.834  2072.8457  1691.376  1825.626  9790.917    25 ab 
#>  exp4()   357.501   403.084   829.6241   481.459   559.668  7319.542    25 a

如果我们将变量的长度从10增加到30,我们可以看到情况有一些变化。CJ()大约是expand_grid()的两倍,而expand_grid()又大约是expand.grid()的两倍。

X <- seq(1,30)
Y <- seq(11,40)
Z <- seq(21,50)

microbenchmark(
  expand_grid(X,Y,Z), 
  expand.grid(X,Y,Z),
  CJ(X,Y,Z), times=25)
#> Unit: microseconds
#>                  expr     min      lq     mean  median      uq      max neval
#>  expand_grid(X, Y, Z) 298.834 321.542 485.6574 429.084 453.626 2516.376    25
#>  expand.grid(X, Y, Z) 610.584 729.084 750.1874 785.167 796.834  813.584    25
#>           CJ(X, Y, Z) 132.542 162.750 190.0006 201.292 214.126  260.959    25
#>  cld
#>   b 
#>    c
#>  a

在研究函数的不同计算方法时,结果同样清楚。data.table解比dplyr解快大约3倍,尽管在统计上也没有显著差异。dplyrdata.table解都比使用expand '.grid()的任一解快得多。

microbenchmark(exp1(), exp2(), exp3(), exp4(), times = 25)
#> Unit: microseconds
#>    expr         min          lq         mean      median          uq
#>  exp1() 1603137.001 1652454.959 1702111.1240 1666486.751 1699009.251
#>  exp2()   85320.126   89976.043   92332.3890   92130.709   94564.959
#>  exp3()    1604.750    1708.334    2302.5940    2199.292    2334.959
#>  exp4()     505.417     541.418     699.1174     723.834     793.001
#>          max neval cld
#>  2110589.459    25   c
#>   101803.250    25  b 
#>     4462.126    25 a  
#>      971.501    25 a

reprex package(v2.0.1)于2023年1月24日创建

rseugnpd

rseugnpd3#

使用expand.grid创建组合并迭代生成的数据框:

f <- function(x,y,z) x*y*z

X <- seq(1,10)
Y <- seq(11,20)
Z <- seq(21,30)

df = expand.grid(X, Y, Z)

for (i in 1:nrow(df)) {
    print(f(df[i, 1], df[i, 2], df[i, 3]))
}

相关问题