R语言 使用多个列值作为函数的输入来计算data.table中的新列

4nkexdtk  于 2022-12-20  发布在  其他
关注(0)|答案(2)|浏览(125)

我使用 data.table 来处理一个包含多个列的数据集。我需要使用其中的一些列值来计算每一行的新列。我知道我可以使用 SDcols 功能来实现一些简单的函数。但是,当我想使用自己的函数时,这就有点麻烦了,因为它对列值的处理方式不同。下面是我的示例:
下面是 data.table 的外观:

Training Age  Client1 Stim1.0 Stim1.1  Client2 Stim2.0 Stim2.1 Choice Val.00 Val.01
1:        0   1  absence       0       0  absence       0       0      2      0      0
2:        0   2  absence       0       0  absence       0       0      2      0      0
3:        0   3 Object 1       1       1 Object 2       2       2      1      0      0
4:        0   4 Object 2       2       1 Object 2       2       1      1      0      0
5:        0   5  absence       0       0  absence       0       0      2      0      0
6:        0   6  absence       0       0  absence       0       0      2      0      0
   Val.02 alpha.0 Val.10 Val.11 Val.12 alpha.1 V25
1: 0.0000   0.005      0 0.0000 0.0000   0.005  NA
2: 0.0000   0.005      0 0.0000 0.0000   0.005  NA
3: 0.0025   0.005      0 0.0000 0.0025   0.005  NA
4: 0.0050   0.005      0 0.0025 0.0025   0.005  NA
5: 0.0050   0.005      0 0.0025 0.0025   0.005  NA
6: 0.0050   0.005      0 0.0025 0.0025   0.005  NA

该函数使用以 Stim 开头的列的值来选择以 Val 开头的列中必须包含在新值的计算中的列。
Stim 和 * 瓦尔 * 列的数量较低时,分别为2和3,我可以使用 fcase 求解

rawData[,`:=`(Val.Client1=fcase(Stim1.0==0,Val.00,
                               Stim1.0==1,Val.01,Stim1.0==2,Val.02)+
                fcase(Stim1.1==0,Val.10,
                      Stim1.1==1,Val.11,Stim1.1==2,Val.12),
              Val.Clien2=fcase(Stim2.0==0,Val.00,
                               Stim2.0==1,Val.01,Stim2.0==2,Val.02)+
                fcase(Stim2.1==0,Val.10,
                      Stim2.1==1,Val.11,Stim2.1==2,Val.12))]

然而,我使用的不同数据集的列数是不同的,所以,我想独立于列数进行编码。
我已经成功地使用 .SDcolsapply 的组合使其工作,方法如下:

numSti<-2,numFeat<-2 # parameters to know the number of columns to expect
rawData[,Val.Client1:=apply(.SD,MARGIN = 1,FUN = function(x){
# I use apply to get a vector with alll the relevant values
  x<-as.numeric(x) # for some reason I must force it to be numeric 
  Stim1.tmp<-x[1:numSti]+1 # Choose the relevant values for the Stim columns
  vals<-x[(numSti*2+1): (numSti*2+numSti*(1+numFeat))] # choose the relevant values for the Val columns
  locVal<-Stim1.tmp+(numFeat+1)*(0:(numSti-1)) # map the Stim to the Val columns
  return(sum(vals[locVal])) # sum over the chosen values. 
}),.SDcols=patterns("Stim.|Val.")]

这段代码给了我正确的计算。但是它太慢了!你能帮我找一个更快的解决方案吗?
根据@jblood94的要求:dput(rawData)的输出

as.data.table(structure(list(Age = 1:6, Client1 = c(2L, 2L, 0L, 1L, 2L, 2L), 
    Stim1.0 = c(0L, 0L, 1L, 2L, 0L, 0L), Stim1.1 = c(0L, 0L, 
    1L, 1L, 0L, 0L), Client2 = c(2L, 2L, 1L, 1L, 2L, 2L), Stim2.0 = c(0L, 
    0L, 2L, 2L, 0L, 0L), Stim2.1 = c(0L, 0L, 2L, 1L, 0L, 0L), 
    Choice = c(2L, 2L, 1L, 1L, 2L, 2L), Val.00 = c(0, 0, 0, 0, 
    0, 0), Val.01 = c(0, 0, 0, 0, 0, 0), Val.02 = c(0, 0, 0.0025, 
    0.005, 0.005, 0.005), alpha.0 = c(0.005, 0.005, 0.005, 0.005, 
    0.005, 0.005), Val.10 = c(0, 0, 0, 0, 0, 0), Val.11 = c(0, 
    0, 0, 0.0025, 0.0025, 0.0025), Val.12 = c(0, 0, 0.0025, 0.0025, 
    0.0025, 0.0025), alpha.1 = c(0.005, 0.005, 0.005, 0.005, 
    0.005, 0.005), V25 = c(NA, NA, NA, NA, NA, NA)), row.names = c(NA, 
-6L), class = c("data.table", "data.frame")))
wwodge7n

