Pandas groupby slicing,but in numpy

xzv2uavs  于 2023-11-18  发布在  其他
关注(0)|答案(1)|浏览(133)

我有一个只包含float dtypes和id列的嵌套框架。我想将嵌套框架限制为每个id值的前10行。直接的方法是df.groupby('id').apply(lambda minidf: minidf.iloc[:k]),但它似乎有点慢,我想知道是否有更快的方法获得相同的输出。
所以我想问一下,既然Antrame是由所有的浮点数组成的,有没有一个numpy方法我可以使用它来等效于上面的代码行?或者我可以在任何其他库中实现相同的结果,但是时间要短得多。提前感谢!

xxb16uws

xxb16uws1#

下面是一个使用Numba的快速解决方案:

import numba as nb
import numpy as np
import pandas as pd

@nb.njit('bool_[:](int64[:],)')
def compute_mask(ids, max_per_group):
    unique_ids = {}
    res = np.empty(ids.size, dtype=np.bool_)
    for i in range(ids.size):
        item = ids[i]
        found = unique_ids.get(item)
        if found is None:
            unique_ids[item] = 1
            res[i] = True
        elif found < max_per_group:
            unique_ids[item] = found + 1
            res[i] = True
        else:
            res[i] = False  # Found too many times
    return res

# df = ...

mask = compute_mask(df['id'].to_numpy())

# Use `mask` here to filter the dataframe either using Numpy or Pandas.
# Creating a new dataframe in Pandas is certainly slower, but simpler.
# It can be done using Pandas with df[mask]. 
# It can be done using Numpy with df['values'].to_numpy()[mask]

字符串
这个想法是使用一个简单的哈希Map来计算重复的项目的数量,并有效地构建一个掩码。Numba中的循环速度很快,哈希Map访问也比CPython快。当ID经常重复多次时,这个解决方案是有效的。如果所有ID往往是不同的,(并行)基于排序的解决方案可以更快。请注意,您可能需要根据实际的嵌套框ID列类型为Numba函数使用签名bool_[:](int32[:],)

相关问题