gcc 整数立方根

eufgjt7s  于 2022-11-30  发布在  其他
关注(0)|答案(8)|浏览(121)

我正在寻找64位(无符号)立方根的快速代码。(我使用C并用gcc编译,但我想大多数所需的工作将是语言和编译器不可知的。)我将用ulong表示64位unisgned整数。
给定一个输入n,我要求(整数)返回值r满足:

r * r * r <= n && n < (r + 1) * (r + 1) * (r + 1)

也就是说,我想要n的立方根,向下取整。

return (ulong)pow(n, 1.0/3);

是不正确的,因为它会向范围的末尾舍入。

ulong
cuberoot(ulong n)
{
    ulong ret = pow(n + 0.5, 1.0/3);
    if (n < 100000000000001ULL)
        return ret;
    if (n >= 18446724184312856125ULL)
        return 2642245ULL;
    if (ret * ret * ret > n) {
        ret--;
        while (ret * ret * ret > n)
            ret--;
        return ret;
    }
    while ((ret + 1) * (ret + 1) * (ret + 1) <= n)
        ret++;
    return ret;
}

会提供正确的结果,但比需要的速度慢。
这段代码是针对一个数学库的,它将被各种函数多次调用。速度很重要,但你不能指望一个热缓存(所以像2,642,245个条目的二分搜索这样的建议是正确的)。
为了进行比较,下面的代码 * 正确 * 计算整数平方根。

ulong squareroot(ulong a) {
    ulong x = (ulong)sqrt((double)a);
    if (x > 0xFFFFFFFF || x*x > a)
        x--;
    return x;
}
disbfnqx

disbfnqx1#

《黑客的喜悦》一书中有解决这个问题和许多其他问题的算法。代码在线here。* 编辑 *:这段代码不能在64位int上正常工作,书中关于如何在64位上修复它的说明有些混乱。一个正确的64位实现(包括测试用例)是在线的here
我怀疑您的squareroot函数是否能“正确”工作--参数应该是ulong a,而不是n:)(但是使用cbrt而不是sqrt也可以实现同样的方法,尽管不是所有的C数学库都有立方根函数)。

ma8fv8wu

ma8fv8wu2#

我已经将1.5.2(第k个根)中的算法应用到Modern Computer Arithmetic (Brent and Zimmerman)中。对于(k == 3)的情况,并给出了一个“相对”准确的初始猜测的高估--这个算法似乎比上面的“黑客的喜悦”代码更好。
不仅如此,MCA作为文本提供了理论背景以及正确性和终止准则的证明。
如果我们能够产生一个“相对”良好的初始高估,我还没有能够找到一个超过(7)迭代。(这与2^6位的64位值有效相关吗?)无论哪种方式,它都比HacDel代码中的(21)次迭代有所改进-具有线性O(b)收敛性,尽管循环体明显快得多。
我使用的初始估计是基于对值(x)中有效位数的“四舍五入”。给定(x)中的(b)个有效位,我们可以说:
2^(b - 1) <= x < 2^b
。我在没有证据的情况下声明(尽管应该相对容易证明):2^ceil(b / 3) > x^(1/3)

static inline uint32_t u64_cbrt (uint64_t x)
{
    uint64_t r0 = 1, r1;

    /* IEEE-754 cbrt *may* not be exact. */

    if (x == 0) /* cbrt(0) : */
        return (0);

    int b = (64) - __builtin_clzll(x);
    r0 <<= (b + 2) / 3; /* ceil(b / 3) */

    do /* quadratic convergence: */
    {
        r1 = r0;
        r0 = (2 * r1 + x / (r1 * r1)) / 3;
    }
    while (r0 < r1);

    return ((uint32_t) r1); /* floor(cbrt(x)); */
}

crbt调用可能并不是那么有用--不像sqrt调用,它可以在现代硬件上高效地实现。也就是说,我已经看到了2^53下的值集的增益(精确地用IEEE-754双精度数表示),这让我很惊讶。
唯一的缺点是除以:(r * r)-这可能会很慢,因为整数除法的延迟继续落后于ALU中的其他进步。除以常数:(3)在任何现代优化编译器上都是通过互逆方法处理的。
有趣的是,英特尔的“Icelake”微架构将显著改进整数除法--这一运算似乎已经被忽视了很长一段时间。在找到可靠的理论基础之前,我根本不会相信“黑客的喜悦”的答案。然后我必须找出哪个变体是“正确”的答案。

