From 21c23245b25f7a3f3dd4aaac584d3ece508342d2 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Thu, 15 Jun 2023 19:48:01 +0800 Subject: [PATCH] ... --- python/_long.py | 118 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 108 insertions(+), 10 deletions(-) diff --git a/python/_long.py b/python/_long.py index 65461ae1..43eecaba 100644 --- a/python/_long.py +++ b/python/_long.py @@ -1,9 +1,15 @@ from c import sizeof +# https://www.cnblogs.com/liuchanglc/p/14203783.html if sizeof('void_p') == 4: - PyLong_SHIFT = 28//2 + PyLong_SHIFT = 28//2 - 1 + PyLong_NTT_P = 12289 + PyLong_NTT_PR = 11 elif sizeof('void_p') == 8: - PyLong_SHIFT = 60//2 + PyLong_SHIFT = 60//2 - 1 + # 998244353 can not be compiled in 32-bit platform (even it is not used) + PyLong_NTT_P = 998244353 # PyLong_NTT_P**2 should not overflow + PyLong_NTT_PR = 3 else: raise NotImplementedError @@ -12,6 +18,64 @@ PyLong_MASK = PyLong_BASE - 1 PyLong_DECIMAL_SHIFT = 4 PyLong_DECIMAL_BASE = 10 ** PyLong_DECIMAL_SHIFT +assert PyLong_NTT_P > PyLong_BASE + +#----------------------------------------------------------------------------# +# # +# Number Theoretic Transform # +# # +#----------------------------------------------------------------------------# + +def ibin(n, bits): + assert type(bits) is int and bits >= 0 + return bin(n)[2:].rjust(bits, "0") + +def _number_theoretic_transform(a: list, p, pr, inverse): + n = len(a) + assert n&(n - 1) == 0 + + a = [x % p for x in a] + b = n.bit_length() - 1 + + for i in range(1, n): + j = int(ibin(i, b)[::-1], 2) + if i < j: + a[i], a[j] = a[j], a[i] + + rt = pow(pr, (p - 1) // n, p) + if inverse: + rt = pow(rt, p - 2, p) + + w = [1]*(n // 2) + for i in range(1, n // 2): + w[i] = w[i - 1]*rt % p + + h = 2 + while h <= n: + hf, ut = h // 2, n // h + for i in range(0, n, h): + for j in range(hf): + u = a[i + j] + v = a[i + j + hf] * w[ut * j] % p + a[i + j] = (u + v) % p + a[i + j + hf] = (u - v + p) % p + h *= 2 + + if inverse: + rv = pow(n, p - 2, p) + a = [x*rv % p for x in a] + + return a + + +def ntt(a, p, pr): + return _number_theoretic_transform(a, p, pr, False) + +def intt(a, p, pr): + return _number_theoretic_transform(a, p, pr, True) + +############################################################## + def ulong_fromint(x: int): # return a list of digits and sign if x == 0: return [0], 1 @@ -109,16 +173,50 @@ def ulong_muli(a: list, b: int): return res def ulong_mul(a: list, b: list): - res = [0] * (len(a) + len(b)) - for i in range(len(a)): + N = len(a) + len(b) + if False: + # use grade-school multiplication + res = [0] * N + for i in range(len(a)): + carry = 0 + for j in range(len(b)): + carry += res[i+j] + a[i] * b[j] + res[i+j] = carry & PyLong_MASK + carry >>= PyLong_SHIFT + res[i+len(b)] = carry + ulong_unpad_(res) + return res + else: + # use fast number-theoretic transform + limit = 1 + while limit < N: + limit <<= 1 + a += [0]*(limit - len(a)) + b += [0]*(limit - len(b)) + # print(a, b) + a = ntt(a, PyLong_NTT_P, PyLong_NTT_PR) + b = ntt(b, PyLong_NTT_P, PyLong_NTT_PR) + # print(a, b) + c = [0] * limit + for i in range(limit): + c[i] = (a[i] * b[i]) % PyLong_NTT_P + + # print(c) + c = intt(c, PyLong_NTT_P, PyLong_NTT_PR) + # print(c) + + # handle carry carry = 0 - for j in range(len(b)): - carry += res[i+j] + a[i] * b[j] - res[i+j] = carry & PyLong_MASK + for i in range(limit-1): + carry += c[i] + c[i] = carry & PyLong_MASK carry >>= PyLong_SHIFT - res[i+len(b)] = carry - ulong_unpad_(res) - return res + if carry > 0: + c[limit-1] = carry + # print(c) + ulong_unpad_(c) # should we use this? + # print(c) + return c def ulong_powi(a: list, b: int): # b >= 0