cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/enumerative-combinatorics/stirling_number_of_the_second_kind.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/stirling_number_of_the_second_kind

def main():
    N = read(int)
    mint.set_mod(998244353)
    mcomb.precomp(N)
    write(*stirling2_n(N))

from cp_library.math.table.mcomb_cls import mcomb
from cp_library.math.table.stirling2_n_fn import stirling2_n
from cp_library.math.mod.mint_ntt_cls import mint
from cp_library.io.read_fn import read
from cp_library.io.write_fn import write

if __name__ == '__main__':
    main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/stirling_number_of_the_second_kind

def main():
    N = read(int)
    mint.set_mod(998244353)
    mcomb.precomp(N)
    write(*stirling2_n(N))

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''



    
class mint(int):
    mod: int
    zero: 'mint'
    one: 'mint'
    two: 'mint'
    cache: list['mint']
    def __new__(cls, *args, **kwargs):
        if 0 <= (x := int(*args, **kwargs)) < 64: return cls.cache[x]
        else: return cls.fix(x)
    @classmethod
    def set_mod(cls, mod: int):
        mint.mod = cls.mod = mod
        mint.zero = cls.zero = cls.cast(0)
        mint.one = cls.one = cls.fix(1)
        mint.two = cls.two = cls.fix(2)
        mint.cache = cls.cache = [cls.zero, cls.one, cls.two]
        for x in range(3,64): mint.cache.append(cls.fix(x))
    @classmethod
    def fix(cls, x): return cls.cast(x%cls.mod)
    @classmethod
    def cast(cls, x): return super().__new__(cls,x)
    @classmethod
    def mod_inv(cls, x):
        a,b,s,t = int(x), cls.mod, 1, 0
        while b: a,b,s,t = b,a%b,t,s-a//b*t
        if a == 1: return cls.fix(s)
        raise ValueError(f"{x} is not invertible in mod {cls.mod}")
    @property
    def inv(self): return mint.mod_inv(self)
    def __add__(self, x): return mint.fix(super().__add__(x))
    def __radd__(self, x): return mint.fix(super().__radd__(x))
    def __sub__(self, x): return mint.fix(super().__sub__(x))
    def __rsub__(self, x): return mint.fix(super().__rsub__(x))
    def __mul__(self, x): return mint.fix(super().__mul__(x))
    def __rmul__(self, x): return mint.fix(super().__rmul__(x))
    def __floordiv__(self, x): return self * mint.mod_inv(x)
    def __rfloordiv__(self, x): return self.inv * x
    def __truediv__(self, x): return self * mint.mod_inv(x)
    def __rtruediv__(self, x): return self.inv * x
    def __pow__(self, x): return self.cast(super().__pow__(x, self.mod))
    def __neg__(self): return mint.mod-self
    def __pos__(self): return self
    def __abs__(self): return self
    def __class_getitem__(self, x: int): return self.cache[x]


def mod_inv(x, mod):
    a, b, s, t = x, mod, 1, 0
    while b:
        a, b, s, t = b,a%b,t,s-a//b*t
    if a == 1: return s % mod
    raise ValueError(f"{x} is not invertible in mod {mod}")
from itertools import accumulate

class mcomb():
    fact: list[int]
    fact_inv: list[int]
    inv: list[int] = [0,1]

    @staticmethod
    def precomp(N):
        mod = mint.mod
        def mod_mul(a,b): return a*b%mod
        fact = list(accumulate(range(1,N+1), mod_mul, initial=1))
        fact_inv = list(accumulate(range(N,0,-1), mod_mul, initial=mod_inv(fact[N], mod)))
        fact_inv.reverse()
        mcomb.fact, mcomb.fact_inv = fact, fact_inv
    
    @staticmethod
    def extend_inv(N):
        N, inv, mod = N+1, mcomb.inv, mint.mod
        while len(inv) < N:
            j, k = divmod(mod, len(inv))
            inv.append(-inv[k] * j % mod)

    @staticmethod
    def factorial(n: int, /) -> mint:
        return mint(mcomb.fact[n])

    @staticmethod
    def comb(n: int, k: int, /) -> mint:
        inv, mod = mcomb.fact_inv, mint.mod
        if n < k or k < 0: return mint.zero
        return mint(inv[k] * inv[n-k] % mod * mcomb.fact[n])
    nCk = binom = comb
    
    @staticmethod
    def comb_with_replacement(n: int, k: int, /) -> mint:
        if n <= 0: return mint.zero
        return mcomb.nCk(n + k - 1, k)
    nHk = comb_with_replacement
    
    @staticmethod
    def multinom(n: int, *K: int) -> mint:
        nCk, res = mcomb.nCk, mint.one
        for k in K: res, n = res*nCk(n,k), n-k
        return res

    @staticmethod
    def perm(n: int, k: int, /) -> mint:
        '''Returns P(n,k) mod p'''
        if n < k: return mint.zero
        return mint(mcomb.fact[n] * mcomb.fact_inv[n-k])
    nPk = perm
    
    @staticmethod
    def catalan(n: int, /) -> mint:
        return mint(mcomb.nCk(2*n,n) * mcomb.fact_inv[n+1])
