cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: cp_library/math/fps/fps_exp_fn.py

Depends on

Required by

Verified with

Code

import cp_library.math.fps.__header__
from cp_library.math.fps.fps_deriv_fn import fps_deriv

def fps_exp(P: list) -> list:
    max_sz = 1 << ((deg := len(P))-1).bit_length()
    modcomb.extend_inv(max_sz)
    inv, mod, ntt = modcomb.inv, mint.mod, mint.ntt
    fntt, ifntt, conv_half = ntt.fntt, ntt.ifntt, ntt.conv_half
    dP = fps_deriv(P) + [0]*(max_sz-deg+1)
    R, E, Eres = [1, (P[1] if 1 < deg else 0)], [1], [1, 1]
    reserve(R, max_sz), reserve(E, max_sz)
    p = 2
    while p < deg:
        Rres = fntt(R + [0]*p)
        x = ifntt([Rres[i]*-e%mod for i, e in enumerate(Eres)])
        x[:h] = [0]*(h:=p>>1)
        E[h:] = conv_half(x, Eres)[h:]
        Eres = fntt(E + [0]*p)
        x = conv_half(dP[:p-1]+[0], Rres[:p])
        for i in range(1,p): x[i-1] -= R[i]*i % mod
        x += [0] * p
        for i in range(p-1): x[p+i],x[i] = x[i],0
        conv_half(x,Eres)
        for i in range(min(deg, p<<1)-1,p-1,-1): x[i] = P[i]+x[i-1]*inv[i]%mod 
        x[:p] = [0] * p
        R[p:] = conv_half(x,Rres)[p:]
        p <<= 1
    return R[:deg]

from cp_library.ds.reserve_fn import reserve
from cp_library.math.table.modcomb_cls import modcomb
from cp_library.math.mod.mint_ntt_cls import mint
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''

def fps_deriv(P: list[int]):
    mod = mint.mod
    return [P[i]*i%mod for i in range(1,len(P))]


    
class mint(int):
    mod: int
    zero: 'mint'
    one: 'mint'
    two: 'mint'
    cache: list['mint']

    def __new__(cls, *args, **kwargs):
        if 0<= (x := int(*args, **kwargs)) <= 2:
            return cls.cache[x]
        else:
            return cls.fix(x)

    @classmethod
    def set_mod(cls, mod: int):
        mint.mod = cls.mod = mod
        mint.zero = cls.zero = cls.cast(0)
        mint.one = cls.one = cls.fix(1)
        mint.two = cls.two = cls.fix(2)
        mint.cache = cls.cache = [cls.zero, cls.one, cls.two]

    @classmethod
    def fix(cls, x): return cls.cast(x%cls.mod)

    @classmethod
    def cast(cls, x): return super().__new__(cls,x)

    @classmethod
    def mod_inv(cls, x):
        a,b,s,t = int(x), cls.mod, 1, 0
        while b: a,b,s,t = b,a%b,t,s-a//b*t
        if a == 1: return cls.fix(s)
        raise ValueError(f"{x} is not invertible in mod {cls.mod}")
    
    @property
    def inv(self): return mint.mod_inv(self)

    def __add__(self, x): return mint.fix(super().__add__(x))
    def __radd__(self, x): return mint.fix(super().__radd__(x))
    def __sub__(self, x): return mint.fix(super().__sub__(x))
    def __rsub__(self, x): return mint.fix(super().__rsub__(x))
    def __mul__(self, x): return mint.fix(super().__mul__(x))
    def __rmul__(self, x): return mint.fix(super().__rmul__(x))
    def __floordiv__(self, x): return self * mint.mod_inv(x)
    def __rfloordiv__(self, x): return self.inv * x
    def __truediv__(self, x): return self * mint.mod_inv(x)
    def __rtruediv__(self, x): return self.inv * x
    def __pow__(self, x): 
        return self.cast(super().__pow__(x, self.mod))
    def __neg__(self): return mint.mod-self
    def __pos__(self): return self
    def __abs__(self): return self

def fps_exp(P: list) -> list:
    max_sz = 1 << ((deg := len(P))-1).bit_length()
    modcomb.extend_inv(max_sz)
    inv, mod, ntt = modcomb.inv, mint.mod, mint.ntt
    fntt, ifntt, conv_half = ntt.fntt, ntt.ifntt, ntt.conv_half
    dP = fps_deriv(P) + [0]*(max_sz-deg+1)
    R, E, Eres = [1, (P[1] if 1 < deg else 0)], [1], [1, 1]
    reserve(R, max_sz), reserve(E, max_sz)
    p = 2
    while p < deg:
        Rres = fntt(R + [0]*p)
        x = ifntt([Rres[i]*-e%mod for i, e in enumerate(Eres)])
        x[:h] = [0]*(h:=p>>1)
        E[h:] = conv_half(x, Eres)[h:]
        Eres = fntt(E + [0]*p)
        x = conv_half(dP[:p-1]+[0], Rres[:p])
        for i in range(1,p): x[i-1] -= R[i]*i % mod
        x += [0] * p
        for i in range(p-1): x[p+i],x[i] = x[i],0
        conv_half(x,Eres)
        for i in range(min(deg, p<<1)-1,p-1,-1): x[i] = P[i]+x[i-1]*inv[i]%mod 
        x[:p] = [0] * p
        R[p:] = conv_half(x,Rres)[p:]
        p <<= 1
    return R[:deg]



