This commit is contained in:
blueloveTH 2023-06-15 21:34:01 +08:00
parent 21c23245b2
commit ba248ae0f3

View File

@ -3,13 +3,8 @@ from c import sizeof
# https://www.cnblogs.com/liuchanglc/p/14203783.html # https://www.cnblogs.com/liuchanglc/p/14203783.html
if sizeof('void_p') == 4: if sizeof('void_p') == 4:
PyLong_SHIFT = 28//2 - 1 PyLong_SHIFT = 28//2 - 1
PyLong_NTT_P = 12289
PyLong_NTT_PR = 11
elif sizeof('void_p') == 8: elif sizeof('void_p') == 8:
PyLong_SHIFT = 60//2 - 1 PyLong_SHIFT = 60//2 - 1
# 998244353 can not be compiled in 32-bit platform (even it is not used)
PyLong_NTT_P = 998244353 # PyLong_NTT_P**2 should not overflow
PyLong_NTT_PR = 3
else: else:
raise NotImplementedError raise NotImplementedError
@ -18,62 +13,6 @@ PyLong_MASK = PyLong_BASE - 1
PyLong_DECIMAL_SHIFT = 4 PyLong_DECIMAL_SHIFT = 4
PyLong_DECIMAL_BASE = 10 ** PyLong_DECIMAL_SHIFT PyLong_DECIMAL_BASE = 10 ** PyLong_DECIMAL_SHIFT
assert PyLong_NTT_P > PyLong_BASE
#----------------------------------------------------------------------------#
# #
# Number Theoretic Transform #
# #
#----------------------------------------------------------------------------#
def ibin(n, bits):
assert type(bits) is int and bits >= 0
return bin(n)[2:].rjust(bits, "0")
def _number_theoretic_transform(a: list, p, pr, inverse):
n = len(a)
assert n&(n - 1) == 0
a = [x % p for x in a]
b = n.bit_length() - 1
for i in range(1, n):
j = int(ibin(i, b)[::-1], 2)
if i < j:
a[i], a[j] = a[j], a[i]
rt = pow(pr, (p - 1) // n, p)
if inverse:
rt = pow(rt, p - 2, p)
w = [1]*(n // 2)
for i in range(1, n // 2):
w[i] = w[i - 1]*rt % p
h = 2
while h <= n:
hf, ut = h // 2, n // h
for i in range(0, n, h):
for j in range(hf):
u = a[i + j]
v = a[i + j + hf] * w[ut * j] % p
a[i + j] = (u + v) % p
a[i + j + hf] = (u - v + p) % p
h *= 2
if inverse:
rv = pow(n, p - 2, p)
a = [x*rv % p for x in a]
return a
def ntt(a, p, pr):
return _number_theoretic_transform(a, p, pr, False)
def intt(a, p, pr):
return _number_theoretic_transform(a, p, pr, True)
############################################################## ##############################################################
def ulong_fromint(x: int): def ulong_fromint(x: int):
@ -174,49 +113,17 @@ def ulong_muli(a: list, b: int):
def ulong_mul(a: list, b: list): def ulong_mul(a: list, b: list):
N = len(a) + len(b) N = len(a) + len(b)
if False: # use grade-school multiplication
# use grade-school multiplication res = [0] * N
res = [0] * N for i in range(len(a)):
for i in range(len(a)):
carry = 0
for j in range(len(b)):
carry += res[i+j] + a[i] * b[j]
res[i+j] = carry & PyLong_MASK
carry >>= PyLong_SHIFT
res[i+len(b)] = carry
ulong_unpad_(res)
return res
else:
# use fast number-theoretic transform
limit = 1
while limit < N:
limit <<= 1
a += [0]*(limit - len(a))
b += [0]*(limit - len(b))
# print(a, b)
a = ntt(a, PyLong_NTT_P, PyLong_NTT_PR)
b = ntt(b, PyLong_NTT_P, PyLong_NTT_PR)
# print(a, b)
c = [0] * limit
for i in range(limit):
c[i] = (a[i] * b[i]) % PyLong_NTT_P
# print(c)
c = intt(c, PyLong_NTT_P, PyLong_NTT_PR)
# print(c)
# handle carry
carry = 0 carry = 0
for i in range(limit-1): for j in range(len(b)):
carry += c[i] carry += res[i+j] + a[i] * b[j]
c[i] = carry & PyLong_MASK res[i+j] = carry & PyLong_MASK
carry >>= PyLong_SHIFT carry >>= PyLong_SHIFT
if carry > 0: res[i+len(b)] = carry
c[limit-1] = carry ulong_unpad_(res)
# print(c) return res
ulong_unpad_(c) # should we use this?
# print(c)
return c
def ulong_powi(a: list, b: int): def ulong_powi(a: list, b: int):
# b >= 0 # b >= 0