from typing import SupportsIndex

class NTT:
    def __init__(self, mod = 998244353) -> None:
        self.mod = m = mod
        self.g = g = self.primitive_root(m)
        self.rank2 = rank2 = ((m-1)&(1-m)).bit_length() - 1
        self.root = root = [0] * (rank2 + 1)
        root[rank2] = pow(g, (m - 1) >> rank2, m)
        self.iroot = iroot = [0] * (rank2 + 1)
        iroot[rank2] = pow(root[rank2], m - 2, m)
        for i in range(rank2 - 1, -1, -1):
            root[i] = root[i+1] * root[i+1] % m
            iroot[i] = iroot[i+1] * iroot[i+1] % m
        def rates(s):
            r8,ir8 = [0]*max(0,rank2-s+1), [0]*max(0,rank2-s+1)
            p = ip = 1
            for i in range(rank2-s+1):
                r, ir = root[i+s], iroot[i+s]
                p,ip,r8[i],ir8[i]= p*ir%m,ip*r%m,r*p%m,ir*ip%m
            return r8, ir8
        self.rate2, self.irate2 = rates(2)
        self.rate3, self.irate3 = rates(3)
 
    def primitive_root(self, m):
        if m == 2: return 1
        if m == 167772161: return 3
        if m == 469762049: return 3
        if m == 754974721: return 11
        if m == 998244353: return 3
        divs = [0] * 20
        cnt, divs[0], x = 1, 2, (m - 1) // 2
        while x % 2 == 0: x //= 2
        i=3
        while i*i <= x:
            if x%i == 0:
                divs[cnt],cnt = i,cnt+1
                while x%i==0:x//=i
            i+=2
        if x > 1: divs[cnt],cnt = x,cnt+1
        for g in range(2,m):
            for i in range(cnt):
                if pow(g,(m-1)//divs[i],m)==1:break
            else:return g
    
    def fntt(self, A: list[int]):
        im, r8, m, h = self.root[2],self.rate3,self.mod,(len(A)-1).bit_length()
        for L in range(0,h-1,2):
            p, r = 1<<(h-L-2),1
            for s in range(1 << L):
                r3,of=(r2:=r*r%m)*r%m,s<<(h-L)
                for i in range(p):
                    i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
                    a0,a1,a2,a3 = A[i0],A[i1]*r,A[i2]*r2,A[i3]*r3
                    a0,a1,a2,a3 = a0+a2,a1+a3,a0-a2,(a1-a3)%m*im
                    A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a0-a1)%m,(a2+a3)%m,(a2-a3)%m
                r=r*r8[(~s&-~s).bit_length()-1]%m
        if h&1:
            r, r8 = 1, self.rate2
            for s in range(1<<(h-1)):
                i1=(i0:=s<<1)+1
                al,ar = A[i0],A[i1]*r%m
                A[i0],A[i1] = (al+ar)%m,(al-ar)%m
                r=r*r8[(~s&-~s).bit_length()-1]%m
        return A
    
    def _ifntt(self, A: list[int]):
        im, r8, m, h = self.iroot[2],self.irate3,self.mod,(len(A)-1).bit_length()
        for L in range(h,1,-2):
            p,r = 1<<(h-L),1
            for s in range(1<<(L-2)):
                r3,of=(r2:=r*r%m)*r%m,s<<(h-L+2)
                for i in range(p):
                    i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
                    a0,a1,a2,a3 = A[i0],A[i1],A[i2],A[i3]
                    a0,a1,a2,a3 = a0+a1,a2+a3,a0-a1,(a2-a3)*im%m
                    A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a2+a3)*r%m,(a0-a1)*r2%m,(a2-a3)*r3%m
                r=r*r8[(~s&-~s).bit_length()-1]%m
        if h&1:
            for i0 in range(p:=1<<(h-1)):
                al,ar = A[i0],A[i1:=i0+p]
                A[i0],A[i1] = (al+ar)%m,(al-ar)%m
        return A

    def ifntt(self, A: list[int]):
        self._ifntt(A)
        iz = mod_inv(N:=len(A),mod:=self.mod)
        for i in range(N): A[i]=A[i]*iz%mod
        return A
    
    def conv_naive(self, A, B, N):
        n, m, mod = len(A),len(B),self.mod
        C = [0]*N
        if n < m: A,B,n,m = B,A,m,n
        for i,a in enumerate(A):
            for j in range(min(m,N-i)):
                C[ij]=(C[ij:=i+j]+a*B[j])%mod
        return C
    
    def conv_fntt(self, A, B, N):
        n,m,mod=len(A),len(B),self.mod
        z=1<<(n+m-2).bit_length()
        self.fntt(A:=A+[0]*(z-n)), self.fntt(B:=B+[0]*(z-m))
        for i, b in enumerate(B): A[i] = A[i] * b % mod
        self.ifntt(A)
        del A[N:]
        return A
    
    def deconv(self, C, B, N = None):
        n, m = len(C), len(B)
        if N is None: N = n - m + 1
        z = 1 << (n + m - 2).bit_length()
        self.fntt(C := C+[0]*(z-n)), self.fntt(B := B+[0]*(z - m))

        A = [0] * z
        for i in range(z):
            if B[i] == 0:
                raise ValueError("Division by zero in NTT domain - deconvolution not possible")
            b_inv = mod_inv(B[i], self.mod)
            A[i] = (C[i] * b_inv) % self.mod
        
        self.ifntt(A)
        return A[:N]
    
    def conv_half(self, A, Bres):
        mod = self.mod
        self.fntt(A)
        for i, b in enumerate(Bres): A[i] = A[i] * b % mod
        self.ifntt(A)
        return A
    
    def conv(self, A, B, N = None):
        n,m = len(A), len(B)
        N = n+m-1 if N is None else N
        if min(n,m) <= 60: return self.conv_naive(A, B, N)
        return self.conv_fntt(A, B, N)

    def cycle_conv(self, A, B):
        n,m,mod=len(A),len(B),self.mod
        assert n == m
        if n==0:return[]
        con,res=self.conv(A,B),[0]*n
        for i in range(n-1):res[i]=(con[i]+con[i+n])%mod
        res[n-1]=con[n-1]
        return res

