cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/set-power-series/subset_convolution_snippet.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/subset_convolution

def main():
    mod = 998244353
    n = rd()
    a = rdl(1 << n)
    b = rdl(1 << n)
    wtnl(isubset_conv(a, b, n, mod))

from cp_library.alg.dp.butterfly.butterfly_masks_fn import subset_zeta_pair, subset_mobius
from cp_library.bit.popcnts_fn import popcnts

def isubset_conv(A,B,N,mod):
    assert len(A) == len(B)
    Z = (N+1)*(M := 1<<N)
    Ar,Br,Cr,P = [0]*Z, [0]*Z, [0]*Z, popcnts(N)
    for i,p in enumerate(P): Ar[p<<N|i], Br[p<<N|i] = A[i], B[i]
    subset_zeta_pair(Ar, Br, N)
    for i in range(Z): Ar[i], Br[i] = Ar[i]%mod, Br[i]%mod
    for i in range(0,Z,M):
        for j in range(0,Z-i,M):
            ij = i+j
            for k in range(M): Cr[ijk] = (Cr[ijk:=ij|k] + Ar[i|k] * Br[j|k]) % mod
    subset_mobius(Cr, N)
    for i,p in enumerate(P): A[i] = Cr[p<<N|i] % mod
    return A

from cp_library.io.fast.fast_io_fn import rd, rdl, wtnl

main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/subset_convolution

def main():
    mod = 998244353
    n = rd()
    a = rdl(1 << n)
    b = rdl(1 << n)
    wtnl(isubset_conv(a, b, n, mod))

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''


'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
    x₀ ────────●─●────────●───●────────●───────●────────► X₀
                ╳          ╲ ╱          ╲     ╱          
    x₄ ────────●─●────────●─╳─●────────●─╲───╱─●────────► X₁
                           ╳ ╳          ╲ ╲ ╱ ╱          
    x₂ ────────●─●────────●─╳─●────────●─╲─╳─╱─●────────► X₂
                ╳          ╱ ╲          ╲ ╳ ╳ ╱          
    x₆ ────────●─●────────●───●────────●─╳─╳─╳─●────────► X₃
                                        ╳ ╳ ╳ ╳         
    x₁ ────────●─●────────●───●────────●─╳─╳─╳─●────────► X₄
                ╳          ╲ ╱          ╱ ╳ ╳ ╲          
    x₅ ────────●─●────────●─╳─●────────●─╱─╳─╲─●────────► X₅
                           ╳ ╳          ╱ ╱ ╲ ╲          
    x₃ ────────●─●────────●─╳─●────────●─╱───╲─●────────► X₆
                ╳          ╱ ╲          ╱     ╲          
    x₇ ────────●─●────────●───●────────●───────●────────► X₇
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
                 Algorithms - DP - Butterfly                     
