numpy 如何确保每个唯一标签的计数大于一个值?

cqoc49vn  于 2023-10-19  发布在  其他
关注(0)|答案(2)|浏览(129)

我有一个一维数组,它包含从0开始的标签。目标是确保每个唯一标签的计数>=常量(例如,10).如果不是,合并最近的标签到它,直到计数大于10.
下面是一个示例:

import random
import numpy as np

data = np.concatenate(([0]*5, [1]*10, [2]*15, [3]*10, [4]*12, [5]*5))
random.Random(4).shuffle(data)
print(data)
array([2, 3, 4, 5, 2, 4, 2, 4, 2, 2, 0, 4, 1, 4, 4, 5, 4, 2, 1, 3, 1, 1,
       5, 0, 3, 3, 2, 2, 4, 4, 3, 3, 4, 4, 1, 2, 5, 1, 2, 2, 3, 3, 1, 0,
       2, 3, 5, 0, 0, 1, 1, 3, 2, 4, 1, 2, 2])

逻辑应该是这样的:
从标签0开始,因为0的计数为5(< 10),通过将1替换为01与0合并。那么标签0有足够的计数(15)。
然后,继续到下一个满足条件的标签2.
最后一个需要合并的标签是5,应该替换为4
我想出了这个方法:循环np.unique(data)并检查np.bincount(data)的计数。然而,如果我们有一个大的data数组,这种方法是缓慢的。

import random
data = np.concatenate(([0]*5, [1]*10, [2]*15, [3]*10, [4]*12, [5]*5))
random.Random(4).shuffle(data)

counts = np.bincount(data)

new_label = 0
count_num = 0

for label in np.unique(data):
    if count_num > 0:
        data[data==label] = label-1
    count_num += counts[label]
    
    if count_num >= 10:
        count_num = 0

    if label == np.unique(data)[-1] and counts[label] < 10:
        data[data==label] = label-1
array([2, 3, 4, 4, 2, 4, 2, 4, 2, 2, 0, 4, 0, 4, 4, 4, 4, 2, 0, 3, 0, 0,
       4, 0, 3, 3, 2, 2, 4, 4, 3, 3, 4, 4, 0, 2, 4, 0, 2, 2, 3, 3, 0, 0,
       2, 3, 4, 0, 0, 0, 0, 3, 2, 4, 0, 2, 2])

有什么关于合并数据的想法吗?谢谢你,谢谢

3bygqnnd

3bygqnnd1#

你的循环是低效的,因为你反复切片原始数组。只需运行np.unique并获取计数一次

import random
data = np.concatenate(([0]*5, [1]*10, [2]*15, [3]*10, [4]*12, [5]*5))
random.Random(4).shuffle(data)

thresh = 10

count_sum = 0
label = 0
labels = []

vals, idx, cnt = np.unique(data, return_inverse=True, return_counts=True)
for i, c in enumerate(cnt, start=1):
    labels.append(label)
    if count_sum+c >= thresh:
        label = i
        count_sum = 0
    else:
        count_sum += c
# fix last labels
if cnt[-1] < thresh:
    labels = np.clip(labels, 0, labels[-1]-1)

out = vals[labels][idx]

输出量:

array([2, 3, 4, 4, 2, 4, 2, 4, 2, 2, 0, 4, 0, 4, 4, 4, 4, 2, 0, 3, 0, 0,
       4, 0, 3, 3, 2, 2, 4, 4, 3, 3, 4, 4, 0, 2, 4, 0, 2, 2, 3, 3, 0, 0,
       2, 3, 4, 0, 0, 0, 0, 3, 2, 4, 0, 2, 2])

计时

在0-100的100 k随机值上

# original approach
863 ms ± 3.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# this solution
20.1 ms ± 673 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

在0-1000的1 M随机值上

# original approach
# took too long, had to interrupt after 10 minutes

# this solution
525 ms ± 36.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5uzkadbs

5uzkadbs2#

除非你有非常多的标签,否则Python循环不应该是太多的性能消耗。您可能希望基于数据中实际存在的标签进行合并(而不是增加/减少1)。这可以通过在当前索引上使用+1/-1而不是标签值来实现。

import numpy as np

data = np.array([2, 3, 4, 5, 2, 4, 2, 4, 2, 2, 0, 4, 1, 4, 4, 5, 4, 2,
       1, 3, 1, 1, 5, 0, 3, 3, 2, 2, 4, 4, 3, 3, 4, 4, 1, 2, 5, 1, 2, 2, 
       3, 3, 1, 0, 2, 3, 5, 0, 0, 1, 1, 3, 2, 4, 1, 2, 2])

print("before:",np.bincount(data))

*counts, = map(list,zip(*np.unique(data, return_counts=True)))
minSize = 10
while any(c<minSize for _,c in counts):    # loop until all large enough
    for i in range(len(counts)-1,-1,-1):   # go trough indexes backward
        label,count = counts[i]
        if count >= minSize: continue          # merge small labels only
        if i>0:
            data[data==label] = counts[i-1][0] # merge with next label
            counts[i-1][1] += counts.pop(i)[1]
        elif i+1<len(counts):
            data[data==label] = counts[i+1][0] # merge with previous label
            counts[i+1][1] += counts[i][1]
            counts[i] = counts.pop(i+1)
        else:
            break                              # cannot merge to neighbour
        
print("after: ",np.bincount(data))

产出:

before: [ 5 10 15 10 12  5]
after:  [ 0 15 15 10 17]

相关问题