wwodge7n1#

也许这个用户函数会有所帮助:

fun <- function(data, vals) {
  stimvals <- Map(function(V, levels) {
    match(paste0(sub("Stim[0-9]+\\.([0-9]+)", "Val.\\1", V), levels),
          names(data))
  }, setNames(nm = vals), lapply(vals, function(z) data[[z]]))
  Reduce(`+`, lapply(stimvals, function(z) as.data.frame(data)[cbind(seq_along(z), z)]))
}

stims <- grep("Stim.*", names(rawData), value = TRUE)
stims <- split(stims, sub("\\..*", "", stims))
names(stims) <- sub(".*([0-9]+)$", "Val.Client\\1", names(stims))
stims
# $Val.Client1
# [1] "Stim1.0" "Stim1.1"
# $Val.Client2
# [1] "Stim2.0" "Stim2.1"

rawData[, names(stims) := lapply(stims, fun, data = .SD)]
rawData
#      Age Client1 Stim1.0 Stim1.1 Client2 Stim2.0 Stim2.1 Choice Val.00 Val.01 Val.02 alpha.0 Val.10 Val.11 Val.12 alpha.1    V25 Val.Client1 Val.Client2
#    <int>   <int>   <int>   <int>   <int>   <int>   <int>  <int>  <num>  <num>  <num>   <num>  <num>  <num>  <num>   <num> <lgcl>       <num>       <num>
# 1:     1       2       0       0       2       0       0      2      0      0 0.0000   0.005      0 0.0000 0.0000   0.005     NA      0.0000      0.0000
# 2:     2       2       0       0       2       0       0      2      0      0 0.0000   0.005      0 0.0000 0.0000   0.005     NA      0.0000      0.0000
# 3:     3       0       1       1       1       2       2      1      0      0 0.0025   0.005      0 0.0000 0.0025   0.005     NA      0.0000      0.0050
# 4:     4       1       2       1       1       2       1      1      0      0 0.0050   0.005      0 0.0025 0.0025   0.005     NA      0.0075      0.0075
# 5:     5       2       0       0       2       0       0      2      0      0 0.0050   0.005      0 0.0025 0.0025   0.005     NA      0.0000      0.0000
# 6:     6       2       0       0       2       0       0      2      0      0 0.0050   0.005      0 0.0025 0.0025   0.005     NA      0.0000      0.0000

这是基于Stim*子级别动态查找Val*名称(例如,Stim2.1.1)和对应的Stim2.1中的 * 值 *。即,如果Stim2.1具有0的值,那么它应该从Val.10拉出。数据基于适当的Val*列进行索引,然后重新分配给从Stim*名称的第一个数字派生的名称(Stim2.12)。
因此,上面stims变量的生成是关键:将相应的Stim#.#*变量分组在一起(它们将被子集化/求和),并适当地命名它们。

uelo1irk

uelo1irk2#

这个方法看起来很有效,其思想是基于Stim的值索引Val列,将结果放入一个c(nrow(rawData), numSti, 2)维的数组中,然后置换该数组,以便沿着第二维与colSums求和。

numSti <- 2
numFeat <- 2

rawData[
  ,paste0("Val.Client", 1:2) := as.data.table(
    colSums(
      aperm(
        array(
          unlist(rawData[,.SD , .SDcols = grep("Val.", names(rawData), value = TRUE)])[
            1:.N + (unlist(rawData[,.SD , .SDcols = grep("Stim", names(rawData), value = TRUE)]) + rep(c(0, numFeat + 1), each = .N))*.N
          ],
          c(.N, numSti, 2)
        ),
        c(2, 1, 3)
      )
    )
  )
]

rawData[, Val.Client1:Val.Client2]
#>    Val.Client1 Val.Client2
#> 1:      0.0000      0.0000
#> 2:      0.0000      0.0000
#> 3:      0.0000      0.0050
#> 4:      0.0075      0.0075
#> 5:      0.0000      0.0000
#> 6:      0.0000      0.0000

它也可以概括为客户端的数量:

numCli <- 2

rawData[
  ,paste0("Val.Client", 1:numCli) := as.data.table(
    colSums(
      aperm(
        array(
          unlist(rawData[,.SD , .SDcols = grep("Val.", names(rawData), value = TRUE)])[
            1:.N + (
              unlist(
                rawData[,.SD , .SDcols = grep("Stim", names(rawData), value = TRUE)]
              ) + rep(
                seq(0, by = numFeat + 1, length.out = numCli),
                each = .N
              )
            )*.N
          ],
          c(.N, numSti, numCli)
        ),
        c(2, 1, 3)
      )
    )
  )
]

相关问题