Python -计算从1到N的K个数字的所有组合,其总和等于N

9nvpjoqh  于 2023-05-27  发布在  Python
关注(0)|答案(4)|浏览(280)

如何计算1-n中所有和等于n的k个数字的组合?例如,对于n = 10,k = 3,我们有(1,2,7),(1,3,6),(1,4,5),(2,3,5)
我试过使用itertools.combination,但它对大数字增长得非常快

rkttyhzu

rkttyhzu1#

使用缓存的递归方法可以在合理的时间内产生结果:

from functools import lru_cache

@lru_cache(None)
def countNK(n,k,t=None):
    t = n if t is None else t                    # track target partial sum
    if k == 0:  return int(t==0)                 # empty set can sum to zero
    if t < 1 :  return 0                         # valid target only
    if k > n :  return 0                         # not enough values
    return countNK(n-1,k,t)+countNK(n-1,k-1,t-n) # combine counts
  • 递归需要使用逐渐变小的n值来瞄准目标
  • 在从目标中删除每个值后,它还需要对较短的组合执行此操作
  • 这将多次合并相同的计算,因此缓存

...
输出:

print(countNK(10,3))     # 4

print(countNK(200,10))   # 98762607
  • 如果需要处理较大的n值(例如500+),你需要增加递归限制或者将函数转换为迭代循环 *
vdzxcuhz

vdzxcuhz2#

基准测试n=100,所有k从0到100,Kelly*是我的解决方案:

2.5 ±  0.1 ms  Kelly
  2.8 ±  0.2 ms  Kelly2
  3.5 ±  0.2 ms  Dave_translated_by_Kelly
295.0 ± 23.7 ms  Alain

令c(n,k)是具有和n的组合的数量,其中k个不同的数为1或更大。
我们得到:c(n, k) = c(n-k, k) + c(n-k, k-1)
你想和n与k个不同的数字1或更大。你要么用数字1,要么不用。

  • 如果你不使用1,那么你需要将n与k个不同的数字2或更大的数字相加。假设你有这样的k个数字。从它们中的每一个减去1,然后你有n-k和k个不同的数字1或更大。这就是c(n-k,k)。
  • 如果你使用1,那么你想要剩余的和n-1与k-1个不同的数字2或更大。假设你有这样的k-1数。从它们中的每一个减去1,那么你有sum(n-1)-(k-1)= n-k,其中k-1个不同的数字1或更大。这就是c(n-k,k-1)。

Dave的情况n=9000,k=100的更快的解决方案:

469.1 ±  9.2 ms  Kelly2
478.8 ± 17.0 ms  Kelly
673.4 ± 18.8 ms  Dave_translated_by_Kelly

代码(Attempt This Online!):

def Kelly(n, k):
    if k == 0:
        return 1 if n == 0 else 0
    @cache
    def c(n, k):
        if n < k * (k+1) // 2:
            return 0
        if k == 1:
            return 1
        n -= k
        return c(n, k) + c(n, k-1)
    return c(n, k)

# Precompute the bounds for the "n < ..." base case
def Kelly2(n, k):
    if k == 0:
        return 1 if n == 0 else 0
    min_n_for_k = list(accumulate(range(k+1)))
    @cache
    def c(n, k):
        if n < min_n_for_k[k]:
            return 0
        if k == 1:
            return 1
        n -= k
        return c(n, k) + c(n, k-1)
    return c(n, k)

def Alain(n, k):
    @lru_cache(None)
    def countNK(n,k,t=None):
        t = n if t is None else t                    # track target partial sum
        if k == 0:  return int(t==0)                 # empty set can sum to zero
        if t < 1 :  return 0                         # valid target only
        if k > n :  return 0                         # not enough values
        return countNK(n-1,k,t)+countNK(n-1,k-1,t-n) # combine counts
    return countNK(n, k)

def Dave_translated_by_Kelly(n, k):

  def choose(n, k):
    if k > n: return 0
    result = 1
    for d in range(1, k+1):
      result *= n
      result //= d
      n -= 1
    return result

  def count_partitions_nozeroes(n, k, cache = {}):
    if k==0 and n==0: return 1
    if n <= 0 or k <= 0: return 0

    # Check if the result is already memoized
    if (n, k) in cache:
      return cache[n, k]

    # Calculate the result
    result = count_partitions_nozeroes(n-1, k-1, cache) + count_partitions_nozeroes(n-k, k, cache)

    # Memoize the result for future use
    cache[n, k] = result
    return result

  def count_partitions_zeros(n,k):
    return count_partitions_nozeroes(n+k, k)

  def solve(n,k):
    r = n - choose(k+1,2)
    return count_partitions_zeros(r,k)

  return solve(n, k)

big = False

funcs = Alain, Kelly, Kelly2, Dave_translated_by_Kelly

if big:
    funcs = funcs[1:]

from functools import lru_cache, cache
from itertools import accumulate
from time import perf_counter as time
from statistics import mean, stdev
import sys
import gc

# Correctness
for n in range(51):
    for k in range(51):
        expect = funcs[0](n, k)
        for f in funcs[1:]:
            result = f(n, k)
            assert result == expect

# Speed
sys.setrecursionlimit(20000)
times = {f: [] for f in funcs}
def stats(f):
    ts = [t * 1e3 for t in sorted(times[f])[:5]]
    return f'{mean(ts):5.1f} ± {stdev(ts):4.1f} ms '
