python 在不使用for循环的情况下为每个唯一键找到最小值

kb5ga3dv  于 12个月前  发布在  Python
关注(0)|答案(3)|浏览(71)

我有一个带键的numpy数组(例如[1, 2, 2, 3, 3, 2])和一个带值的数组(例如[0.2, 0.6, 0.8, 0.4, 0.9, 0.3])。我想在不使用for循环的情况下找到与每个唯一键关联的最小值。在这个例子中,答案是{1: 0.2, 2: 0.3, 3: 0.4}。我问了ChatGPT和New Bing,但他们一直给我错误的答案。那么,真的可以在不使用for循环的情况下做到这一点吗?
编辑1:我想达到的是最快的速度。而且,在我的例子中,大多数键都是唯一的。我考虑使用np.unique来获取每个键,然后计算每个键的最小值,但显然这需要一个for循环和一个二次时间。我还考虑按键对数组进行排序,并对每个键的值应用np.min,另外,根据评论,pandas.DataFrame有一个groupby方法可能会有帮助,但我不确定它是否是最快的(也许我会自己尝试)。
编辑二:我不一定需要dict作为输出;它可以是唯一键的数组和最小值的数组,键的顺序无关紧要。

mpgws1up

mpgws1up1#

我不认为没有循环是可能的。

keys = [1, 2, 2, 3, 3, 2]
values = [0.2, 0.6, 0.8, 0.4, 0.9, 0.3]
keys = sorted(keys)

values= sorted(values)

dct,i = {}, 1

while i<len(keys):
    x, y = keys[i], keys[i-1]
    if x != y :
        if y not in dct.keys():
            dct[y] = values[i-1]
            dct[x] = values[i]
        else:
            dct[x] = values[i]
    else:
        if y not in dct.keys():
            dct[y] = values[i-1]
   
    i += 1
print(dct)
#Output : {1: 0.2, 2: 0.3, 3: 0.8}

字符串

aij0ehis

aij0ehis2#

简单的Python解决方案类似于:

result = {}

for key, value in zip(keys, values):
    current = result.get(key)
    if current is not None:
        result[key] = min(current, value)
    else:
        result[key] = value

字符串
它应该是相对较快的。
如果您确实需要从中挤出性能,则应该使用numba

import numba

@numba.jit(nopython=True)
def group_min(keys, values):
    result = {}

    for key, value in zip(keys, values):
        current = result.get(key)
        if current is not None:
            result[key] = min(current, value)
        else:
            result[key] = value

    return result


请务必仔细阅读numba docs,以了解如何尽可能地提高性能。

3htmauhk

3htmauhk3#

试试看:

keys = np.array([1, 2, 2, 3, 3, 2])
values = np.array([0.2, 0.6, 0.8, 0.4, 0.9, 0.3])

out = dict(
    map(
        lambda x: (x[0], np.min(x[1])),
        zip(
            (rv := np.unique(keys[(idx := np.argsort(keys))], return_index=True))[0],
            np.split(values[idx], rv[1][1:]),
        ),
    )
)
print(out)

字符串
打印:

{1: 0.2, 2: 0.3, 3: 0.4}


快速基准:

import numba
import perfplot

def get_dict_andrej(keys, values):
    return dict(
        map(
            lambda x: (x[0], np.min(x[1])),
            zip(
                (rv := np.unique(keys[(idx := np.argsort(keys))], return_index=True))[
                    0
                ],
                np.split(values[idx], rv[1][1:]),
            ),
        )
    )

def get_dict_juan(keys, values):
    result = {}

    for key, value in zip(keys, values):
        current = result.get(key, float("inf"))
        if key is not None:
            result[key] = min(current, value)
        else:
            result[key] = value

    return result

@numba.jit(nopython=True)
def get_dict_juan_numba(keys, values):
    result = {}

    for key, value in zip(keys, values):
        current = result.get(key)
        if current is not None:
            result[key] = min(current, value)
        else:
            result[key] = value

    return result

# compile the function
get_dict_juan_numba(
    np.array([1], dtype=np.int32),
    np.array([0.1], dtype=float),
)

np.random.seed(42)

perfplot.show(
    setup=lambda n: (
        np.random.randint(1, n // 1.2, size=n, dtype=np.int32),
        np.random.random(size=n),
    ),
    kernels=[get_dict_andrej, get_dict_juan, get_dict_juan_numba],
    labels=["Andrej", "Juan", "Juan_numba"],
    n_range=[2**k for k in range(10, 22)],
    xlabel="N",
    equality_check=None,
    logx=True,
    logy=True,
)


创建此图表:


的数据

相关问题