class mint(mint):
    ntt: NTT

    @classmethod
    def set_mod(cls, mod: int):
        super().set_mod(mod)
        cls.ntt = NTT(mod)

def stirling2_n(n: SupportsIndex):
    inv,conv,sign = mcomb.fact_inv,mint.ntt.conv,(mod:=mint.mod)-1
    A = [inv[t]*pow(t,n,mod)%mod for t in range(n+1)]
    B = [inv[t]*(sign:=mod-sign)%mod for t in range(n+1)]
    return [mint(x) for x in conv(A, B, n+1)]
from typing import Type, Union, overload
from typing import TypeVar

_S = TypeVar('S'); _T = TypeVar('T'); _U = TypeVar('U'); _T1 = TypeVar('T1'); _T2 = TypeVar('T2'); _T3 = TypeVar('T3'); _T4 = TypeVar('T4'); _T5 = TypeVar('T5'); _T6 = TypeVar('T6')


@overload
def read() -> list[int]: ...
@overload
def read(spec: Type[_T], char=False) -> _T: ...
@overload
def read(spec: _U, char=False) -> _U: ...
@overload
def read(*specs: Type[_T], char=False) -> tuple[_T, ...]: ...
@overload
def read(*specs: _U, char=False) -> tuple[_U, ...]: ...
def read(*specs: Union[Type[_T],_T], char=False):
    IO.stdin.char = char
    if not specs: return IO.stdin.readnumsinto([])
    parser: _T = Parser.compile(specs[0] if len(specs) == 1 else specs)
    return parser(IO.stdin)
from os import read as os_read, write as os_write, fstat as os_fstat
import sys
from __pypy__.builders import StringBuilder