for _ in range(25):
    for f in funcs:
        gc.collect()
        t0 = time()
        if big:
           f(9000, 100)
        else:
            for k in range(101):
                f(100, k)
        times[f].append(time() - t0)
for f in sorted(funcs, key=stats):
    print(stats(f), f.__name__)
cunj1qz1

cunj1qz13#

我们可以用k个不同的正整数组成的最小数是choose(k+1,2)。
设r(n,k)= n - choose(k+1,2).
那么从k个不同的整数形成n的方法的计数等于将k个非负的不一定不同的整数求和以得到r(n,k)的方法的计数。我们的想法是从1,2,3,...,k开始,然后以非递减的方式将r(n,k)分配给这些起始整数。
例如,10、3:
所以我们把问题简化为计算k个非负整数求和得到r(n,k)的方法的数量。已回答here
Ruby代码(包括util函数):

def choose(n, k)
  return 0 if k > n
  result = 1
  1.upto(k) do |d|
    result *= n
    result /= d
    n -= 1
  end
  return result
end

def count_partitions_nozeroes(n, k, cache = {})
  return 1 if k==0 && n==0
  return 0 if n <= 0 || k <= 0

  # Check if the result is already memoized
  if cache.key?([n, k])
    return cache[[n, k]]
  end

  # Calculate the result
  result = count_partitions_nozeroes(n-1, k-1, cache) + count_partitions_nozeroes(n-k, k, cache)

  # Memoize the result for future use
  cache[[n, k]] = result
  return result
end

def count_partitions_zeros(n,k)
  return count_partitions_nozeroes(n+k, k)
end

def solve(n,k)
  r = n - choose(k+1,2)
  return count_partitions_zeros(r,k)
end

样品结果

> solve(10,3)
=> 4

> solve(200,10)
=> 98762607

> solve(2000,10)
=> 343161146717017732

> solve(2000,100) # correct that there's no solution
=> 0

> solve(2000,40)
=> 2470516759655914864269838818691

> solve(5000,50)
=> 961911722856534054414857561149346788190620561928079

> solve(9000,100)
=> 74438274524772625088229884845232647085568457172246625852148213

这里有一个更简单的Ruby版本,它避免了递归(其他方法不变)。它给出了与上面相同的结果。下面显示了较大数字的一些结果。时间复杂度为O(n)。

def count_partitions_nozeroes(n, k)
  n_to_k_to_count = Hash.new{|h, n| h[n] = Hash.new{|h2, k| h2[k] = 0}}
  n_to_k_to_count[n][k] = 1
  
  (n).downto(1) do |cur_n|
    n_to_k_to_count.delete(cur_n + 1) # delete old keys to save space
    n_to_k_to_count[cur_n].keys.each do |cur_k|
      n_to_k_to_count[cur_n - 1][cur_k - 1] += n_to_k_to_count[cur_n][cur_k] if cur_n >= 1 && cur_k >= 1
      n_to_k_to_count[cur_n - cur_k][cur_k] += n_to_k_to_count[cur_n][cur_k] if cur_n >= cur_k && cur_k >= 0
    end
  end
  return n_to_k_to_count[0][0] 
end

样品结果

> solve(10_000, 100)
=> 274235043379646744332574760930015102932669961381003514201948469288939

> solve(20_000, 100)
=> 7299696028160228272878582999080106323327610318395689691894033570930310212378988634117070675146218304092757

> solve(30_000, 100)
=> 272832080760303721646457320315409638838332197621252917061852201523368622283328266190355855228845140740972789576932357443034296

> solve(40_000, 200)
=> 1207940070190155086319681977786735094825631330761751426889808559216057614938892266960158470822904722575922933920904751545295375665942760497367

> solve(100_000, 200)
=> 13051215883535384859396062192804954511590479767894013629996324213956689010966899432038449004533035681835942448619230013858515264041486939129111486281204426757510182253404556858519289275662797170197384965998425620735381780708992863774464769

> solve(1_000_000, 200) # getting painfully slow; 3.5 mins
=> 42888085617859871072014862493356049406160707924757355757377806772267059145453158292921778894240787681100326388859698107659554647376742676484705287095709871992089520633323366183055674466048100639306064833776787643422680599710237129079050538847275806415974795879584513402381125673297339438303953873226899382823803432464875135708283442981500695089121425622135472568284901515995857775659213466818843464541496090119445962587194304280691087464026800781
6qqygrtg

6qqygrtg4#

让我们引入一个函数:f(n,k,s) =从1到nk的组合数,s作为它们的总和。
为了解决这个问题,我们需要计算f(n,k,n)
可以递归地计算该函数。所有组合可分为两组:有和没有最大值。这就是f(n,k,s)=f(n-1,k-1,s-n)+f(n-1,k,s)。在以下情况下,递归可能会停止:

  • n<k -> 0(我们没有足够的数字)
  • k=1,s>n -> 0(每个数字都太小)
  • k=1,s<1 -> 0(每个数字都太小)
  • k=1,1<=s<=n -> 1(只有一个合适的数字)

N^2*k可能的参数组合,因此如果我们缓存已经计算的值,我们将在O(N^3)内。

相关问题