def reserve(A: list, est_len: int) -> None: ...
try:
    from __pypy__ import resizelist_hint
except:
    def resizelist_hint(A: list, est_len: int):
        pass
reserve = resizelist_hint



def mod_inv(x, mod):
    a,b,s,t = x, mod, 1, 0
    while b:
        a,b,s,t = b,a%b,t,s-a//b*t
    if a == 1: return s % mod
    raise ValueError(f"{x} is not invertible in mod {mod}")
from itertools import accumulate

class modcomb():
    fact: list[int]
    fact_inv: list[int]
    inv: list[int] = [0,1]

    @staticmethod
    def precomp(N):
        mod = mint.mod
        def mod_mul(a,b): return a*b%mod
        fact = list(accumulate(range(1,N+1), mod_mul, initial=1))
        fact_inv = list(accumulate(range(N,0,-1), mod_mul, initial=mod_inv(fact[N], mod)))
        fact_inv.reverse()
        modcomb.fact, modcomb.fact_inv = fact, fact_inv
    
    @staticmethod
    def extend_inv(N):
        N, inv, mod = N+1, modcomb.inv, mint.mod
        while len(inv) < N:
            j, k = divmod(mod, len(inv))
            inv.append(-inv[k] * j % mod)

    @staticmethod
    def factorial(n: int, /) -> mint:
        return mint(modcomb.fact[n])

    @staticmethod
    def comb(n: int, k: int, /) -> mint:
        inv, mod = modcomb.fact_inv, mint.mod
        if n < k: return mint.zero
        return mint(inv[k] * inv[n-k] % mod * modcomb.fact[n])
    nCk = binom = comb
    
    @staticmethod
    def comb_with_replacement(n: int, k: int, /) -> mint:
        if n <= 0: return mint.zero
        return modcomb.nCk(n + k - 1, k)
    nHk = comb_with_replacement
    
    @staticmethod
    def multinom(n: int, *K: int) -> mint:
        nCk, res = modcomb.nCk, mint.one
        for k in K: res, n = res*nCk(n,k), n-k
        return res

    @staticmethod
    def perm(n: int, k: int, /) -> mint:
        '''Returns P(n,k) mod p'''
        if n < k: return mint.zero
        return mint(modcomb.fact[n] * modcomb.fact_inv[n-k])
    nPk = perm
    
    @staticmethod
    def catalan(n: int, /) -> mint:
        return mint(modcomb.nCk(2*n,n) * modcomb.fact_inv[n+1])


