cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/set-power-series/polynomial_composite_set_power_series.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/polynomial_composite_set_power_series

def main():
    M, N = rd()
    A = rdl(M)
    B = rdl(1<<N)
    C = sps_composite(A, B, 998244353)
    wtnl(C)

from cp_library.math.sps.mod.sps_composite_fn import sps_composite
from cp_library.io.fast_io_fn import rd, rdl, wtnl

main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/polynomial_composite_set_power_series

def main():
    M, N = rd()
    A = rdl(M)
    B = rdl(1<<N)
    C = sps_composite(A, B, 998244353)
    wtnl(C)

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''


def popcnts(N):
    P = [0]*(1 << N)
    for i in range(N):
        for m in range(b := 1<<i):
            P[m^b] = P[m] + 1
    return P



def elist(hint: int) -> list: ...
try:
    from __pypy__ import newlist_hint
except:
    def newlist_hint(hint): return []
elist = newlist_hint
    

from typing import Generic
from typing import TypeVar

_S = TypeVar('S'); _T = TypeVar('T'); _U = TypeVar('U'); _T1 = TypeVar('T1'); _T2 = TypeVar('T2'); _T3 = TypeVar('T3'); _T4 = TypeVar('T4'); _T5 = TypeVar('T5'); _T6 = TypeVar('T6')
import sys

def list_find(lst: list, value, start = 0, stop = sys.maxsize):
    try:
        return lst.index(value, start, stop)
    except:
        return -1


class view(Generic[_T]):
    __slots__ = 'A', 'l', 'r'
    def __init__(V, A: list[_T], l: int = 0, r: int = 0): V.A, V.l, V.r = A, l, r
    def __len__(V): return V.r - V.l
    def __getitem__(V, i: int): 
        if 0 <= i < V.r - V.l: return V.A[V.l+i]
        else: raise IndexError
    def __setitem__(V, i: int, v: _T): V.A[V.l+i] = v
    def __contains__(V, v: _T): return list_find(V.A, v, V.l, V.r) != -1
    def set_range(V, l: int, r: int): V.l, V.r = l, r
    def index(V, v: _T): return V.A.index(v, V.l, V.r) - V.l
    def reverse(V):
        l, r = V.l, V.r-1
        while l < r: V.A[l], V.A[r] = V.A[r], V.A[l]; l += 1; r -= 1
    def sort(V, /, *args, **kwargs):
        A = V.A[V.l:V.r]; A.sort(*args, **kwargs)
        for i,a in enumerate(A,V.l): V.A[i] = a
    def pop(V): V.r -= 1; return V.A[V.r]
    def append(V, v: _T): V.A[V.r] = v; V.r += 1
    def popleft(V): V.l += 1; return V.A[V.l-1]
    def appendleft(V, v: _T): V.l -= 1; V.A[V.l] = v; 
    def validate(V): return 0 <= V.l <= V.r <= len(V.A)

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
    x₀ ────────●─●────────●───●────────●───────●────────► X₀
                ╳          ╲ ╱          ╲     ╱          
    x₄ ────────●─●────────●─╳─●────────●─╲───╱─●────────► X₁
                           ╳ ╳          ╲ ╲ ╱ ╱          
    x₂ ────────●─●────────●─╳─●────────●─╲─╳─╱─●────────► X₂
                ╳          ╱ ╲          ╲ ╳ ╳ ╱          
    x₆ ────────●─●────────●───●────────●─╳─╳─╳─●────────► X₃
                                        ╳ ╳ ╳ ╳         
    x₁ ────────●─●────────●───●────────●─╳─╳─╳─●────────► X₄
                ╳          ╲ ╱          ╱ ╳ ╳ ╲          
    x₅ ────────●─●────────●─╳─●────────●─╱─╳─╲─●────────► X₅
                           ╳ ╳          ╱ ╱ ╲ ╲          
    x₃ ────────●─●────────●─╳─●────────●─╱───╲─●────────► X₆
                ╳          ╱ ╲          ╱     ╲          
    x₇ ────────●─●────────●───●────────●───────●────────► X₇
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
                      Math - Convolution                     
