使用frollsum计算data.table中的滚动加权平均值

p8ekf7hl  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(92)

我们的目标是计算加权平均值,在窗口中有3行,权重为3,2,1,按最近的行的顺序排列。这与问题here类似,但权重不是由列给出的。此外,我真的很想使用frollsum(),因为我正在处理大量数据,需要它的性能。
我有一个使用frollapply()的解决方案:

library(data.table)

# Your data
set.seed(1)
DT <- data.table(group = rep(c(1, 2), each = 10), value = round(runif(n = 20, 1, 5)))

weights <- 1:3
k <- 3

weighted_average <- function(x) {
  sum(x * weights[1:length(x)]) / sum(weights[1:length(x)])
}

# Apply rolling weighted average
DT[, wtavg := shift(frollapply(value, k, weighted_average, align = "right", fill = NA)), 
   by = group]

DT 
#>     group value    wtavg
#>  1:     1     2       NA
#>  2:     1     2       NA
#>  3:     1     3       NA
#>  4:     1     5 2.500000
#>  5:     1     2 3.833333
#>  6:     1     5 3.166667
#>  7:     1     5 4.000000
#>  8:     1     4 4.500000
#>  9:     1     4 4.500000
#> 10:     1     1 4.166667
#> 11:     2     2       NA
#> 12:     2     2       NA
#> 13:     2     4       NA
#> 14:     2     3 3.000000
#> 15:     2     4 3.166667
#> 16:     2     3 3.666667
#> 17:     2     4 3.333333
#> 18:     2     5 3.666667
#> 19:     2     3 4.333333
#> 20:     2     4 3.833333

字符串
创建于2023-11-27使用reprex v2.0.2

06odsfpq

06odsfpq1#

可能不是最佳方式(我会研究Rcpp),但你可以简单地使用frollsum thrice来获得显著的速度:

shift((frollsum(value, 3) + frollsum(value, 2) + frollsum(value, 1)) / 6)

字符串
请注意,frollsum(value, 1)可以替换为value
另一个(看起来)更快更简单的选择:

(c(shift(value, 3) + 2 * shift(value, 2) + 3 * shift(value, 1)) / 6,


标杆

set.seed(1)
n = 1000000
groups = 1:1000
DT <- data.table(group = rep(groups, each = n/length(groups)), value = round(runif(n = n, 1, 5)))

bench::mark(
  A = {
    DT[, shift(frollapply(value, k, weighted_average, align = "right", fill = NA)), 
   by = group]
  },
  B = {
    DT[, shift((frollsum(value, 3) + frollsum(value, 2) + value) / 6),
       by = group]
  }
)

#   expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
# 1 A             1.75s    1.75s     0.570    46.7MB     34.2     1    60      1.75s
# 2 B           80.38ms  85.72ms     9.24     77.4MB     12.9     5     7   541.13ms

7eumitmz

7eumitmz2#

我使用frollsum()frollapply()RcppRoll::roll_meanr()在不同的滞后长度、组数和观察数下进行了更多的基准测试,每种方法都适合接受多个变量。
RcppRoll::roll_meanr()是明显的赢家。

# Using 12 threads
data.table::getDTthreads()
#> [1] 12

# Function to create data
make_data <- function(n, n_groups, seed = 1) {
  set.seed(seed)
  data.table::data.table(
    group = rep(seq(n_groups), each = n / n_groups), 
    value1 = round(runif(n = n, 1, 5)),
    value2 = round(runif(n = n, 1, 6))
  )  
}

# Used inside fapplysum() below
weighted_average <- function(x, weights) {
  sum(x * weights[1:length(x)]) / sum(weights[1:length(x)])
}

# Weighted average using RcppRoll::rollmeanr()
rcpp_roll <- function(dt, weights) {
  dt[, RcppRoll::roll_meanr(as.matrix(.SD), weights = weights) |> 
       as.data.frame() |> 
       data.table::shift(), 
     by = group, 
     .SDcols = c("value1", "value2")]
}

# Weighted average using frollsum()
fsum <- function(dt, weights) {
  dt[
    , 
    {
      x <- purrr::map(
        weights,
        ~data.table::frollsum(.SD, n = .x)
      )
      
      purrr::map(seq(ncol(.SD)), \(i) purrr::map(x, ~.x[[i]])) |> 
        purrr::map(~Reduce(`+`, .x) / sum(weights)) |>
        data.table::shift()
    }
    ,
    .SDcols = c("value1", "value2"), 
    by = group]
}

# Weighted average using sum() and frollapply()
fapplysum <- function(dt, weights) {
  dt[, data.table::frollapply(.SD, length(weights), weighted_average, weights = weights) |> 
       data.table::shift(), 
     .SDcols = c("value1", "value2"),
     by = group]
}

# Create benchmarking template
benchmark <- function(dt, weights, iter = 10) {
  bench::mark(
    filter_gc = FALSE,
    iterations = iter,
    rcpp_roll = rcpp_roll(dt, weights),
    fsum = fsum(dt, weights),
    fapplysum = fapplysum(dt, weights)
  )  
} 

# Create benchmarking data
combos <- expand.grid(
  n = c(1e3, 1e4, 5e4, 7.5e4, 1e5),
  n_groups = c(1, 10, 50),
  k = c(3, 10, 50)
)

combos$weights <-  sapply(combos$k, \(k) 1:k)

data <- purrr::map2(
  combos$n,
  combos$n_groups,
  make_data
)

# Benchmark rolling weighted average functions
results <- purrr::map2(
  .progress = TRUE,
  data,
  combos$weights,
  benchmark
)

# Plot results
library(ggplot2)

results2 <- dplyr::bind_rows(.id = "id", results) |> 
  dplyr::mutate(id = as.numeric(id)) |> 
  dplyr::select(id, expression, median) 

plot_data <- results2 |> 
  dplyr::inner_join(combos |> 
                      dplyr::select(-weights) |> 
                      dplyr::mutate(id = dplyr::row_number()), by = "id") |> 
  # average over different lag lengths and number of groups
  dplyr::summarise(
    .by = c("expression", "n"),
    time_avg = mean(as.numeric(median))
  ) |> 
  dplyr::mutate(expression = expression |> as.character() |> as.factor())

ggplot(data = plot_data) +
  geom_line(aes(x = n, y = time_avg, color = expression)) +
  ggtitle("Comparing execution time for different rolling weighted average functions") +
  ylab("Avg. Time (s)") +
  xlab("Number of Obs.")

字符串
x1c 0d1x的数据
创建于2023-12-09带有reprex v2.0.2

相关问题