mspsb9vt

mspsb9vt3#

您可以尝试牛顿步来修正舍入误差:

ulong r = (ulong)pow(n, 1.0/3);
if(r==0) return r; /* avoid divide by 0 later on */
ulong r3 = r*r*r;
ulong slope = 3*r*r;

ulong r1 = r+1;
ulong r13 = r1*r1*r1;

/* making sure to handle unsigned arithmetic correctly */
if(n >= r13) r+= (n - r3)/slope;
if(n < r3)   r-= (r3 - n)/slope;

一个牛顿步应该就足够了,但是你可能会有一个(或者更多?)的错误。你可以使用最后的检查和增量步骤来检查/修复这些错误,就像你的OQ中一样:

while(r*r*r > n) --r;
while((r+1)*(r+1)*(r+1) <= n) ++r;

或类似的。
(我承认我懒;正确的方法是仔细检查,以确定哪些(如果有的话)检查和增量的东西实际上是必要的...)

am46iovg

am46iovg4#

如果pow的开销太大,您可以使用计数前导零指令来获得结果的近似值,然后使用查找表,再用一些牛顿步骤来完成它。

int k = __builtin_clz(n); // counts # of leading zeros (often a single assembly insn)
int b = 64 - k;           // # of bits in n
int top8 = n >> (b - 8);  // top 8 bits of n (top bit is always 1)
int approx = table[b][top8 & 0x7f];

给定btop8,你可以使用一个查找表(在我的代码中,有8K个条目)来找到cuberoot(n)的一个很好的近似值。

yws3nbqq

yws3nbqq5#

// On my pc: Math.Sqrt 35 ns, cbrt64 <70ns, cbrt32 <25 ns, (cbrt12 < 10ns)

// cbrt64(ulong x) is a C# version of:
// http://www.hackersdelight.org/hdcodetxt/acbrt.c.txt     (acbrt1)

// cbrt32(uint x) is a C# version of:
// http://www.hackersdelight.org/hdcodetxt/icbrt.c.txt     (icbrt1)

// Union in C#:
// http://www.hanselman.com/blog/UnionsOrAnEquivalentInCSairamasTipOfTheDay.aspx

using System.Runtime.InteropServices;  
[StructLayout(LayoutKind.Explicit)]  
public struct fu_32   // float <==> uint
{
[FieldOffset(0)]
public float f;
[FieldOffset(0)]
public uint u;
}

private static uint cbrt64(ulong x)
{
    if (x >= 18446724184312856125) return 2642245;
    float fx = (float)x;
    fu_32 fu32 = new fu_32();
    fu32.f = fx;
    uint uy = fu32.u / 4;
    uy += uy / 4;
    uy += uy / 16;
    uy += uy / 256;
    uy += 0x2a5137a0;
    fu32.u = uy;
    float fy = fu32.f;
    fy = 0.33333333f * (fx / (fy * fy) + 2.0f * fy);
    int y0 = (int)                                      
        (0.33333333f * (fx / (fy * fy) + 2.0f * fy));    
    uint y1 = (uint)y0;                                 

    ulong y2, y3;
    if (y1 >= 2642245)
    {
        y1 = 2642245;
        y2 = 6981458640025;
        y3 = 18446724184312856125;
    }
    else
    {
        y2 = (ulong)y1 * y1;
        y3 = y2 * y1;
    }
    if (y3 > x)
    {
        y1 -= 1;
        y2 -= 2 * y1 + 1;
        y3 -= 3 * y2 + 3 * y1 + 1;
        while (y3 > x)
        {
            y1 -= 1;
            y2 -= 2 * y1 + 1;
            y3 -= 3 * y2 + 3 * y1 + 1;
        }
        return y1;
    }
    do
    {
        y3 += 3 * y2 + 3 * y1 + 1;
        y2 += 2 * y1 + 1;
        y1 += 1;
    }
    while (y3 <= x);
    return y1 - 1;
}

