This documentation is automatically generated by online-judge-tools/verification-helper
import cp_library.math.fps.__header__
from cp_library.math.fps.fps_deriv_fn import fps_deriv
def fps_exp(P: list) -> list:
max_sz = 1 << ((deg := len(P))-1).bit_length()
modcomb.extend_inv(max_sz)
inv, mod, ntt = modcomb.inv, mint.mod, mint.ntt
fntt, ifntt, conv_half = ntt.fntt, ntt.ifntt, ntt.conv_half
dP = fps_deriv(P) + [0]*(max_sz-deg+1)
R, E, Eres = [1, (P[1] if 1 < deg else 0)], [1], [1, 1]
reserve(R, max_sz), reserve(E, max_sz)
p = 2
while p < deg:
Rres = fntt(R + [0]*p)
x = ifntt([Rres[i]*-e%mod for i, e in enumerate(Eres)])
x[:h] = [0]*(h:=p>>1)
E[h:] = conv_half(x, Eres)[h:]
Eres = fntt(E + [0]*p)
x = conv_half(dP[:p-1]+[0], Rres[:p])
for i in range(1,p): x[i-1] -= R[i]*i % mod
x += [0] * p
for i in range(p-1): x[p+i],x[i] = x[i],0
conv_half(x,Eres)
for i in range(min(deg, p<<1)-1,p-1,-1): x[i] = P[i]+x[i-1]*inv[i]%mod
x[:p] = [0] * p
R[p:] = conv_half(x,Rres)[p:]
p <<= 1
return R[:deg]
from cp_library.ds.reserve_fn import reserve
from cp_library.math.table.modcomb_cls import modcomb
from cp_library.math.mod.mint_ntt_cls import mint
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
def fps_deriv(P: list[int]):
mod = mint.mod
return [P[i]*i%mod for i in range(1,len(P))]
class mint(int):
mod: int
zero: 'mint'
one: 'mint'
two: 'mint'
cache: list['mint']
def __new__(cls, *args, **kwargs):
if 0<= (x := int(*args, **kwargs)) <= 2:
return cls.cache[x]
else:
return cls.fix(x)
@classmethod
def set_mod(cls, mod: int):
mint.mod = cls.mod = mod
mint.zero = cls.zero = cls.cast(0)
mint.one = cls.one = cls.fix(1)
mint.two = cls.two = cls.fix(2)
mint.cache = cls.cache = [cls.zero, cls.one, cls.two]
@classmethod
def fix(cls, x): return cls.cast(x%cls.mod)
@classmethod
def cast(cls, x): return super().__new__(cls,x)
@classmethod
def mod_inv(cls, x):
a,b,s,t = int(x), cls.mod, 1, 0
while b: a,b,s,t = b,a%b,t,s-a//b*t
if a == 1: return cls.fix(s)
raise ValueError(f"{x} is not invertible in mod {cls.mod}")
@property
def inv(self): return mint.mod_inv(self)
def __add__(self, x): return mint.fix(super().__add__(x))
def __radd__(self, x): return mint.fix(super().__radd__(x))
def __sub__(self, x): return mint.fix(super().__sub__(x))
def __rsub__(self, x): return mint.fix(super().__rsub__(x))
def __mul__(self, x): return mint.fix(super().__mul__(x))
def __rmul__(self, x): return mint.fix(super().__rmul__(x))
def __floordiv__(self, x): return self * mint.mod_inv(x)
def __rfloordiv__(self, x): return self.inv * x
def __truediv__(self, x): return self * mint.mod_inv(x)
def __rtruediv__(self, x): return self.inv * x
def __pow__(self, x):
return self.cast(super().__pow__(x, self.mod))
def __neg__(self): return mint.mod-self
def __pos__(self): return self
def __abs__(self): return self
def fps_exp(P: list) -> list:
max_sz = 1 << ((deg := len(P))-1).bit_length()
modcomb.extend_inv(max_sz)
inv, mod, ntt = modcomb.inv, mint.mod, mint.ntt
fntt, ifntt, conv_half = ntt.fntt, ntt.ifntt, ntt.conv_half
dP = fps_deriv(P) + [0]*(max_sz-deg+1)
R, E, Eres = [1, (P[1] if 1 < deg else 0)], [1], [1, 1]
reserve(R, max_sz), reserve(E, max_sz)
p = 2
while p < deg:
Rres = fntt(R + [0]*p)
x = ifntt([Rres[i]*-e%mod for i, e in enumerate(Eres)])
x[:h] = [0]*(h:=p>>1)
E[h:] = conv_half(x, Eres)[h:]
Eres = fntt(E + [0]*p)
x = conv_half(dP[:p-1]+[0], Rres[:p])
for i in range(1,p): x[i-1] -= R[i]*i % mod
x += [0] * p
for i in range(p-1): x[p+i],x[i] = x[i],0
conv_half(x,Eres)
for i in range(min(deg, p<<1)-1,p-1,-1): x[i] = P[i]+x[i-1]*inv[i]%mod
x[:p] = [0] * p
R[p:] = conv_half(x,Rres)[p:]
p <<= 1
return R[:deg]
def reserve(A: list, est_len: int) -> None: ...
try:
from __pypy__ import resizelist_hint
except:
def resizelist_hint(A: list, est_len: int):
pass
reserve = resizelist_hint
def mod_inv(x, mod):
a,b,s,t = x, mod, 1, 0
while b:
a,b,s,t = b,a%b,t,s-a//b*t
if a == 1: return s % mod
raise ValueError(f"{x} is not invertible in mod {mod}")
from itertools import accumulate
class modcomb():
fact: list[int]
fact_inv: list[int]
inv: list[int] = [0,1]
@staticmethod
def precomp(N):
mod = mint.mod
def mod_mul(a,b): return a*b%mod
fact = list(accumulate(range(1,N+1), mod_mul, initial=1))
fact_inv = list(accumulate(range(N,0,-1), mod_mul, initial=mod_inv(fact[N], mod)))
fact_inv.reverse()
modcomb.fact, modcomb.fact_inv = fact, fact_inv
@staticmethod
def extend_inv(N):
N, inv, mod = N+1, modcomb.inv, mint.mod
while len(inv) < N:
j, k = divmod(mod, len(inv))
inv.append(-inv[k] * j % mod)
@staticmethod
def factorial(n: int, /) -> mint:
return mint(modcomb.fact[n])
@staticmethod
def comb(n: int, k: int, /) -> mint:
inv, mod = modcomb.fact_inv, mint.mod
if n < k: return mint.zero
return mint(inv[k] * inv[n-k] % mod * modcomb.fact[n])
nCk = binom = comb
@staticmethod
def comb_with_replacement(n: int, k: int, /) -> mint:
if n <= 0: return mint.zero
return modcomb.nCk(n + k - 1, k)
nHk = comb_with_replacement
@staticmethod
def multinom(n: int, *K: int) -> mint:
nCk, res = modcomb.nCk, mint.one
for k in K: res, n = res*nCk(n,k), n-k
return res
@staticmethod
def perm(n: int, k: int, /) -> mint:
'''Returns P(n,k) mod p'''
if n < k: return mint.zero
return mint(modcomb.fact[n] * modcomb.fact_inv[n-k])
nPk = perm
@staticmethod
def catalan(n: int, /) -> mint:
return mint(modcomb.nCk(2*n,n) * modcomb.fact_inv[n+1])
class NTT:
def __init__(self, mod = 998244353) -> None:
self.mod = m = mod
self.g = g = self.primitive_root(m)
self.rank2 = rank2 = ((m-1)&(1-m)).bit_length() - 1
self.root = root = [0] * (rank2 + 1)
root[rank2] = pow(g, (m - 1) >> rank2, m)
self.iroot = iroot = [0] * (rank2 + 1)
iroot[rank2] = pow(root[rank2], m - 2, m)
for i in range(rank2 - 1, -1, -1):
root[i] = root[i+1] * root[i+1] % m
iroot[i] = iroot[i+1] * iroot[i+1] % m
def rates(s):
r8,ir8 = [0]*max(0,rank2-s+1), [0]*max(0,rank2-s+1)
p = ip = 1
for i in range(rank2-s+1):
r, ir = root[i+s], iroot[i+s]
p,ip,r8[i],ir8[i]= p*ir%m,ip*r%m,r*p%m,ir*ip%m
return r8, ir8
self.rate2, self.irate2 = rates(2)
self.rate3, self.irate3 = rates(3)
def primitive_root(self, m):
if m == 2: return 1
if m == 167772161: return 3
if m == 469762049: return 3
if m == 754974721: return 11
if m == 998244353: return 3
divs = [0] * 20
cnt, divs[0], x = 1, 2, (m - 1) // 2
while x % 2 == 0: x //= 2
i=3
while i*i <= x:
if x%i == 0:
divs[cnt],cnt = i,cnt+1
while x%i==0:x//=i
i+=2
if x > 1: divs[cnt],cnt = x,cnt+1
for g in range(2,m):
for i in range(cnt):
if pow(g,(m-1)//divs[i],m)==1:break
else:return g
def fntt(self, A: list[int]):
im, r8, m, h = self.root[2],self.rate3,self.mod,(len(A)-1).bit_length()
for L in range(0,h-1,2):
p, r = 1<<(h-L-2),1
for s in range(1 << L):
r3,of=(r2:=r*r%m)*r%m,s<<(h-L)
for i in range(p):
i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
a0,a1,a2,a3 = A[i0],A[i1]*r,A[i2]*r2,A[i3]*r3
a0,a1,a2,a3 = a0+a2,a1+a3,a0-a2,(a1-a3)%m*im
A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a0-a1)%m,(a2+a3)%m,(a2-a3)%m
r=r*r8[(~s&-~s).bit_length()-1]%m
if h&1:
r, r8 = 1, self.rate2
for s in range(1<<(h-1)):
i1=(i0:=s<<1)+1
al,ar = A[i0],A[i1]*r%m
A[i0],A[i1] = (al+ar)%m,(al-ar)%m
r=r*r8[(~s&-~s).bit_length()-1]%m
return A
def _ifntt(self, A: list[int]):
im, r8, m, h = self.iroot[2],self.irate3,self.mod,(len(A)-1).bit_length()
for L in range(h,1,-2):
p,r = 1<<(h-L),1
for s in range(1<<(L-2)):
r3,of=(r2:=r*r%m)*r%m,s<<(h-L+2)
for i in range(p):
i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
a0,a1,a2,a3 = A[i0],A[i1],A[i2],A[i3]
a0,a1,a2,a3 = a0+a1,a2+a3,a0-a1,(a2-a3)*im%m
A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a2+a3)*r%m,(a0-a1)*r2%m,(a2-a3)*r3%m
r=r*r8[(~s&-~s).bit_length()-1]%m
if h&1:
for i0 in range(p:=1<<(h-1)):
al,ar = A[i0],A[i1:=i0+p]
A[i0],A[i1] = (al+ar)%m,(al-ar)%m
return A
def ifntt(self, A: list[int]):
self._ifntt(A)
iz = mod_inv(N:=len(A),mod:=self.mod)
for i in range(N): A[i]=A[i]*iz%mod
return A
def conv_naive(self, A, B, N):
n, m, mod = len(A),len(B),self.mod
C = [0]*N
if n < m: A,B,n,m = B,A,m,n
for i,a in enumerate(A):
for j in range(min(m,N-i)):
C[ij]=(C[ij:=i+j]+a*B[j])%mod
return C
def conv_fntt(self, A, B, N):
n,m,mod=len(A),len(B),self.mod
z=1<<(n+m-2).bit_length()
self.fntt(A:=A+[0]*(z-n)), self.fntt(B:=B+[0]*(z-m))
for i, b in enumerate(B): A[i] = A[i] * b % mod
self.ifntt(A)
del A[N:]
return A
def deconv(self, C, B, N = None):
n, m = len(C), len(B)
if N is None: N = n - m + 1
z = 1 << (n + m - 2).bit_length()
self.fntt(C := C+[0]*(z-n)), self.fntt(B := B+[0]*(z - m))
A = [0] * z
for i in range(z):
if B[i] == 0:
raise ValueError("Division by zero in NTT domain - deconvolution not possible")
b_inv = mod_inv(B[i], self.mod)
A[i] = (C[i] * b_inv) % self.mod
self.ifntt(A)
return A[:N]
def conv_half(self, A, Bres):
mod = self.mod
self.fntt(A)
for i, b in enumerate(Bres): A[i] = A[i] * b % mod
self.ifntt(A)
return A
def conv(self, A, B, N = None):
n,m = len(A), len(B)
N = n+m-1 if N is None else N
if min(n,m) <= 60: return self.conv_naive(A, B, N)
return self.conv_fntt(A, B, N)
def cycle_conv(self, A, B):
n,m,mod=len(A),len(B),self.mod
assert n == m
if n==0:return[]
con,res=self.conv(A,B),[0]*n
for i in range(n-1):res[i]=(con[i]+con[i+n])%mod
res[n-1]=con[n-1]
return res
class mint(mint):
ntt: NTT
@classmethod
def set_mod(cls, mod: int):
super().set_mod(mod)
cls.ntt = NTT(mod)