cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/enumerative-combinatorics/stirling_number_of_the_second_kind.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/stirling_number_of_the_second_kind

def main():
    N = read(int)
    mint.set_mod(998244353)
    modcomb.precomp(N)
    write(*stirling2_n(N))

from cp_library.math.table.modcomb_cls import modcomb
from cp_library.math.table.stirling2_n_fn import stirling2_n
from cp_library.math.mod.mint_ntt_cls import mint
from cp_library.io.read_fn import read
from cp_library.io.write_fn import write

if __name__ == '__main__':
    main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/stirling_number_of_the_second_kind

def main():
    N = read(int)
    mint.set_mod(998244353)
    modcomb.precomp(N)
    write(*stirling2_n(N))

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''

    
class mint(int):
    mod: int
    zero: 'mint'
    one: 'mint'
    two: 'mint'
    cache: list['mint']

    def __new__(cls, *args, **kwargs):
        if 0<= (x := int(*args, **kwargs)) <= 2:
            return cls.cache[x]
        else:
            return cls.fix(x)

    @classmethod
    def set_mod(cls, mod: int):
        mint.mod = cls.mod = mod
        mint.zero = cls.zero = cls.cast(0)
        mint.one = cls.one = cls.fix(1)
        mint.two = cls.two = cls.fix(2)
        mint.cache = cls.cache = [cls.zero, cls.one, cls.two]

    @classmethod
    def fix(cls, x): return cls.cast(x%cls.mod)

    @classmethod
    def cast(cls, x): return super().__new__(cls,x)

    @classmethod
    def mod_inv(cls, x):
        a,b,s,t = int(x), cls.mod, 1, 0
        while b: a,b,s,t = b,a%b,t,s-a//b*t
        if a == 1: return cls.fix(s)
        raise ValueError(f"{x} is not invertible in mod {cls.mod}")
    
    @property
    def inv(self): return mint.mod_inv(self)

    def __add__(self, x): return mint.fix(super().__add__(x))
    def __radd__(self, x): return mint.fix(super().__radd__(x))
    def __sub__(self, x): return mint.fix(super().__sub__(x))
    def __rsub__(self, x): return mint.fix(super().__rsub__(x))
    def __mul__(self, x): return mint.fix(super().__mul__(x))
    def __rmul__(self, x): return mint.fix(super().__rmul__(x))
    def __floordiv__(self, x): return self * mint.mod_inv(x)
    def __rfloordiv__(self, x): return self.inv * x
    def __truediv__(self, x): return self * mint.mod_inv(x)
    def __rtruediv__(self, x): return self.inv * x
    def __pow__(self, x): 
        return self.cast(super().__pow__(x, self.mod))
    def __neg__(self): return mint.mod-self
    def __pos__(self): return self
    def __abs__(self): return self


def mod_inv(x, mod):
    a,b,s,t = x, mod, 1, 0
    while b:
        a,b,s,t = b,a%b,t,s-a//b*t
    if a == 1: return s % mod
    raise ValueError(f"{x} is not invertible in mod {mod}")
from itertools import accumulate

class modcomb():
    fact: list[int]
    fact_inv: list[int]
    inv: list[int] = [0,1]

    @staticmethod
    def precomp(N):
        mod = mint.mod
        def mod_mul(a,b): return a*b%mod
        fact = list(accumulate(range(1,N+1), mod_mul, initial=1))
        fact_inv = list(accumulate(range(N,0,-1), mod_mul, initial=mod_inv(fact[N], mod)))
        fact_inv.reverse()
        modcomb.fact, modcomb.fact_inv = fact, fact_inv
    
    @staticmethod
    def extend_inv(N):
        N, inv, mod = N+1, modcomb.inv, mint.mod
        while len(inv) < N:
            j, k = divmod(mod, len(inv))
            inv.append(-inv[k] * j % mod)

    @staticmethod
    def factorial(n: int, /) -> mint:
        return mint(modcomb.fact[n])

    @staticmethod
    def comb(n: int, k: int, /) -> mint:
        inv, mod = modcomb.fact_inv, mint.mod
        if n < k: return mint.zero
        return mint(inv[k] * inv[n-k] % mod * modcomb.fact[n])
    nCk = binom = comb
    
    @staticmethod
    def comb_with_replacement(n: int, k: int, /) -> mint:
        if n <= 0: return mint.zero
        return modcomb.nCk(n + k - 1, k)
    nHk = comb_with_replacement
    
    @staticmethod
    def multinom(n: int, *K: int) -> mint:
        nCk, res = modcomb.nCk, mint.one
        for k in K: res, n = res*nCk(n,k), n-k
        return res

    @staticmethod
    def perm(n: int, k: int, /) -> mint:
        '''Returns P(n,k) mod p'''
        if n < k: return mint.zero
        return mint(modcomb.fact[n] * modcomb.fact_inv[n-k])
    nPk = perm
    
    @staticmethod
    def catalan(n: int, /) -> mint:
        return mint(modcomb.nCk(2*n,n) * modcomb.fact_inv[n+1])