private static uint cbrt32(uint x)
{
    uint y = 0, z = 0, b = 0;
    int s = x < 1u << 24 ? x < 1u << 12 ? x < 1u << 06 ? x < 1u << 03 ? 00 : 03 :
                                                         x < 1u << 09 ? 06 : 09 :
                                          x < 1u << 18 ? x < 1u << 15 ? 12 : 15 :
                                                         x < 1u << 21 ? 18 : 21 :
                           x >= 1u << 30 ? 30 : x < 1u << 27 ? 24 : 27;
    do
    {
        y *= 2;
        z *= 4;
        b = 3 * y + 3 * z + 1 << s;
        if (x >= b)
        {
            x -= b;
            z += 2 * y + 1;
            y += 1;
        }
        s -= 3;
    }
    while (s >= 0);
    return y;
}

private static uint cbrt12(uint x) // x < ~255
{
    uint y = 0, a = 0, b = 1, c = 0;
    while (a < x)
    {
        y++;
        b += c;
        a += b;
        c += 6;
    }
    if (a != x) y--;
    return y;
}
50few1ms

50few1ms6#

从Fabian Giesen的答案中GitHub内的代码开始,我得到了以下更快的实现:

#include <stdint.h>

static inline uint64_t icbrt(uint64_t x) {
  uint64_t b, y, bits = 3*21;
  int s;
  for (s = bits - 3; s >= 0; s -= 3) {
    if ((x >> s) == 0)
      continue;
    x -= 1 << s;
    y = 1;
    for (s = s - 3; s >= 0; s -= 3) {
      y += y;
      b = 1 + 3*y*(y + 1);
      if ((x >> s) >= b) {
        x -= b << s;
        y += 1;
      }
    }
    return y;
  }
  return 0;
}

虽然上面的方法仍然比依赖于GNU特定__builtin_clzll的方法慢一些,但是上面的方法没有使用编译器的特定信息,因此是完全可移植的。

bits常数

常数bits越小,计算速度越快,但函数给出正确结果的最大数x(1 << bits) - 1。另外,bits必须是3的倍数,最大值为64,这意味着它的最大值实际上是3*21 == 63。对于bits = 3*21icbrt()因此对输入x <= 9223372036854775807起作用。如果我们知道一个程序在有限的x下工作,比如x < 1000000,那么我们可以通过设置bits = 3*7来加速方根的计算,因为(1 << 3*7) - 1 = 2097151 >= 1000000

64位与32位整数

虽然上面是针对64位整数编写的,但对于32位整数,逻辑是相同的:

#include <stdint.h>

static inline uint32_t icbrt(uint32_t x) {
  uint32_t b, y, bits = 3*7;  /* or whatever is appropriate */
  int s;
  for (s = bits - 3; s >= 0; s -= 3) {
    if ((x >> s) == 0)
      continue;
    x -= 1 << s;
    y = 1;
    for (s = s - 3; s >= 0; s -= 3) {
      y += y;
      b = 1 + 3*y*(y + 1);
      if ((x >> s) >= b) {
        x -= b << s;
        y += 1;
      }
    }
    return y;
  }
  return 0;
}
cld4siwp

cld4siwp7#

我将research how to do it by hand,然后将其转换为计算机算法,以2为基数而不是以10为基数工作。
我们最终得到一个类似于(伪代码)的算法:

Find the largest n such that (1 << 3n) < input.
result = 1 << n.
For i in (n-1)..0:
    if ((result | 1 << i)**3) < input:
        result |= 1 << i.

我们可以优化(result | 1 << i)**3的计算,方法是观察到按位或等于加法,重构为result**3 + 3 * i * result ** 2 + 3 * i ** 2 * result + i ** 3,在迭代之间缓存result**3result**2的值,并使用移位而不是乘法。

dauxcl2d

dauxcl2d8#

您可以尝试采用以下C算法:

#include <limits.h>

// return a number that, when multiplied by itself twice, makes N. 
unsigned cube_root(unsigned n){
    unsigned a = 0, b;
    for (int c = sizeof(unsigned) * CHAR_BIT / 3 * 3 ; c >= 0; c -= 3) {
        a <<= 1;
        b = a + (a << 1), b = b * a + b + 1 ;
        if (n >> c >= b)
            n -= b << c, ++a;
    }
    return a;
}

还存在着:

// return the number that was multiplied by itself to reach N.
unsigned square_root(const unsigned num) {
    unsigned a, b, c, d;
    for (b = a = num, c = 1; a >>= 1; ++c);
    for (c = 1 << (c & -2); c; c >>= 2) {
        d = a + c;
        a >>= 1;
        if (b >= d)
            b -= d, a += c;
    }
    return a;
}

Source

相关问题