'''

def ior_zeta(A: list[int], N: int, Z: int = None):
    Z = Z if Z else len(A)
    for i in range(N):
        m = b = 1<<i
        while m < Z: A[m] += A[m^b]; m = m+1|b
    return A

def ior_mobius_ranked(A: list[int], N: int, M: int, Z: int):
    for i in range(0, Z, M):
        l, r = i, i+M-(1<<(N-(i>>N)))+1
        for j in range(N):
            m = l|(b := 1<<j)
            while m < r: A[m] -= A[m^b]; m = m+1|b
    return A



def isubset_conv_zeta_ranked(Ar: list[int], Br: list[int], n: int, N: int, mod: int) -> list[int]:
    m = 1<<n
    for ij in range(n,-1,-1):
        ij_, i_ = (ij+1)<<N|m, ij<<n
        for k in range(m): Ar[ij_|k] = Br[i_|k] * Ar[k] % mod
        for i in range(ij):
            j = ij-i; i_, j_ = i<<n, j<<N
            for k in range(m): Ar[ij_|k] = (Ar[ij_|k] + Br[i_|k] * Ar[j_|k]) % mod

def sps_composite(A: list[int], B: list[int], mod: int) -> list[int]:
    C = [0]*(M := 1 << (N := len(B).bit_length() - 1))
    if not A: return C
    dA, B0, B1, Br, Cr, pcnt = A[:], elist(N+1), view(B), elist(N), [0]*(Z := (N+1)*M), popcnts(N)
    for n in range(N+1):
        if n < N:
            # zeta transform of ranked 
            B1.set_range(1<<n, 2<<n)
            br = [0]*(z := (n+1)*(m := 1<<n))
            for i in range(m): br[pcnt[i]<<n|i] = B1[i]
            ior_zeta(br, n)
            for i in range(z): br[i] %= mod
            Br.append(br)
        # evaluate current polynomial at B[0] using Horner's method
        t = 0
        for j in range(len(dA)-1, -1, -1): t = (t * B[0] + dA[j]) % mod
        B0.append(t)
        # update dA to be the derivative
        for j in range(1, len(dA)): dA[j-1] = (j * dA[j]) % mod
        if dA: dA[-1] = 0
    for n in range(N+1):
        for m in range(n-1, -1, -1):
            # effectively computes `C[1<<m:2<<m] = subset_conv(C[:1<<m], B[1<<m:2<<m])`
            # but basically maintains `Cr`, the ranked zeta transformed `C`
            # partial zeta updates need to be made after loop ends to propagate contributions
            isubset_conv_zeta_ranked(Cr, Br[m], m, N, mod)
        # partial zeta updates
        for m in range(n):
            b = 1 << m
            for j in range(m+1):
                j <<= N
                for k in range(j, j|b): Cr[k|b] += Cr[k]
        for k in range(1<<n): Cr[k] = B0[~n]
    ior_mobius_ranked(Cr, N, M, Z)
    for i, p in enumerate(pcnt): C[i] = Cr[p<<N|i] % mod
    return C

from os import read as os_read, write as os_write, fstat as os_fstat
from __pypy__.builders import StringBuilder



def max2(a, b): return a if a > b else b

class IOBase:
    @property
    def char(io) -> bool: ...
    @property
    def writable(io) -> bool: ...
    def __next__(io) -> str: ...
    def write(io, s: str) -> None: ...
    def readline(io) -> str: ...
    def readtoken(io) -> str: ...
    def readtokens(io) -> list[str]: ...
    def readints(io) -> list[int]: ...
    def readdigits(io) -> list[int]: ...
    def readnums(io) -> list[int]: ...
    def readchar(io) -> str: ...
    def readchars(io) -> str: ...
    def readinto(io, lst: list[str]) -> list[str]: ...
    def readcharsinto(io, lst: list[str]) -> list[str]: ...
    def readtokensinto(io, lst: list[str]) -> list[str]: ...
    def readintsinto(io, lst: list[int]) -> list[int]: ...
    def readdigitsinto(io, lst: list[int]) -> list[int]: ...
    def readnumsinto(io, lst: list[int]) -> list[int]: ...
    def wait(io): ...
    def flush(io) -> None: ...
    def line(io) -> list[str]: ...

class IO(IOBase):
    BUFSIZE = 1 << 16; stdin: 'IO'; stdout: 'IO'
    __slots__ = 'f', 'file', 'B', 'O', 'V', 'S', 'l', 'p', 'char', 'sz', 'st', 'ist', 'writable', 'encoding', 'errors'
    def __init__(io, file):
        io.file = file
        try: io.f = file.fileno(); io.sz, io.writable = max2(io.BUFSIZE, os_fstat(io.f).st_size), ('x' in file.mode or 'r' not in file.mode)
        except: io.f, io.sz, io.writable = -1, io.BUFSIZE, False
        io.B, io.O, io.S = bytearray(), [], StringBuilder(); io.V = memoryview(io.B); io.l = io.p = 0
        io.char, io.st, io.ist, io.encoding, io.errors = False, [], [], 'ascii', 'ignore'
    def _dec(io, l, r): return io.V[l:r].tobytes().decode(io.encoding, io.errors)
    def readbytes(io, sz): return os_read(io.f, sz)
    def load(io):
        while io.l >= len(io.O):
            if not (b := io.readbytes(io.sz)):
                if io.O[-1] < len(io.B): io.O.append(len(io.B))
                break
            pos = len(io.B); io.B.extend(b)
            while ~(pos := io.B.find(b'\n', pos)): io.O.append(pos := pos+1)
    def __next__(io):
        if io.char: return io.readchar()
        else: return io.readtoken()
    def readchar(io):
        io.load(); r = io.O[io.l]
        c = chr(io.B[io.p])
        if io.p >= r-1: io.p = r; io.l += 1
        else: io.p += 1
        return c
    def write(io, s: str): io.S.append(s)
    def readline(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p)
    def readtoken(io):
        io.load(); r = io.O[io.l]
        if ~(p := io.B.find(b' ', io.p, r)): s = io._dec(io.p, p); io.p = p+1
        else: s = io._dec(io.p, r-1); io.p = r; io.l += 1
        return s
    def readtokens(io): io.st.clear(); return io.readtokensinto(io.st)
    def readints(io): io.ist.clear(); return io.readintsinto(io.ist)
    def readdigits(io): io.ist.clear(); return io.readdigitsinto(io.ist)
    def readnums(io): io.ist.clear(); return io.readnumsinto(io.ist)
    def readchars(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p-1)
    def readinto(io, lst):
        if io.char: return io.readcharsinto(lst)
        else: return io.readtokensinto(lst)
    def readcharsinto(io, lst): lst.extend(io.readchars()); return lst
    def readtokensinto(io, lst): 
        io.load(); r = io.O[io.l]
        while ~(p := io.B.find(b' ', io.p, r)): lst.append(io._dec(io.p, p)); io.p = p+1
        lst.append(io._dec(io.p, r-1)); io.p = r; io.l += 1; return lst
    def _readint(io, r):
        while io.p < r and io.B[io.p] <= 32: io.p += 1
        if io.p >= r: return None
        minus = x = 0
        if io.B[io.p] == 45: minus = 1; io.p += 1
        while io.p < r and io.B[io.p] >= 48: x = x * 10 + (io.B[io.p] & 15); io.p += 1
        io.p += 1
        return -x if minus else x
    def readintsinto(io, lst):
        io.load(); r = io.O[io.l]
        while io.p < r and (x := io._readint(r)) is not None: lst.append(x)
        io.l += 1; return lst
    def _readdigit(io): d = io.B[io.p] & 15; io.p += 1; return d
    def readdigitsinto(io, lst):
        io.load(); r = io.O[io.l]
        while io.p < r and io.B[io.p] > 32: lst.append(io._readdigit())
        if io.B[io.p] == 10: io.l += 1
        io.p += 1
        return lst
    def readnumsinto(io, lst):
        if io.char: return io.readdigitsinto(lst)
        else: return io.readintsinto(lst)
    def line(io): io.st.clear(); return io.readinto(io.st)
    def wait(io):
        io.load(); r = io.O[io.l]
        while io.p < r: yield
    def flush(io):
        if io.writable: os_write(io.f, io.S.build().encode(io.encoding, io.errors)); io.S = StringBuilder()
sys.stdin = IO.stdin = IO(sys.stdin); sys.stdout = IO.stdout = IO(sys.stdout)
def rd(): return IO.stdin.readints()
def rds(): return IO.stdin.__next__()
def rdl(n): return IO.stdin.readintsinto(elist(n))
def wt(s): IO.stdout.write(s)
def wtn(s): IO.stdout.write(f'{s}\n')
def wtnl(l): IO.stdout.write(' '.join(map(str, l)))

main()
Back to top page