python-3.x 如何优化分割重叠范围?

rbl8hiat  于 2023-08-08  发布在  Python
关注(0)|答案(1)|浏览(109)

我编写的这个Python脚本是为了将重叠的范围分割成唯一的范围(last iteration)。它可以产生正确的输出,并且性能优于the version given in the answer。我测试了已知正确方法的输出和另一种强力方法的输出。我确认了所有方法的正确性,我的代码是最有效的。
排列在一行中的无限多个框被编号。每个盒子只能装一个物体,不管最后放进盒子的是什么。它们最初是空的。现在想象一个三元组的列表,每个三元组的前两个元素是整数(第一个不大于第二个),每个三元组代表一个将第三个元素放入框中的指令。三元组(0, 10, 'A')意味着“将'A'放入框0至10(包括框0和框10)”,并且在执行指令后,框0至10中的每一个都包含'A'的示例。空框将被忽略。任务是使用最少数量的三元组来描述执行所有指令后的框的状态。这些规则很简单,有狭义案例和一般案例:
狭义大小写:给定三元组(s1, e1, d1)(s2, e2, d2)s1 < s2 < e1 < e2总是False(输入中的所有配对都符合这一点)。有四个子案例:

  • 情况1:s1 = s2且e2 < e1:

谁先结束谁就赢。盒子总是保存来自赢家的值,给定[(0, 10, 'A'), (0, 5, 'B')],执行后盒子的状态是[(0, 5, 'B'), (6, 10, 'a')]

  • 第二种情况:s1 < s2且e2 < e1:

先结束者获胜,例如:[(0, 10, 'A'), (5, 7, 'B')] -> [(0, 4, 'A'), (5, 7, 'B'), (8, 10, 'A')]

  • 第三种情况:s1 < s2且e2 = e1:

规则同上,但这里是平局。如果出现平局,以较晚开始者为准:[(0, 10, 'A'), (6, 10, 'B')]-> [(0, 5, 'A'), (6, 10, 'B')]的值。

  • 第四种情况:s1 = s2且e1 = e2:

这很特别。在真平局的情况下,输入中较晚出现的那一个获胜。在这种情况下,我的代码不保证框中的对象来自最新的指令(但它保证所有框都只有一个对象)。
其他规则:

  • 如果没有更新,则不执行任何操作。

[(0, 10, 'A'), (5, 6, 'A')]-> [(0, 10, 'A')]的值

  • 如果有间隙,请保持原样(这些框为空):

[(0, 10, 'A'), (15, 20, 'B')]-> [(0, 10, 'A'), (15, 20, 'B')]

  • 如果e1 + 1 = s2且d1 = d2,则将它们合并(最少的三元组)。

[(0, 10, 'A'), (11, 20, 'A')]-> [(0, 20, 'A')]

  • 一般情况下,当狭义情形的基本条件不成立时,在真正相交的情况下。谁先开始谁就赢。

[(0, 10, 'A'), (5, 20, 'B')]-> [(0, 4, 'A'), (5, 20, 'B')]的范围
蛮力实现是正确的,但速度较慢,可以处理一般情况:

def brute_force_discretize(ranges):
    numbers = {}
    ranges.sort(key=lambda x: (x[0], -x[1]))
    for start, end, data in ranges:
        numbers |= {n: data for n in range(start, end + 1)}
    numbers = list(numbers.items())
    l = len(numbers)
    i = 0
    output = []
    while i < l:
        di = 0
        curn, curv = numbers[i]
        while i < l and curn + di == numbers[i][0] and curv == numbers[i][1]:
            i += 1
            di += 1
        output.append((curn, numbers[i-1][0], curv))
    return output

字符串
高性能的智能实施,但仅能处理有限的情况:

from typing import Any, List, Tuple

def get_nodes(ranges: List[Tuple[int, int, Any]]) -> List[Tuple[int, int, Any]]:
    nodes = []
    for ini, fin, data in ranges:
        nodes.extend([(ini, False, data), (fin, True, data)])
    return sorted(nodes)

def merge_ranges(data: List[List[int | Any]], range: List[int | Any]) -> None:
    if not data or range[2] != (last := data[-1])[2] or range[0] > last[1] + 1:
        data.append(range)
    else:
        last[1] = range[1]

def discretize_narrow(ranges):
    nodes = get_nodes(ranges)
    output = []
    stack = []
    actions = []
    for node, end, data in nodes:
        if not end:
            action = False
            if not stack or data != stack[-1]:
                if stack and start < node:
                    merge_ranges(output, [start, node - 1, stack[-1]])
                stack.append(data)
                start = node
                action = True
            actions.append(action)
        elif actions.pop(-1):
            if start <= node:
                merge_ranges(output, [start, node, stack.pop(-1)])
                start = node + 1
            else:
                stack.pop(-1)
    return output