def max2(a, b): return a if a > b else b

class IOBase:
    @property
    def char(io) -> bool: ...
    @property
    def writable(io) -> bool: ...
    def __next__(io) -> str: ...
    def write(io, s: str) -> None: ...
    def readline(io) -> str: ...
    def readtoken(io) -> str: ...
    def readtokens(io) -> list[str]: ...
    def readints(io) -> list[int]: ...
    def readdigits(io) -> list[int]: ...
    def readnums(io) -> list[int]: ...
    def readchar(io) -> str: ...
    def readchars(io) -> str: ...
    def readinto(io, lst: list[str]) -> list[str]: ...
    def readcharsinto(io, lst: list[str]) -> list[str]: ...
    def readtokensinto(io, lst: list[str]) -> list[str]: ...
    def readintsinto(io, lst: list[int]) -> list[int]: ...
    def readdigitsinto(io, lst: list[int]) -> list[int]: ...
    def readnumsinto(io, lst: list[int]) -> list[int]: ...
    def wait(io): ...
    def flush(io) -> None: ...
    def line(io) -> list[str]: ...

class IO(IOBase):
    BUFSIZE = 1 << 16; stdin: 'IO'; stdout: 'IO'
    __slots__ = 'f', 'file', 'B', 'O', 'V', 'S', 'l', 'p', 'char', 'sz', 'st', 'ist', 'writable', 'encoding', 'errors'
    def __init__(io, file):
        io.file = file
        try: io.f = file.fileno(); io.sz, io.writable = max2(io.BUFSIZE, os_fstat(io.f).st_size), ('x' in file.mode or 'r' not in file.mode)
        except: io.f, io.sz, io.writable = -1, io.BUFSIZE, False
        io.B, io.O, io.S = bytearray(), [], StringBuilder(); io.V = memoryview(io.B); io.l = io.p = 0
        io.char, io.st, io.ist, io.encoding, io.errors = False, [], [], 'ascii', 'ignore'
    def _dec(io, l, r): return io.V[l:r].tobytes().decode(io.encoding, io.errors)
    def readbytes(io, sz): return os_read(io.f, sz)
    def load(io):
        while io.l >= len(io.O):
            if not (b := io.readbytes(io.sz)):
                if io.O[-1] < len(io.B): io.O.append(len(io.B))
                break
            pos = len(io.B); io.B.extend(b)
            while ~(pos := io.B.find(b'\n', pos)): io.O.append(pos := pos+1)
    def __next__(io):
        if io.char: return io.readchar()
        else: return io.readtoken()
    def readchar(io):
        io.load(); r = io.O[io.l]
        c = chr(io.B[io.p])
        if io.p >= r-1: io.p = r; io.l += 1
        else: io.p += 1
        return c
    def write(io, s: str): io.S.append(s)
    def readline(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p)
    def readtoken(io):
        io.load(); r = io.O[io.l]
        if ~(p := io.B.find(b' ', io.p, r)): s = io._dec(io.p, p); io.p = p+1
        else: s = io._dec(io.p, r-1); io.p = r; io.l += 1
        return s
    def readtokens(io): io.st.clear(); return io.readtokensinto(io.st)
    def readints(io): io.ist.clear(); return io.readintsinto(io.ist)
    def readdigits(io): io.ist.clear(); return io.readdigitsinto(io.ist)
    def readnums(io): io.ist.clear(); return io.readnumsinto(io.ist)
    def readchars(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p-1)
    def readinto(io, lst):
        if io.char: return io.readcharsinto(lst)
        else: return io.readtokensinto(lst)
    def readcharsinto(io, lst): lst.extend(io.readchars()); return lst
    def readtokensinto(io, lst): 
        io.load(); r = io.O[io.l]
        while ~(p := io.B.find(b' ', io.p, r)): lst.append(io._dec(io.p, p)); io.p = p+1
        lst.append(io._dec(io.p, r-1)); io.p = r; io.l += 1; return lst
    def _readint(io, r):
        while io.p < r and io.B[io.p] <= 32: io.p += 1
        if io.p >= r: return None
        minus = x = 0
        if io.B[io.p] == 45: minus = 1; io.p += 1
        while io.p < r and io.B[io.p] >= 48: x = x * 10 + (io.B[io.p] & 15); io.p += 1
        io.p += 1
        return -x if minus else x
    def readintsinto(io, lst):
        io.load(); r = io.O[io.l]
        while io.p < r and (x := io._readint(r)) is not None: lst.append(x)
        io.l += 1; return lst
    def _readdigit(io): d = io.B[io.p] & 15; io.p += 1; return d
    def readdigitsinto(io, lst):
        io.load(); r = io.O[io.l]
        while io.p < r and io.B[io.p] > 32: lst.append(io._readdigit())
        if io.B[io.p] == 10: io.l += 1
        io.p += 1
        return lst
    def readnumsinto(io, lst):
        if io.char: return io.readdigitsinto(lst)
        else: return io.readintsinto(lst)
    def line(io): io.st.clear(); return io.readinto(io.st)
    def wait(io):
        io.load(); r = io.O[io.l]
        while io.p < r: yield
    def flush(io):
        if io.writable: os_write(io.f, io.S.build().encode(io.encoding, io.errors)); io.S = StringBuilder()
