To speed up my bignum divisons I need to speed up operation y = x^2
for bigints which are represented as dynamic arrays of unsigned DWORDs. To be clear:
DWORD x[n+1] = { LSW, ......, MSW };
- where n+1 is number of used DWORDs
- so value of number
x = x[0]+x[1]<<32 + ... x[N]<<32*(n)
The question is: How do I compute y = x^2
as fast as possible without precision loss? - Using C++ and with integer arithmetics (32bit with Carry) at disposal.
My current approach is applying multiplication y = x*x
and avoid multiple multiplications.
For example:
x = x[0] + x[1]<<32 + ... x[n]<<32*(n)
For simplicity, let me rewrite it:
x = x0+ x1 + x2 + ... + xn
where index represent the address inside the array, so:
y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)
y0 = x0*x0
y1 = x1*x0 + x0*x1
y2 = x2*x0 + x1*x1 + x0*x2
y3 = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n ) + x(n-1)*x(n-1) + x(n )*x(n-2)
y(2n-2) = xn(n-1)*x(n ) + x(n )*x(n-1)
y(2n-1) = xn(n )*x(n )
After a closer look, it is clear that almost all xi*xj
appears twice (not the first and last one) which means that N*N
multiplications can be replaced by (N+1)*(N/2)
multiplications. P.S. 32bit*32bit = 64bit
so the result of every mul+add
operation is handled as 64+1 bit
.
Is there a better way to compute this fast? All I found during searches were sqrts algorithms, not sqr...
Fast sqr
!!! Beware that all numbers in my code are MSW first,... not as in above test (there are LSW first for simplicity of equations, otherwise it would be an index mess).
Current functional fsqr implementation
void arbnum::sqr(const arbnum &x)
{
// O((N+1)*N/2)
arbnum c;
DWORD h, l;
int N, nx, nc, i, i0, i1, k;
c._alloc(x.siz + x.siz + 1);
nx = x.siz - 1;
nc = c.siz - 1;
N = nx + nx;
for (i=0; i<=nc; i++)
c.dat[i]=0;
for (i=1; i<N; i++)
for (i0=0; (i0<=nx) && (i0<=i); i0++)
{
i1 = i - i0;
if (i0 >= i1)
break;
if (i1 > nx)
continue;
h = x.dat[nx-i0];
if (!h)
continue;
l = x.dat[nx-i1];
if (!l)
continue;
alu.mul(h, l, h, l);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k], l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k],h);
k--;
for (; (alu.cy) && (k>=0); k--)
alu.inc(c.dat[k]);
}
c.shl(1);
for (i = 0; i <= N; i += 2)
{
i0 = i>>1;
h = x.dat[nx-i0];
if (!h)
continue;
alu.mul(h, l, h, h);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k],l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k], h);
k--;
for (; (alu.cy) && (k >= 0); k--)
alu.inc(c.dat[k]);
}
c.bits = c.siz<<5;
c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
c.sig = sig;
*this = c;
}
Use of Karatsuba multiplication
(thanks to Calpis)
I implemented Karatsuba multiplication but the results are massively slower even than by use of simple O(N^2)
multiplication, probably because of that horrible recursion that I can't see any way to avoid. It's trade-off must be at really large numbers (bigger than hundreds of digits) ... but even then there are a lot of memory transfers. Is there a way to avoid recursion calls (non-recursive variant,... Almost all recursive algorithms can be done that way). Still, I will try to tweak things up and see what happens (avoid normalizations, etc..., also it could be some silly mistake in the code). Anyway, after solving Karatsuba for case x*x
there is not much performance gain.
Optimized Karatsuba multiplication
Performance test for y = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits
:
x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication
x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]
x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]
After optimizations for Karatsuba, the code is massively faster than before. Still, for smaller numbers it is slightly less than half speed of my O(N^2)
multiplication. For bigger numbers, it is faster with the ratio given by the complexities of Booth multiplications. The threshold for multiplication is around 3298 bits and for sqr around 32389 bits, so if the sum of input bits cross this threshold then Karatsuba multiplication will be used for speeding up multiplication and that goes similar for sqr too.
BTW, optimizations included:
- Minimize heap trashing by too-big recursion argument
- Avoidance of any bignum aritmetics (+,-) 32-bit ALU with carry is used instead.
- Ignoring
0*y
orx*0
or0*0
cases - Reformatting input
x,y
number sizes to power of two to avoid reallocating - Implement modulo multiplication for
z1 = (x0 + x1)*(y0 + y1)
to minimize recursion
Modified Schönhage-Strassen multiplication to sqr implementation
I have tested use of FFT and NTT transforms to speed up sqr computation. The results are these:
- FFT
Lose accuracy and therefore need high precision complex numbers. This actually slows things down considerably so no speedup is present. The result is not precise (can be wrongly rounded)so FFT is unusable (for now) - NTT
NTT is finite field DFT and so no accuracy loss occurs. It need modular arithmetics on unsigned integers: modpow, modmul, modadd
and modsub
.
I use DWORD
(32bit unsigned integer numbers). The NTT input/otput vector size is limited because of overflow issues!!! For 32-bit modular arithmetics, N
is limited to (2^32)/(max(input[])^2)
so bigint
must be divided to smaller chunks (I use BYTES
so maximum size of bigint
processed is
(2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs)
The sqr
uses only 1xNTT + 1xINTT
instead of 2xNTT + 1xINTT
for multiplication but NTT usage is too slow and the threshold number size is too large for practical use in my implementation (for mul
and also for sqr
).
Is possible that is even over the overflow limit so 64-bit modular arithmetics should be used which can slow things down even more. So NTT is for my purposes also unusable too.
Some measurements:
a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul
My implementation:
void arbnum::sqr_NTT(const arbnum &x)
{
// O(N*log(N)*(log(log(N)))) - 1x NTT
// Schönhage-Strassen sqr
// To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
int i, j, k, n;
int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
i = x.siz;
for (n = 1; n < i; n<<=1)
;
if (n + n > 0x3000) {
_error(_arbnum_error_TooBigNumber);
zero();
return;
}
n <<= 3;
DWORD *xx, *yy, q, qq;
xx = new DWORD[n+n];
#ifdef _mmap_h
if (xx)
mmap_new(xx, (n+n) << 2);
#endif
if (xx==NULL) {
_error(_arbnum_error_NotEnoughMemory);
zero();
return;
}
yy = xx + n;
// Zero padding (and split DWORDs to BYTEs)
for (i--, k=0; i >= 0; i--)
{
q = x.dat[i];
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++;
}
for (;k<n;k++)
xx[k] = 0;
//NTT
fourier_NTT ntt;
ntt.NTT(yy,xx,n); // init NTT for n
// Convolution
for (i=0; i<n; i++)
yy[i] = modmul(yy[i], yy[i], ntt.p);
//INTT
ntt.INTT(xx, yy);
//suma
q=0;
for (i = 0, j = 0; i<n; i++) {
qq = xx[i];
q += qq&0xFF;
yy[n-i-1] = q&0xFF;
q>>=8;
qq>>=8;
q+=qq;
}
// Merge WORDs to DWORDs and copy them to result
_alloc(n>>2);
for (i = 0, j = 0; i<siz; i++)
{
q =(yy[j]<<24)&0xFF000000; j++;
q |=(yy[j]<<16)&0x00FF0000; j++;
q |=(yy[j]<< 8)&0x0000FF00; j++;
q |=(yy[j] )&0x000000FF; j++;
dat[i] = q;
}
#ifdef _mmap_h
if (xx)
mmap_del(xx);
#endif
delete xx;
bits = siz<<5;
sig = s;
exp = exp0 + (siz<<5) - 1;
// _normalize();
}
Conclusion
For smaller numbers, it is the best option my fast sqr
approach, and after threshold Karatsuba multiplication is better. But I still think there should be something trivial which we have overlooked. Has anyone other ideas?
NTT optimization
After massively-intense optimizations (mostly NTT): Stack Overflow question Modular arithmetics and NTT (finite field DFT) optimizations .
Some values have changed:
a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul
So now NTT multiplication is finally faster than Karatsuba after about 1500*32-bit threshold.
Some measurements and bug spotted
a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[ 58.656 ms ] fast sqr
sqr2[ 13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[ 28.916 ms ] Karatsuba mul Error
mul3[ 19.470 ms ] NTT mul
I found out that my Karatsuba (over/under)flows the LSB of each DWORD
segment of bignum. When I have researched, I will update the code...
此外,在进一步的NTT优化之后,阈值发生了变化,因此对于NTT sqr,它是operand的310*32 bits = 9920 bits
,对于NTT穆尔,它是result(操作数位之和)的1396*32 bits = 44672 bits
。
感谢@greybeard修复了唐部代码
//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
// Recursion for Karatsuba
// z[2n] = x[n]*y[n];
// n=2^m
int i;
for (i=0; i<n; i++)
if (x[i]) {
i=-1;
break;
} // x==0 ?
if (i < 0)
for (i = 0; i<n; i++)
if (y[i]) {
i = -1;
break;
} // y==0 ?
if (i >= 0) {
for (i = 0; i < n + n; i++)
z[i]=0;
return;
} // 0.? = 0
if (n == 1) {
alu.mul(z[0], z[1], x[0], y[0]);
return;
}
if (n< 1)
return;
int n2 = n>>1;
_mul_karatsuba(z+n, x+n2, y+n2, n2); // z0 = x0.y0
_mul_karatsuba(z , x , y , n2); // z2 = x1.y1
DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
BYTE cx,cy;
if (q == NULL) {
_error(_arbnum_error_NotEnoughMemory);
return;
}
#define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
#define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
qq = q;
q0 = x + n2;
q1 = x;
i = n2 - 1;
_add;
cx = alu.cy; // =x0+x1
qq = q + n2;
q0 = y + n2;
q1 = y;
i = n2 - 1;
_add;
cy = alu.cy; // =y0+y1
_mul_karatsuba(q + n, q + n2, q, n2); // =(x0+x1)(y0+y1) mod ((2^N)-1)
if (cx) {
qq = q + n;
q0 = qq;
q1 = q + n2;
i = n2 - 1;
_add;
cx = alu.cy;
}// += cx*(y0 + y1) << n2
if (cy) {
qq = q + n;
q0 = qq;
q1 = q;
i = n2 -1;
_add;
cy = alu.cy;
}// +=cy*(x0+x1)<<n2
qq = q + n; q0 = qq; q1 = z + n; i = n - 1; _sub; // -=z0
qq = q + n; q0 = qq; q1 = z; i = n - 1; _sub; // -=z2
qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add; // z1=(x0+x1)(y0+y1)-z0-z2
DWORD ccc=0;
if (alu.cy)
ccc++; // Handle carry from last operation
if (cx || cy)
ccc++; // Handle carry from before last operation
if (ccc)
{
i = n2 - 1;
alu.add(z[i], z[i], ccc);
for (i--; i>=0; i--)
if (alu.cy)
alu.inc(z[i]);
else
break;
}
delete[] q;
#undef _add
#undef _sub
}
//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
// O(3*(N)^log2(3)) ~ O(3*(N^1.585))
// Karatsuba multiplication
//
int s = x.sig*y.sig;
arbnum a, b;
a = x;
b = y;
a.sig = +1;
b.sig = +1;
int i, n;
for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
;
a._realloc(n);
b._realloc(n);
_alloc(n + n);
for (i=0; i < siz; i++)
dat[i]=0;
_mul_karatsuba(dat, a.dat, b.dat, n);
bits = siz << 5;
sig = s;
exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
// _normalize();
}
//---------------------------------------------------------------------------
我的arbnum
数字表示:
// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
dat[siz]
是尾数。LSDW表示最低有效DWORD。exp
是dat[0]
的MSB的指数- 尾数中存在第一个非零位!!!
// |-----|---------------------------|---------------|------|
// | sig | MSB mantisa LSB | exponent | bits |
// |-----|---------------------------|---------------|------|
// | +1 | 0.(0 ... 0) | 2^0 | 0 | +zero
// | -1 | 0.(0 ... 0) | 2^0 | 0 | -zero
// |-----|---------------------------|---------------|------|
// | +1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | +number
// | -1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | -number
// |-----|---------------------------|---------------|------|
// | +1 | 1.0 | 2^+0x7FFFFFFE | 1 | +infinity
// | -1 | 1.0 | 2^+0x7FFFFFFE | 1 | -infinity
// |-----|---------------------------|---------------|------|
3条答案
按热度按时间hgc7kmma1#
如果我没理解错你的算法的话,看起来是
O(n^2)
其中n
是位数。你看过Karatsuba Algorithm吗?它使用分治法来加速乘法。它可能值得一看。
sr4lhrrt2#
Great question you have, thanks!
Decided to implement from scratch a huge C++ solution for you, based on Number Theoretic Transform (NTT) and Discrete Fourier Transform .
To tell in advance, my FFT/NTT code achieves 330x speedup on 2-core old laptop compared to naive school-grade multiplication for the case of array size 2^16 32-bit words. Even bigger arrays above 2^20 in size will give millions times speedup.
Squaring a number with 2^22 words of 32-bit size (i.e. 4 Million words) takes 7 seconds on my NTT and 13 seconds on my FFT, on old 2GHz 2-core laptop with SSE2 only.
To remind, FFT and NTT give multiplication time
O(N * Log(N))
, while naive school grade algorithm hasO(N^2)
time. That's why I have so huge speedup described in previous paragraph.Both together with code are well described in this article , mainly I was inspired by this article when writing below code. Another good article is Nayuki's NTT article .
I was convinced that for quite large numbers these two transforms will beat any other methods, like Karatsuba .
Besides basic approach described in article I also did dozens of optimizations:
constexpr
functions and values and templated programming everywhere where I can. Reduction of runtime values to compile time values where possible gives a lot of speedup.W
multiplier into separate loop together with pre-computation/caching. This gave about 2x speedup.My code is self-contained, if you compile+run it then it will run tests measuring speed. Inside test function you can see how to use my library. Test runs FFT/NTT/Naive multiplication, measures time and compares if all multiplication results are correct, i.e. equal to naive version.
Note: No matter how I struggled to speedup FFT through SIMD, yet my NTT is so optimized that it is 1.3-1.8x times faster than FFT. As you know FFT gives errors which grow with size of big number. And if to take into account a fact that my NTT got faster then NTT is the only option for you!
It appeared that FFT can be used only for array sizes like 2^16 32-bit words, no more, then error size becomes to critical and destructs final result. Or you can decrease size of input 32-bit numbers, to 10-12 bits, this helps to reduce errors, yet you can't go bigger than 2^18 array size with critical error. You have to compute error size experimentally to figure out what is best.
Code can be compiled in CLang/MSVC/GCC. Maybe other compilers too. It has no external libraries dependencies at all, maybe except OpenMP library which is usually shipped with compiler. Only computation of Primitive Roots (NTT modulus) requires Boost library but only for MSVC and uses only 128-bit integer from there.
CODE GOES HERE. Only because code size is 65 KB, I can't inline it inside this post, as StackOverflow post size limit is 30 000 symbols. Hence I'm providing my code in below Github Gist link:
Github Gist source code
Example console output:
rur96b6h3#
如果你想写一个新的更好的指数,你可能必须用汇编语言来写。这是来自golang的代码。
https://code.google.com/p/go/source/browse/src/pkg/math/exp_amd64.s