Full script,其产生窄情况。性能表现:

In [518]: sample = make_sample(2048, 65536, 16)

In [519]: %timeit descretize(sample)
4.46 ms ± 32.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [520]: %timeit discretize_narrow(sample)
3.13 ms ± 34.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [521]: list(map(tuple, discretize_narrow(sample))) == descretize(sample)
Out[521]: True


如何使我的代码更快?输入是按升序排序的,但我的代码假定不是这样。如果我将顶层数据拆分为不重叠的范围(利用输入被排序的优点),在存在重叠时将三元组累积到堆栈中,并在重叠结束时推动堆栈进行离散化和清除堆栈,然后连接中间结果,则代码可以更快。我不能让它工作。
一般的情况我怎么处理?我想我需要知道范围何时结束(四种状态),但无法正确处理:

def get_quadruples(ranges):
    nodes = []
    for ini, fin, data in ranges:
        nodes.extend([(ini, False, -fin, data), (fin, True, ini, data)])
    return sorted(nodes)


我不能在代码评审中发布这个,因为这是一个“如何”的问题(我还没有实现我的目标)。为了演示间隔树的缓慢性,this answer on Code Review使用它:

In [531]: %timeit discretize_narrow([(0, 10, 'A'), (0, 1, 'B'), (2, 5, 'C'), (3, 4, 'C'), (6, 7, 'C'), (8, 8, 'D'), (110, 150, 'E'), (250, 300, 'C'), (256, 270, 'D'), (295, 300, 'E'), (500, 600, 'F')])
14.3 µs ± 42.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [532]: %timeit merge_rows([(0, 10, 'A'), (0, 1, 'B'), (2, 5, 'C'), (3, 4, 'C'), (6, 7, 'C'), (8, 8, 'D'), (110, 150, 'E'), (250, 300, 'C'), (256, 270, 'D'), (295, 300, 'E'), (500, 600, 'F')])
891 µs ± 12.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [533]: data = [(0, 10, 'A'), (0, 1, 'B'), (2, 5, 'C'), (3, 4, 'C'), (6, 7, 'C'), (8, 8, 'D'), (110, 150, 'E'), (250, 300, 'C'), (256, 270, 'D'), (295, 300, 'E'), (500, 600, 'F')]

In [534]: merge_rows(data) == discretize_narrow(data)
Out[534]: True

In [535]: sample = make_sample(256, 65536, 16)

In [536]: merge_rows(sample) == discretize_narrow(sample)
Out[536]: True

In [537]: %timeit discretize_narrow(sample)
401 µs ± 3.33 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [538]: %time result = merge_rows(sample)
CPU times: total: 78.1 ms
Wall time: 56 ms


它的效率不高,但我不认为我可以实现一个更有效的版本。可以使用链接的GitHub页面中的代码生成测试用例。我的逻辑和字典更新是一样的,我的蛮力实现就是这样的(使用字典更新来处理范围,根据定义是正确的)。我的聪明的方法的正确性是使用输出的蛮力一个得到验证。
手动测试用例:

In [539]: discretize_narrow([(0, 10, 'A'), (0, 1, 'B'), (2, 5, 'C'), (3, 4, 'C'), (6, 7, 'C'), (8, 8, 'D'), (110, 150, 'E'), (250, 300, 'C'), (256, 270, 'D'), (295, 300, 'E'), (500, 600, 'F')])
Out[539]:
[[0, 1, 'B'],
 [2, 7, 'C'],
 [8, 8, 'D'],
 [9, 10, 'A'],
 [110, 150, 'E'],
 [250, 255, 'C'],
 [256, 270, 'D'],
 [271, 294, 'C'],
 [295, 300, 'E'],
 [500, 600, 'F']]

In [540]: discretize_narrow([(0, 100, 'A'), (10, 25, 'B'), (15, 25, 'C'), (20, 25, 'D'), (30, 50, 'E'), (40, 50, 'F'), (60, 80, 'G'), (150, 180, 'H')])
Out[540]:
[[0, 9, 'A'],
 [10, 14, 'B'],
 [15, 19, 'C'],
 [20, 25, 'D'],
 [26, 29, 'A'],
 [30, 39, 'E'],
 [40, 50, 'F'],
 [51, 59, 'A'],
 [60, 80, 'G'],
 [81, 100, 'A'],
 [150, 180, 'H']]


