import numba as nb
import numpy as np
import pandas as pd
@nb.njit('bool_[:](int64[:],)')
def compute_mask(ids, max_per_group):
unique_ids = {}
res = np.empty(ids.size, dtype=np.bool_)
for i in range(ids.size):
item = ids[i]
found = unique_ids.get(item)
if found is None:
unique_ids[item] = 1
res[i] = True
elif found < max_per_group:
unique_ids[item] = found + 1
res[i] = True
else:
res[i] = False # Found too many times
return res
# df = ...
mask = compute_mask(df['id'].to_numpy())
# Use `mask` here to filter the dataframe either using Numpy or Pandas.
# Creating a new dataframe in Pandas is certainly slower, but simpler.
# It can be done using Pandas with df[mask].
# It can be done using Numpy with df['values'].to_numpy()[mask]
1条答案
按热度按时间xxb16uws1#
下面是一个使用Numba的快速解决方案:
字符串
这个想法是使用一个简单的哈希Map来计算重复的项目的数量,并有效地构建一个掩码。Numba中的循环速度很快,哈希Map访问也比CPython快。当ID经常重复多次时,这个解决方案是有效的。如果所有ID往往是不同的,(并行)基于排序的解决方案可以更快。请注意,您可能需要根据实际的嵌套框ID列类型为Numba函数使用签名
bool_[:](int32[:],)
。