from typing import SupportsIndex


class NTT:
    def __init__(self, mod = 998244353) -> None:
        self.mod = m = mod
        self.g = g = self.primitive_root(m)
        self.rank2 = rank2 = ((m-1)&(1-m)).bit_length() - 1
        self.root = root = [0] * (rank2 + 1)
        root[rank2] = pow(g, (m - 1) >> rank2, m)
        self.iroot = iroot = [0] * (rank2 + 1)
        iroot[rank2] = pow(root[rank2], m - 2, m)
        for i in range(rank2 - 1, -1, -1):
            root[i] = root[i+1] * root[i+1] % m
            iroot[i] = iroot[i+1] * iroot[i+1] % m
        def rates(s):
            r8,ir8 = [0]*max(0,rank2-s+1), [0]*max(0,rank2-s+1)
            p = ip = 1
            for i in range(rank2-s+1):
                r, ir = root[i+s], iroot[i+s]
                p,ip,r8[i],ir8[i]= p*ir%m,ip*r%m,r*p%m,ir*ip%m
            return r8, ir8
        self.rate2, self.irate2 = rates(2)
        self.rate3, self.irate3 = rates(3)
 
    def primitive_root(self, m):
        if m == 2: return 1
        if m == 167772161: return 3
        if m == 469762049: return 3
        if m == 754974721: return 11
        if m == 998244353: return 3
        divs = [0] * 20
        cnt, divs[0], x = 1, 2, (m - 1) // 2
        while x % 2 == 0: x //= 2
        i=3
        while i*i <= x:
            if x%i == 0:
                divs[cnt],cnt = i,cnt+1
                while x%i==0:x//=i
            i+=2
        if x > 1: divs[cnt],cnt = x,cnt+1
        for g in range(2,m):
            for i in range(cnt):
                if pow(g,(m-1)//divs[i],m)==1:break
            else:return g
    
    def fntt(self, A: list[int]):
        im, r8, m, h = self.root[2],self.rate3,self.mod,(len(A)-1).bit_length()
        for L in range(0,h-1,2):
            p, r = 1<<(h-L-2),1
            for s in range(1 << L):
                r3,of=(r2:=r*r%m)*r%m,s<<(h-L)
                for i in range(p):
                    i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
                    a0,a1,a2,a3 = A[i0],A[i1]*r,A[i2]*r2,A[i3]*r3
                    a0,a1,a2,a3 = a0+a2,a1+a3,a0-a2,(a1-a3)%m*im
                    A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a0-a1)%m,(a2+a3)%m,(a2-a3)%m
                r=r*r8[(~s&-~s).bit_length()-1]%m
        if h&1:
            r, r8 = 1, self.rate2
            for s in range(1<<(h-1)):
                i1=(i0:=s<<1)+1
                al,ar = A[i0],A[i1]*r%m
                A[i0],A[i1] = (al+ar)%m,(al-ar)%m
                r=r*r8[(~s&-~s).bit_length()-1]%m
        return A
    
    def _ifntt(self, A: list[int]):
        im, r8, m, h = self.iroot[2],self.irate3,self.mod,(len(A)-1).bit_length()
        for L in range(h,1,-2):
            p,r = 1<<(h-L),1
            for s in range(1<<(L-2)):
                r3,of=(r2:=r*r%m)*r%m,s<<(h-L+2)
                for i in range(p):
                    i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
                    a0,a1,a2,a3 = A[i0],A[i1],A[i2],A[i3]
                    a0,a1,a2,a3 = a0+a1,a2+a3,a0-a1,(a2-a3)*im%m
                    A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a2+a3)*r%m,(a0-a1)*r2%m,(a2-a3)*r3%m
                r=r*r8[(~s&-~s).bit_length()-1]%m
        if h&1:
            for i0 in range(p:=1<<(h-1)):
                al,ar = A[i0],A[i1:=i0+p]
                A[i0],A[i1] = (al+ar)%m,(al-ar)%m
        return A

    def ifntt(self, A: list[int]):
        self._ifntt(A)
        iz = mod_inv(N:=len(A),mod:=self.mod)
        for i in range(N): A[i]=A[i]*iz%mod
        return A
    
    def conv_naive(self, A, B, N):
        n, m, mod = len(A),len(B),self.mod
        C = [0]*N
        if n < m: A,B,n,m = B,A,m,n
        for i,a in enumerate(A):
            for j in range(min(m,N-i)):
                C[ij]=(C[ij:=i+j]+a*B[j])%mod
        return C
    
    def conv_fntt(self, A, B, N):
        n,m,mod=len(A),len(B),self.mod
        z=1<<(n+m-2).bit_length()
        self.fntt(A:=A+[0]*(z-n)), self.fntt(B:=B+[0]*(z-m))
        for i, b in enumerate(B): A[i] = A[i] * b % mod
        self.ifntt(A)
        del A[N:]
        return A
    
    def deconv(self, C, B, N = None):
        n, m = len(C), len(B)
        if N is None: N = n - m + 1
        z = 1 << (n + m - 2).bit_length()
        self.fntt(C := C+[0]*(z-n)), self.fntt(B := B+[0]*(z - m))

        A = [0] * z
        for i in range(z):
            if B[i] == 0:
                raise ValueError("Division by zero in NTT domain - deconvolution not possible")
            b_inv = mod_inv(B[i], self.mod)
            A[i] = (C[i] * b_inv) % self.mod
        
        self.ifntt(A)
        return A[:N]
    
    def conv_half(self, A, Bres):
        mod = self.mod
        self.fntt(A)
        for i, b in enumerate(Bres): A[i] = A[i] * b % mod
        self.ifntt(A)
        return A
    
    def conv(self, A, B, N = None):
        n,m = len(A), len(B)
        N = n+m-1 if N is None else N
        if min(n,m) <= 60: return self.conv_naive(A, B, N)
        return self.conv_fntt(A, B, N)

    def cycle_conv(self, A, B):
        n,m,mod=len(A),len(B),self.mod
        assert n == m
        if n==0:return[]
        con,res=self.conv(A,B),[0]*n
        for i in range(n-1):res[i]=(con[i]+con[i+n])%mod
        res[n-1]=con[n-1]
        return res