机器生成的一般情况(以及正确的输出):

In [542]: ranges = []

In [543]: for _ in range(20):
     ...:     start = random.randrange(100)
     ...:     end = random.randrange(100)
     ...:     if start > end:
     ...:         start, end = end, start
     ...:     ranges.append([start, end, random.randrange(5)])

In [544]: ranges.sort()

In [545]: ranges
Out[545]:
[[0, 31, 0],
 [1, 47, 1],
 [1, 67, 0],
 [10, 68, 0],
 [15, 17, 2],
 [18, 39, 0],
 [19, 73, 3],
 [25, 32, 0],
 [26, 33, 1],
 [26, 72, 2],
 [26, 80, 2],
 [28, 28, 1],
 [29, 31, 4],
 [30, 78, 2],
 [36, 47, 0],
 [36, 59, 4],
 [44, 67, 3],
 [52, 61, 4],
 [58, 88, 1],
 [64, 92, 1]]

In [546]: brute_force_discretize(ranges)
Out[546]:
[(0, 0, 0),
 (1, 9, 1),
 (10, 14, 0),
 (15, 17, 2),
 (18, 18, 0),
 (19, 24, 3),
 (25, 25, 0),
 (26, 28, 1),
 (29, 29, 4),
 (30, 35, 2),
 (36, 43, 0),
 (44, 51, 3),
 (52, 57, 4),
 (58, 92, 1)]

编辑

产生一般案例的函数:

def make_generic_case(num, lim, dat):
    ranges = []

    for _ in range(num):
        start = random.randrange(lim)
        end = random.randrange(lim)
        if start > end:
            start, end = end, start
        ranges.append([start, end, random.randrange(dat)])
    
    ranges.sort()
    return ranges


来自现有答案的代码与正确结果不一致的示例输入:
测试用例:

[
    [0, 31, 0],
    [1, 47, 1],
    [1, 67, 0],
    [10, 68, 0],
    [15, 17, 2],
    [18, 39, 0],
    [19, 73, 3],
    [25, 32, 0],
    [26, 33, 1],
    [26, 72, 2],
    [26, 80, 2],
    [28, 28, 1],
    [29, 31, 4],
    [30, 78, 2],
    [36, 47, 0],
    [36, 59, 4],
    [44, 67, 3],
    [52, 61, 4],
    [58, 88, 1],
    [64, 92, 1]
]


输出量:

[
    (0, 0, 0),
    (1, 9, 1),
    (10, 14, 0),
    (15, 17, 2),
    (18, 18, 0),
    (19, 24, 3),
    (25, 25, 0),
    (26, 28, 1),
    (29, 29, 4),
    (30, 35, 2),
    (36, 43, 0),
    (44, 51, 3),
    (52, 57, 4),
    (58, 88, 1)
]


输出与正确的几乎相同,除了最后一个范围,它应该是(58, 92, 1)而不是(58, 88, 1)
还有另一种情况:

[
    [4, 104, 4],
    [22, 463, 2],
    [24, 947, 2],
    [36, 710, 1],
    [37, 183, 1],
    [39, 698, 7],
    [51, 438, 4],
    [60, 450, 7],
    [120, 383, 2],
    [130, 193, 7],
    [160, 562, 5],
    [179, 443, 6],
    [186, 559, 6],
    [217, 765, 2],
    [221, 635, 2],
    [240, 515, 3],
    [263, 843, 3],
    [274, 759, 6],
    [288, 389, 5],
    [296, 298, 6],
    [333, 1007, 1],
    [345, 386, 5],
    [356, 885, 3],
    [377, 435, 5],
    [407, 942, 7],
    [423, 436, 1],
    [484, 926, 5],
    [496, 829, 0],
    [559, 870, 5],
    [610, 628, 1],
    [651, 787, 4],
    [735, 927, 1],
    [765, 1002, 1]
]


输出量:

[(4, 21, 4),
 (22, 35, 2),
 (36, 38, 1),
 (39, 50, 7),
 (51, 59, 4),
 (60, 119, 7),
 (120, 129, 2),
 (130, 159, 7),
 (160, 178, 5),
 (179, 216, 6),
 (217, 239, 2),
 (240, 273, 3),
 (274, 287, 6),
 (288, 295, 5),
 (296, 298, 6),
 (299, 332, 5),
 (333, 344, 1),
 (345, 355, 5),
 (356, 376, 3),
 (377, 406, 5),
 (407, 422, 7),
 (423, 436, 1),
 (437, 483, 7),
 (484, 495, 5),
 (496, 558, 0),
 (559, 609, 5),
 (610, 628, 1),
 (629, 650, 5),
 (651, 734, 4),
 (735, 927, 1),
 (928, 942, 7),
 (943, 1007, 1)]


