python 查找列表中每个字符串之间的最大差异

cngwdvgl  于 2023-03-28  发布在  Python
关注(0)|答案(3)|浏览(193)

这是一个面试的问题:
给定一个长度为l的n个字符串的列表,在任何给定的索引处都可以包含字符'a'或'b',我如何在尽可能低的时间复杂度下计算出每个n个字符串的最大差异(意味着每个索引处不同的字符数量)。
例如[“abaab”,“abbbb”,“babba”]
输出将是[5,3,5],因为第一个和第三个字符串之间的差是5(没有一个字符是相同的),所以字符串1和任何其他字符串之间的最大差值是5。同样,字符串2和字符串3之间的差值是3,大于字符串2和1之间的差值。字符串3和字符串1之间的差值是5。
l的极限是20
我尝试了下面的代码(其中l和n是相应命名的,字符串是字符串的列表)。这段代码中的变量是示例情况下的变量,但我希望有一个通用的解决方案。

n = 3
l = 5
abmap = {}
strings = ["abaab", "abbbb","babba"]

for i in range(l):
    abmap[i] = {"a": [], "b": []}

for i in range(n):
    for j in range(l):
        if strings[i][j] == "a":
            abmap[j]["a"].append(i)

        else:
            abmap[j]["b"].append(i)

for string in strings:
    differences = n * [0]
    for i in range(l):
        if string[i] == "a":
            for index in abmap[i]["b"]:
                differences[index] += 1

        else:
            for index in abmap[i]["a"]:
                differences[index] += 1

    print(max(differences))

然而,这个解决方案是O(n2·l)。面试官要求我进一步优化它(例如到O(l·n·log(n))。我如何才能做到这一点?
此操作的时间限制为15秒,并且n小于100000。

7uhlpewt

7uhlpewt1#

就像我之前的回答一样(现在已经不可行了),我把每个字符串转换成一个L位数。它们是2L ≤ 220个可能的L位数的子集。把它想象成一个2L个节点的图,如果两个节点相差一位,那么两个节点之间就有一条边,输入的数字是这些节点的子集。
现在……对于任何输入数,最远的输入数是多少?我们可以通过看反数来解决这个问题(所有L个位翻转,即,距离L),并询问与之 * 最接近 * 的输入数字是多少。因此,我们运行一个并行BFS(广度优先搜索)。我们将它们标记为距离为0。然后我们将距离为2的所有数字标记为一位变化。最后,对于每一个输入的数字,我们看一下倒数的距离,然后从L中减去它。
基准测试结果,最差情况需要约3秒,远低于15秒限值:

n=1000 L=20:
 0.14 s  solution1
 2.63 s  solution2

n=10000 L=20:
15.01 s  solution1
 3.01 s  solution2

n=100000 L=20:
 3.47 s  solution2

完整代码(solution1是我以前的代码,solution2是我上面介绍的代码):

def solution1(strings):
    table = str.maketrans('ab', '01')
    numbers = [
        int(s.translate(table), 2)
        for s in strings
    ]
    return [
        max((x ^ y).bit_count() for y in numbers)
        for x in numbers
    ]

def solution2(strings):
    table = str.maketrans('ab', '01')
    numbers = [
        int(s.translate(table), 2)
        for s in strings
    ]
    L = len(strings[0])
    bits = [2**i for i in range(L)]
    dist = [None] * 2**L
    for x in numbers:
        dist[x] = 0
    horizon = numbers
    d = 1
    while horizon:
        horizon = [
            y
            for x in horizon
            for bit in bits
            for y in [x ^ bit]
            if dist[y] is None
            for dist[y] in [d]
        ]
        d += 1
    return [L - dist[~x] for x in numbers]

funcs = solution1, solution2

import random
from time import time

# Generate random input
def gen(n, L):
    return [
        ''.join(random.choices('ab', k=L))
        for _ in range(n)
    ]

# Correctness
for _ in range(100):
    strings = gen(100, 10)
    expect = funcs[0](strings)
    for f in funcs:
        result = f(strings)
        assert result == expect

# Speed
def test(n, L, funcs):
    print(f'{n=} {L=}:')

    for _ in range(1):
        strings = gen(n, L)
        expect = None
        for f in funcs:
            t = time()
            result = f(strings)
            print(f'{time()-t:5.2f} s ', f.__name__)
            if expect is None:
                result = expect
            else:
                assert result == expect
            del result
    print()

test(1000, 20, [solution1, solution2])
test(10000, 20, [solution1, solution2])
test(100000, 20, [solution2])

Attempt This Online!

sd2nnvve

sd2nnvve2#

(With现在增加了n和时间的限制,这不再可行。请看我的新答案。)
L的限制为20(为了可读性而重新命名),这表明他们希望您将每个字符串转换为L位数,然后通过对它们进行异或并要求popcount来计算其中两个的差异,每个差异都有O(1),因此总体上是O(nL+n²)。

strings = ["abaab", "abbbb", "babba"]

table = str.maketrans('ab', '01')
numbers = [int(s.translate(table), 2)
           for s in strings]

for x in numbers:
    print(max((x ^ y).bit_count()
              for y in numbers))

Attempt This Online!

cbeh67ev

cbeh67ev3#

用一个位集来代表每支球队,其中每一位对应一头牛的品种(0代表根西岛,1代表荷斯坦),然后我们可以通过计算两支球队的位集的位异或,并计算集合位的数量(即两支球队不同位置的数量)来计算两支球队之间的差异。
为了找到每个团队的最大差异,我们可以迭代所有其他团队,并使用上述方法计算差异。我们跟踪到目前为止看到的最大差异,并在发现更大差异时更新它。
为了使代码更快,我们可以利用这样一个事实,即对于一个大小为C的团队,可能的位集的数量是2^C,最多是2^18 = 262144。这意味着我们可以预先计算所有位集对之间的差异,并将它们存储在查找表中。
然后我们可以使用这个查找表来快速计算任何两个团队之间的差异,只需在表中查找他们的位集差异。这将算法的时间复杂度降低到O(N^2 / 32),这比以前的方法快得多。
Python代码:

from collections import defaultdict

C, N = map(int, input().split())

# read in the teams and convert them to bitsets
teams = []
for i in range(N):
    team_str = input().strip()
    team_bits = int(''.join(['0' if c == 'G' else '1' for c in team_str]), 2)
    teams.append(team_bits)

# precompute the differences between all pairs of bitsets
lookup = defaultdict(dict)
for i in range(N):
    for j in range(i + 1, N):
        diff = bin(teams[i] ^ teams[j]).count('1')
        lookup[i][j] = diff
        lookup[j][i] = diff

# compute the maximum difference for each team
for i in range(N):
    max_diff = max(lookup[i].values())
    print(max_diff)

相关问题