class mint(mint):
    ntt: NTT

    @classmethod
    def set_mod(cls, mod: int):
        super().set_mod(mod)
        cls.ntt = NTT(mod)

def stirling2_n(n: SupportsIndex):
    inv,conv,sign = modcomb.fact_inv,mint.ntt.conv,(mod:=mint.mod)-1
    A = [inv[t]*pow(t,n,mod)%mod for t in range(n+1)]
    B = [inv[t]*(sign:=mod-sign)%mod for t in range(n+1)]
    return [mint(x) for x in conv(A, B, n+1)]


from typing import Iterable, Type, Union, overload
import typing
from collections import deque
from numbers import Number
from types import GenericAlias 
from typing import Callable, Collection, Iterator, Union
import os
import sys
from io import BytesIO, IOBase


class FastIO(IOBase):
    BUFSIZE = 8192
    newlines = 0

    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None

    def read(self):
        BUFSIZE = self.BUFSIZE
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()

    def readline(self):
        BUFSIZE = self.BUFSIZE
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()

    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)


class IOWrapper(IOBase):
    stdin: 'IOWrapper' = None
    stdout: 'IOWrapper' = None
    
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable

    def write(self, s):
        return self.buffer.write(s.encode("ascii"))
    
    def read(self):
        return self.buffer.read().decode("ascii")
    
    def readline(self):
        return self.buffer.readline().decode("ascii")

sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
from typing import TypeVar
_T = TypeVar('T')

class TokenStream(Iterator):
    stream = IOWrapper.stdin

    def __init__(self):
        self.queue = deque()

    def __next__(self):
        if not self.queue: self.queue.extend(self._line())
        return self.queue.popleft()
    
    def wait(self):
        if not self.queue: self.queue.extend(self._line())
        while self.queue: yield
 
    def _line(self):
        return TokenStream.stream.readline().split()

    def line(self):
        if self.queue:
            A = list(self.queue)
            self.queue.clear()
            return A
        return self._line()
TokenStream.default = TokenStream()

class CharStream(TokenStream):
    def _line(self):
        return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()


ParseFn = Callable[[TokenStream],_T]
class Parser:
    def __init__(self, spec: Union[type[_T],_T]):
        self.parse = Parser.compile(spec)

    def __call__(self, ts: TokenStream) -> _T:
        return self.parse(ts)
    
    @staticmethod
    def compile_type(cls: type[_T], args = ()) -> _T:
        if issubclass(cls, Parsable):
            return cls.compile(*args)
        elif issubclass(cls, (Number, str)):
            def parse(ts: TokenStream): return cls(next(ts))              
            return parse
        elif issubclass(cls, tuple):
            return Parser.compile_tuple(cls, args)
        elif issubclass(cls, Collection):
            return Parser.compile_collection(cls, args)
        elif callable(cls):
            def parse(ts: TokenStream):
                return cls(next(ts))              
            return parse
        else:
            raise NotImplementedError()
    
    @staticmethod
    def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
        if isinstance(spec, (type, GenericAlias)):
            cls = typing.get_origin(spec) or spec
            args = typing.get_args(spec) or tuple()
            return Parser.compile_type(cls, args)
        elif isinstance(offset := spec, Number): 
            cls = type(spec)  
            def parse(ts: TokenStream): return cls(next(ts)) + offset
            return parse
        elif isinstance(args := spec, tuple):      
            return Parser.compile_tuple(type(spec), args)
        elif isinstance(args := spec, Collection):  
            return Parser.compile_collection(type(spec), args)
        elif isinstance(fn := spec, Callable): 
            def parse(ts: TokenStream): return fn(next(ts))
            return parse
        else:
            raise NotImplementedError()

    @staticmethod
    def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
        if spec is int:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
            return parse
        else:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
            return parse

    @staticmethod
    def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
        fn = Parser.compile(spec)
        def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
        return parse

    @staticmethod
    def compile_children(cls: _T, specs) -> ParseFn[_T]:
        fns = tuple((Parser.compile(spec) for spec in specs))
        def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])  
        return parse
            
    @staticmethod
    def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
        if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
            return Parser.compile_line(cls, specs[0])
        else:
            return Parser.compile_children(cls, specs)

    @staticmethod
    def compile_collection(cls, specs):
        if not specs or len(specs) == 1 or isinstance(specs, set):
            return Parser.compile_line(cls, *specs)
        elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
            return Parser.compile_repeat(cls, specs[0], specs[1])
        else:
            raise NotImplementedError()

class Parsable:
    @classmethod
    def compile(cls):
        def parser(ts: TokenStream): return cls(next(ts))
        return parser

@overload
def read() -> Iterable[int]: ...
@overload
def read(spec: int) -> list[int]: ...
@overload
def read(spec: Union[Type[_T],_T], char=False) -> _T: ...
def read(spec: Union[Type[_T],_T] = None, char=False):
    if not char and spec is None: return map(int, TokenStream.default.line())
    parser: _T = Parser.compile(spec)
    return parser(CharStream.default if char else TokenStream.default)

def write(*args, **kwargs):
    '''Prints the values to a stream, or to stdout_fast by default.'''
    sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IOWrapper.stdout)
    at_start = True
    for x in args:
        if not at_start:
            file.write(sep)
        file.write(str(x))
        at_start = False
    file.write(kwargs.pop("end", "\n"))
    if kwargs.pop("flush", False):
        file.flush()

if __name__ == '__main__':
    main()
Back to top page