同样,它几乎是正确的,但最后三个范围是错误的。
它们在这里是(735, 927, 1), (928, 942, 7), (943, 1007, 1),但应该是(735, 1007, 1)
我已经测试了一些其他的输入,所提出的解决方案的输出是正确的,但我已经确定了一些边缘情况下,它不是。
我刚刚实现了算法found here,但它不能正确工作,并且在狭义情况下失败:

def discretize_gen(ranges):
    nodes = get_nodes(ranges)
    stack = []
    for (n1, e1, d1), (n2, e2, _) in zip(nodes, nodes[1:]):
        if e1:
            stack.remove(d1)
        else:
            stack.append(d1)
        start = n1 + e1
        end = n2 - (not e2)
        if start <= end and stack:
            yield start, end, stack[-1]


使用方法如下:list(merge(discretize_gen(ranges)))merge的函数可以在下面的答案中找到。
它对许多输入都失败了,我刚刚写了一个丑陋的函数来比较输出与正确的输出:

def compare_results(ranges):
    correct = brute_force_discretize(ranges)
    correct_set = set(correct)
    output = list(merge(discretize_gen(ranges)))
    output_set = set(output)
    errors = output_set ^ correct_set
    indices = [(i, correct.index(e)) for i, e in enumerate(output) if e not in errors]
    indices.append((len(output), None))
    comparison = [(a, b) if c - a == 1 else (slice(a, c), slice(b, d)) for (a, b), (c, d) in zip(indices, indices[1:])]
    result = {}
    for a, b in comparison:
        key = output[a]
        val = correct[b]
        if isinstance(key, list):
            key = tuple(key)
        result[key] = val
    return result


测试用例失败:

[(73, 104, 3), (75, 98, 0), (78, 79, 3), (83, 85, 3), (88, 90, 2)]


比较:

{(73, 74, 3): (73, 74, 3),
 ((75, 77, 0), (78, 87, 3)): [(75, 77, 0),
  (78, 79, 3),
  (80, 82, 0),
  (83, 85, 3),
  (86, 87, 0)],
 ((88, 90, 2), (91, 104, 3)): [(88, 90, 2), (91, 98, 0), (99, 104, 3)]}


我不知道为什么它失败了,答案有一个11分,并被接受,但它不工作不知何故。怎么修?

* 编辑 *

函数make_generic_case中有一个缩进错误,导致附加动作无意中成为if条件块的一部分,因此只有当start大于end时才会附加新的范围,start和end变量将被交换,并且范围将被附加。
我的意图是为所有生成的对附加一个新的范围,我通过取消缩进行来修复这个问题。我现在才注意到这个错误,这个错误是由复制粘贴引起的,我将代码粘贴到Visual Studio Code中,自动缩进修复打破了缩进。

bvjveswy

bvjveswy1#

下面是一个类似于discretize_narrow中的方法,当我用make_sample(2048, 65536, 16)输入计时时,性能似乎略有提高(10-20%):

def solve(ranges):
    if not ranges:
        return []
        
    def disjoint_segments(ranges):
        ranges.sort(key=lambda x: (x[0], -x[1]))
        # In order to get extra loop iteration, add one dummy range
        ranges.append((float("inf"), None, None))  
        stack_end = []  # use separate stacks for ends, and data
        stack_data = []
        current = ranges[0][0]
        for start, end, data in ranges:
            # flush data from stack up to start - 1.
            while stack_end and stack_end[-1] < start:
                end2 = stack_end.pop()
                data2 = stack_data.pop()
                if current <= end2:
                    yield current, end2, data2
                    current = end2 + 1
            if stack_end and current < start:
                yield current, start - 1, stack_data[-1]
            # stack the current tuple's info
            current = start
            if not stack_end or stack_data[-1] != data or end > stack_end[-1]:
                stack_end.append(end)
                stack_data.append(data)

    def merge(segments):
        start, end, data = next(segments)  # keep one segment buffered
        for start2, end2, data2 in segments:
            if end + 1 == start2 and data == data2:  # adjacent & same data
                end = end2  # merge
            else:
                yield start, end, data
                start, end, data = start2, end2, data2
        yield start, end, data  # flush the buffer
    
    return list(merge(disjoint_segments(ranges)))

字符串
上面的函数也应该为“通用”情况产生正确的结果。

相关问题