class NTT:
    def __init__(self, mod = 998244353) -> None:
        self.mod = m = mod
        self.g = g = self.primitive_root(m)
        self.rank2 = rank2 = ((m-1)&(1-m)).bit_length() - 1
        self.root = root = [0] * (rank2 + 1)
        root[rank2] = pow(g, (m - 1) >> rank2, m)
        self.iroot = iroot = [0] * (rank2 + 1)
        iroot[rank2] = pow(root[rank2], m - 2, m)
        for i in range(rank2 - 1, -1, -1):
            root[i] = root[i+1] * root[i+1] % m
            iroot[i] = iroot[i+1] * iroot[i+1] % m
        def rates(s):
            r8,ir8 = [0]*max(0,rank2-s+1), [0]*max(0,rank2-s+1)
            p = ip = 1
            for i in range(rank2-s+1):
                r, ir = root[i+s], iroot[i+s]
                p,ip,r8[i],ir8[i]= p*ir%m,ip*r%m,r*p%m,ir*ip%m
            return r8, ir8
        self.rate2, self.irate2 = rates(2)
        self.rate3, self.irate3 = rates(3)
 
    def primitive_root(self, m):
        if m == 2: return 1
        if m == 167772161: return 3
        if m == 469762049: return 3
        if m == 754974721: return 11
        if m == 998244353: return 3
        divs = [0] * 20
        cnt, divs[0], x = 1, 2, (m - 1) // 2
        while x % 2 == 0: x //= 2
        i=3
        while i*i <= x:
            if x%i == 0:
                divs[cnt],cnt = i,cnt+1
                while x%i==0:x//=i
            i+=2
        if x > 1: divs[cnt],cnt = x,cnt+1
        for g in range(2,m):
            for i in range(cnt):
                if pow(g,(m-1)//divs[i],m)==1:break
            else:return g
    
    def fntt(self, A: list[int]):
        im, r8, m, h = self.root[2],self.rate3,self.mod,(len(A)-1).bit_length()
        for L in range(0,h-1,2):
            p, r = 1<<(h-L-2),1
            for s in range(1 << L):
                r3,of=(r2:=r*r%m)*r%m,s<<(h-L)
                for i in range(p):
                    i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
                    a0,a1,a2,a3 = A[i0],A[i1]*r,A[i2]*r2,A[i3]*r3
                    a0,a1,a2,a3 = a0+a2,a1+a3,a0-a2,(a1-a3)%m*im
                    A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a0-a1)%m,(a2+a3)%m,(a2-a3)%m
                r=r*r8[(~s&-~s).bit_length()-1]%m
        if h&1:
            r, r8 = 1, self.rate2
            for s in range(1<<(h-1)):
                i1=(i0:=s<<1)+1
                al,ar = A[i0],A[i1]*r%m
                A[i0],A[i1] = (al+ar)%m,(al-ar)%m
                r=r*r8[(~s&-~s).bit_length()-1]%m
        return A
    
    def _ifntt(self, A: list[int]):
        im, r8, m, h = self.iroot[2],self.irate3,self.mod,(len(A)-1).bit_length()
        for L in range(h,1,-2):
            p,r = 1<<(h-L),1
            for s in range(1<<(L-2)):
                r3,of=(r2:=r*r%m)*r%m,s<<(h-L+2)
                for i in range(p):
                    i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
                    a0,a1,a2,a3 = A[i0],A[i1],A[i2],A[i3]
                    a0,a1,a2,a3 = a0+a1,a2+a3,a0-a1,(a2-a3)*im%m
                    A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a2+a3)*r%m,(a0-a1)*r2%m,(a2-a3)*r3%m
                r=r*r8[(~s&-~s).bit_length()-1]%m
        if h&1:
            for i0 in range(p:=1<<(h-1)):
                al,ar = A[i0],A[i1:=i0+p]
                A[i0],A[i1] = (al+ar)%m,(al-ar)%m
        return A

    def ifntt(self, A: list[int]):
        self._ifntt(A)
        iz = mod_inv(N:=len(A),mod:=self.mod)
        for i in range(N): A[i]=A[i]*iz%mod
        return A
    
    def conv_naive(self, A, B, N):
        n, m, mod = len(A),len(B),self.mod
        C = [0]*N
        if n < m: A,B,n,m = B,A,m,n
        for i,a in enumerate(A):
            for j in range(min(m,N-i)):
                C[ij]=(C[ij:=i+j]+a*B[j])%mod
        return C
    
    def conv_fntt(self, A, B, N):
        n,m,mod=len(A),len(B),self.mod
        z=1<<(n+m-2).bit_length()
        self.fntt(A:=A+[0]*(z-n)), self.fntt(B:=B+[0]*(z-m))
        for i, b in enumerate(B): A[i] = A[i] * b % mod
        self.ifntt(A)
        del A[N:]
        return A
    
    def deconv(self, C, B, N = None):
        n, m = len(C), len(B)
        if N is None: N = n - m + 1
        z = 1 << (n + m - 2).bit_length()
        self.fntt(C := C+[0]*(z-n)), self.fntt(B := B+[0]*(z - m))

        A = [0] * z
        for i in range(z):
            if B[i] == 0:
                raise ValueError("Division by zero in NTT domain - deconvolution not possible")
            b_inv = mod_inv(B[i], self.mod)
            A[i] = (C[i] * b_inv) % self.mod
        
        self.ifntt(A)
        return A[:N]
    
    def conv_half(self, A, Bres):
        mod = self.mod
        self.fntt(A)
        for i, b in enumerate(Bres): A[i] = A[i] * b % mod
        self.ifntt(A)
        return A
    
    def conv(self, A, B, N = None):
        n,m = len(A), len(B)
        N = n+m-1 if N is None else N
        if min(n,m) <= 60: return self.conv_naive(A, B, N)
        return self.conv_fntt(A, B, N)

    def cycle_conv(self, A, B):
        n,m,mod=len(A),len(B),self.mod
        assert n == m
        if n==0:return[]
        con,res=self.conv(A,B),[0]*n
        for i in range(n-1):res[i]=(con[i]+con[i+n])%mod
        res[n-1]=con[n-1]
        return res

class mint(mint):
    ntt: NTT

    @classmethod
    def set_mod(cls, mod: int):
        super().set_mod(mod)
        cls.ntt = NTT(mod)
Back to top page