sys.stdin = IO.stdin = IO(sys.stdin); sys.stdout = IO.stdout = IO(sys.stdout)
import typing
from numbers import Number
from types import GenericAlias 
from typing import Callable, Collection

class Parsable:
    @classmethod
    def compile(cls):
        def parser(io: 'IOBase'): return cls(next(io))
        return parser
    @classmethod
    def __class_getitem__(cls, item): return GenericAlias(cls, item)

class Parser:
    def __init__(self, spec):  self.parse = Parser.compile(spec)
    def __call__(self, io: IOBase): return self.parse(io)
    @staticmethod
    def compile_type(cls, args = ()):
        if issubclass(cls, Parsable): return cls.compile(*args)
        elif issubclass(cls, (Number, str)):
            def parse(io: IOBase): return cls(next(io))              
            return parse
        elif issubclass(cls, tuple): return Parser.compile_tuple(cls, args)
        elif issubclass(cls, Collection): return Parser.compile_collection(cls, args)
        elif callable(cls):
            def parse(io: IOBase): return cls(next(io))              
            return parse
        else: raise NotImplementedError()
    @staticmethod
    def compile(spec=int):
        if isinstance(spec, (type, GenericAlias)):
            cls, args = typing.get_origin(spec) or spec, typing.get_args(spec) or tuple()
            return Parser.compile_type(cls, args)
        elif isinstance(offset := spec, Number): 
            cls = type(spec)  
            def parse(io: IOBase): return cls(next(io)) + offset
            return parse
        elif isinstance(args := spec, tuple): return Parser.compile_tuple(type(spec), args)
        elif isinstance(args := spec, Collection): return Parser.compile_collection(type(spec), args)
        elif isinstance(fn := spec, Callable): 
            def parse(io: IOBase): return fn(next(io))
            return parse
        else: raise NotImplementedError()
    @staticmethod
    def compile_line(cls, spec=int):
        if spec is int:
            def parse(io: IOBase): return cls(io.readnums())
        elif spec is str:
            def parse(io: IOBase): return cls(io.line())
        else:
            fn = Parser.compile(spec)
            def parse(io: IOBase): return cls((fn(io) for _ in io.wait()))
        return parse
    @staticmethod
    def compile_repeat(cls, spec, N):
        fn = Parser.compile(spec)
        def parse(io: IOBase): return cls([fn(io) for _ in range(N)])
        return parse
    @staticmethod
    def compile_children(cls, specs):
        fns = tuple((Parser.compile(spec) for spec in specs))
        def parse(io: IOBase): return cls([fn(io) for fn in fns])  
        return parse
    @staticmethod
    def compile_tuple(cls, specs):
        if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...: return Parser.compile_line(cls, specs[0])
        else: return Parser.compile_children(cls, specs)
    @staticmethod
    def compile_collection(cls, specs):
        if not specs or len(specs) == 1 or isinstance(specs, set):
            return Parser.compile_line(cls, *specs)
        elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
            return Parser.compile_repeat(cls, specs[0], specs[1])
        else:
            raise NotImplementedError()

def write(*args, **kwargs):
    '''Prints the values to a stream, or to stdout_fast by default.'''
    sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IO.stdout)
    at_start = True
    for x in args:
        if not at_start: file.write(sep)
        file.write(str(x))
        at_start = False
    file.write(kwargs.pop("end", "\n"))
    if kwargs.pop("flush", False): file.flush()

if __name__ == '__main__':
    main()
Back to top page