This documentation is automatically generated by online-judge-tools/verification-helper
import cp_library.math.fps.__header__
def fps_inv(P: list) -> list:
ntt, inv, d = mint.ntt, [0]*(deg:=len(P)), 1
inv[0] = mod_inv(P[0], mod := mint.mod)
while d < deg:
sz, f, g = min(deg,z:=d<<1), [0]*z, [0]*z
f[:sz], g[:d] = P[:sz], inv[:d]
ntt.conv_half(f,gres:=ntt.fntt(g))
f[:d] = [0]*d
ntt.conv_half(f,gres)
for j in range(d,sz): inv[j] = mod-f[j] if f[j] else 0
d = z
return inv
from cp_library.math.mod.mint_ntt_cls import mint
from cp_library.math.nt.mod_inv_fn import mod_inv
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
def fps_inv(P: list) -> list:
ntt, inv, d = mint.ntt, [0]*(deg:=len(P)), 1
inv[0] = mod_inv(P[0], mod := mint.mod)
while d < deg:
sz, f, g = min(deg,z:=d<<1), [0]*z, [0]*z
f[:sz], g[:d] = P[:sz], inv[:d]
ntt.conv_half(f,gres:=ntt.fntt(g))
f[:d] = [0]*d
ntt.conv_half(f,gres)
for j in range(d,sz): inv[j] = mod-f[j] if f[j] else 0
d = z
return inv
class mint(int):
mod: int
zero: 'mint'
one: 'mint'
two: 'mint'
cache: list['mint']
def __new__(cls, *args, **kwargs):
if 0<= (x := int(*args, **kwargs)) <= 2:
return cls.cache[x]
else:
return cls.fix(x)
@classmethod
def set_mod(cls, mod: int):
mint.mod = cls.mod = mod
mint.zero = cls.zero = cls.cast(0)
mint.one = cls.one = cls.fix(1)
mint.two = cls.two = cls.fix(2)
mint.cache = cls.cache = [cls.zero, cls.one, cls.two]
@classmethod
def fix(cls, x): return cls.cast(x%cls.mod)
@classmethod
def cast(cls, x): return super().__new__(cls,x)
@classmethod
def mod_inv(cls, x):
a,b,s,t = int(x), cls.mod, 1, 0
while b: a,b,s,t = b,a%b,t,s-a//b*t
if a == 1: return cls.fix(s)
raise ValueError(f"{x} is not invertible in mod {cls.mod}")
@property
def inv(self): return mint.mod_inv(self)
def __add__(self, x): return mint.fix(super().__add__(x))
def __radd__(self, x): return mint.fix(super().__radd__(x))
def __sub__(self, x): return mint.fix(super().__sub__(x))
def __rsub__(self, x): return mint.fix(super().__rsub__(x))
def __mul__(self, x): return mint.fix(super().__mul__(x))
def __rmul__(self, x): return mint.fix(super().__rmul__(x))
def __floordiv__(self, x): return self * mint.mod_inv(x)
def __rfloordiv__(self, x): return self.inv * x
def __truediv__(self, x): return self * mint.mod_inv(x)
def __rtruediv__(self, x): return self.inv * x
def __pow__(self, x):
return self.cast(super().__pow__(x, self.mod))
def __neg__(self): return mint.mod-self
def __pos__(self): return self
def __abs__(self): return self
def mod_inv(x, mod):
a,b,s,t = x, mod, 1, 0
while b:
a,b,s,t = b,a%b,t,s-a//b*t
if a == 1: return s % mod
raise ValueError(f"{x} is not invertible in mod {mod}")
class NTT:
def __init__(self, mod = 998244353) -> None:
self.mod = m = mod
self.g = g = self.primitive_root(m)
self.rank2 = rank2 = ((m-1)&(1-m)).bit_length() - 1
self.root = root = [0] * (rank2 + 1)
root[rank2] = pow(g, (m - 1) >> rank2, m)
self.iroot = iroot = [0] * (rank2 + 1)
iroot[rank2] = pow(root[rank2], m - 2, m)
for i in range(rank2 - 1, -1, -1):
root[i] = root[i+1] * root[i+1] % m
iroot[i] = iroot[i+1] * iroot[i+1] % m
def rates(s):
r8,ir8 = [0]*max(0,rank2-s+1), [0]*max(0,rank2-s+1)
p = ip = 1
for i in range(rank2-s+1):
r, ir = root[i+s], iroot[i+s]
p,ip,r8[i],ir8[i]= p*ir%m,ip*r%m,r*p%m,ir*ip%m
return r8, ir8
self.rate2, self.irate2 = rates(2)
self.rate3, self.irate3 = rates(3)
def primitive_root(self, m):
if m == 2: return 1
if m == 167772161: return 3
if m == 469762049: return 3
if m == 754974721: return 11
if m == 998244353: return 3
divs = [0] * 20
cnt, divs[0], x = 1, 2, (m - 1) // 2
while x % 2 == 0: x //= 2
i=3
while i*i <= x:
if x%i == 0:
divs[cnt],cnt = i,cnt+1
while x%i==0:x//=i
i+=2
if x > 1: divs[cnt],cnt = x,cnt+1
for g in range(2,m):
for i in range(cnt):
if pow(g,(m-1)//divs[i],m)==1:break
else:return g
def fntt(self, A: list[int]):
im, r8, m, h = self.root[2],self.rate3,self.mod,(len(A)-1).bit_length()
for L in range(0,h-1,2):
p, r = 1<<(h-L-2),1
for s in range(1 << L):
r3,of=(r2:=r*r%m)*r%m,s<<(h-L)
for i in range(p):
i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
a0,a1,a2,a3 = A[i0],A[i1]*r,A[i2]*r2,A[i3]*r3
a0,a1,a2,a3 = a0+a2,a1+a3,a0-a2,(a1-a3)%m*im
A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a0-a1)%m,(a2+a3)%m,(a2-a3)%m
r=r*r8[(~s&-~s).bit_length()-1]%m
if h&1:
r, r8 = 1, self.rate2
for s in range(1<<(h-1)):
i1=(i0:=s<<1)+1
al,ar = A[i0],A[i1]*r%m
A[i0],A[i1] = (al+ar)%m,(al-ar)%m
r=r*r8[(~s&-~s).bit_length()-1]%m
return A
def _ifntt(self, A: list[int]):
im, r8, m, h = self.iroot[2],self.irate3,self.mod,(len(A)-1).bit_length()
for L in range(h,1,-2):
p,r = 1<<(h-L),1
for s in range(1<<(L-2)):
r3,of=(r2:=r*r%m)*r%m,s<<(h-L+2)
for i in range(p):
i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
a0,a1,a2,a3 = A[i0],A[i1],A[i2],A[i3]
a0,a1,a2,a3 = a0+a1,a2+a3,a0-a1,(a2-a3)*im%m
A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a2+a3)*r%m,(a0-a1)*r2%m,(a2-a3)*r3%m
r=r*r8[(~s&-~s).bit_length()-1]%m
if h&1:
for i0 in range(p:=1<<(h-1)):
al,ar = A[i0],A[i1:=i0+p]
A[i0],A[i1] = (al+ar)%m,(al-ar)%m
return A
def ifntt(self, A: list[int]):
self._ifntt(A)
iz = mod_inv(N:=len(A),mod:=self.mod)
for i in range(N): A[i]=A[i]*iz%mod
return A
def conv_naive(self, A, B, N):
n, m, mod = len(A),len(B),self.mod
C = [0]*N
if n < m: A,B,n,m = B,A,m,n
for i,a in enumerate(A):
for j in range(min(m,N-i)):
C[ij]=(C[ij:=i+j]+a*B[j])%mod
return C
def conv_fntt(self, A, B, N):
n,m,mod=len(A),len(B),self.mod
z=1<<(n+m-2).bit_length()
self.fntt(A:=A+[0]*(z-n)), self.fntt(B:=B+[0]*(z-m))
for i, b in enumerate(B): A[i] = A[i] * b % mod
self.ifntt(A)
del A[N:]
return A
def deconv(self, C, B, N = None):
n, m = len(C), len(B)
if N is None: N = n - m + 1
z = 1 << (n + m - 2).bit_length()
self.fntt(C := C+[0]*(z-n)), self.fntt(B := B+[0]*(z - m))
A = [0] * z
for i in range(z):
if B[i] == 0:
raise ValueError("Division by zero in NTT domain - deconvolution not possible")
b_inv = mod_inv(B[i], self.mod)
A[i] = (C[i] * b_inv) % self.mod
self.ifntt(A)
return A[:N]
def conv_half(self, A, Bres):
mod = self.mod
self.fntt(A)
for i, b in enumerate(Bres): A[i] = A[i] * b % mod
self.ifntt(A)
return A
def conv(self, A, B, N = None):
n,m = len(A), len(B)
N = n+m-1 if N is None else N
if min(n,m) <= 60: return self.conv_naive(A, B, N)
return self.conv_fntt(A, B, N)
def cycle_conv(self, A, B):
n,m,mod=len(A),len(B),self.mod
assert n == m
if n==0:return[]
con,res=self.conv(A,B),[0]*n
for i in range(n-1):res[i]=(con[i]+con[i+n])%mod
res[n-1]=con[n-1]
return res
class mint(mint):
ntt: NTT
@classmethod
def set_mod(cls, mod: int):
super().set_mod(mod)
cls.ntt = NTT(mod)