'''

def butterfly_masks(N, Z):
    for i in range(N):
        m = b = 1<<i
        while m < Z:
            yield m^b, m
            m = (m+1)|b

def fwht(A: list, N: int):
    for m0, m1 in butterfly_masks(N, len(A)):
        a0, a1 = A[m0], A[m1]
        A[m0], A[m1] = a0+a1, a0-a1
    return A

def subset_zeta(A: list[int], N: int):
    for m0, m1 in butterfly_masks(N, len(A)):
        A[m1] += A[m0]
    return A

def subset_zeta_pair(A: list[int], B: list[int], N: int):
    for m0, m1 in butterfly_masks(N, len(A)):
        A[m1] += A[m0]
        B[m1] += B[m0]
    return A, B

def subset_mobius(A: list[int], N: int):
    for m0, m1 in butterfly_masks(N, len(A)):
        A[m1] -= A[m0]
    return A

def superset_zeta(A, N: int):
    for m0, m1 in butterfly_masks(N, len(A)):
        A[m0] += A[m1]
    return A

def superset_mobius(A, N: int):
    for m0, m1 in butterfly_masks(N, len(A)):
        A[m0] -= A[m1]
    return A

def popcnts(N):
    P = [0]*(1 << N)
    for i in range(N):
        for m in range(b := 1<<i):
            P[m^b] = P[m] + 1
    return P


def subset_conv(A,B,N):
    assert len(A) == len(B)
    Z = (N+1)*(M := 1<<N)
    Ar,Br,Cr,P = [0]*Z, [0]*Z, [0]*Z, popcnts(N)
    for i,p in enumerate(P): Ar[p<<N|i], Br[p<<N|i] = A[i], B[i]
    subset_zeta_pair(Ar, Br, N)
    for i in range(0,Z,M):
        for j in range(0,Z-i,M):
            ij = i+j
            for k in range(M): Cr[ij|k] += Ar[i|k] * Br[j|k]
    subset_mobius(Cr, N)
    for i,p in enumerate(P): A[i] = Cr[p<<N|i]
    return A


def popcnts(N):
    P = [0]*(1 << N)
    for i in range(N):
        for m in range(b := 1<<i):
            P[m^b] = P[m] + 1
    return P

def isubset_conv(A,B,N,mod):
    assert len(A) == len(B)
    Z = (N+1)*(M := 1<<N)
    Ar,Br,Cr,P = [0]*Z, [0]*Z, [0]*Z, popcnts(N)
    for i,p in enumerate(P): Ar[p<<N|i], Br[p<<N|i] = A[i], B[i]
    subset_zeta_pair(Ar, Br, N)
    for i in range(Z): Ar[i], Br[i] = Ar[i]%mod, Br[i]%mod
    for i in range(0,Z,M):
        for j in range(0,Z-i,M):
            ij = i+j
            for k in range(M): Cr[ijk] = (Cr[ijk:=ij|k] + Ar[i|k] * Br[j|k]) % mod
    subset_mobius(Cr, N)
    for i,p in enumerate(P): A[i] = Cr[p<<N|i] % mod
    return A

from __pypy__.builders import StringBuilder
import sys
from os import read as os_read, write as os_write
from atexit import register as atexist_register

class Fastio:
    ibuf = bytes()
    pil = pir = 0
    sb = StringBuilder()
    def load(self):
        self.ibuf = self.ibuf[self.pil:]
        self.ibuf += os_read(0, 131072)
        self.pil = 0; self.pir = len(self.ibuf)
    def flush_atexit(self): os_write(1, self.sb.build().encode())
    def flush(self):
        os_write(1, self.sb.build().encode())
        self.sb = StringBuilder()
    def fastin(self):
        if self.pir - self.pil < 64: self.load()
        minus = x = 0
        while self.ibuf[self.pil] < 45: self.pil += 1
        if self.ibuf[self.pil] == 45: minus = 1; self.pil += 1
        while self.ibuf[self.pil] >= 48:
            x = x * 10 + (self.ibuf[self.pil] & 15)
            self.pil += 1
        if minus: return -x
        return x
    def fastin_string(self):
        if self.pir - self.pil < 64: self.load()
        while self.ibuf[self.pil] <= 32: self.pil += 1
        res = bytearray()
        while self.ibuf[self.pil] > 32:
            if self.pir - self.pil < 64: self.load()
            res.append(self.ibuf[self.pil])
            self.pil += 1
        return res
    def fastout(self, x): self.sb.append(str(x))
    def fastoutln(self, x): self.sb.append(str(x)); self.sb.append('\n')
fastio = Fastio()
rd = fastio.fastin; rds = fastio.fastin_string; wt = fastio.fastout; wtn = fastio.fastoutln; flush = fastio.flush
atexist_register(fastio.flush_atexit)
sys.stdin = None; sys.stdout = None
def rdl(n): return [rd() for _ in range(n)]
def wtnl(l): wtn(' '.join(map(str, l)))

main()
Back to top page