cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/graph/incremental_scc.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/incremental_scc

def main():
    N, M = rd(), rd()
    X, U, V = rdl(N), [0]*M, [0]*M
    for e in range(M): U[e], V[e] = rd(), rd()
    W, dsu, ans, mod = scc_incremental(N, M, U, V), [*range(N)], [0]*M, 998244353; cur = t = 0
    for e in argsort_bounded(W,M):
        while t < W[e]: ans[t] = cur; t += 1
        u, v = U[e], V[e]
        while u != dsu[u]: dsu[u] = u = dsu[dsu[u]]
        while v != dsu[v]: dsu[v] = v = dsu[dsu[v]]
        if u != v: dsu[v], cur, X[u] = u, (cur+X[u]*X[v])%mod, (X[u]+X[v])%mod
    while t < M: ans[t] = cur; t += 1
    wtnl(ans)

from cp_library.alg.graph.fast.snippets.scc_incremental_fn import scc_incremental
from cp_library.alg.iter.argsort_bounded_fn import argsort_bounded
from cp_library.io.fast.fast_io_fn import rd, rdl, wtnl

if __name__ == '__main__':
    main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/incremental_scc

def main():
    N, M = rd(), rd()
    X, U, V = rdl(N), [0]*M, [0]*M
    for e in range(M): U[e], V[e] = rd(), rd()
    W, dsu, ans, mod = scc_incremental(N, M, U, V), [*range(N)], [0]*M, 998244353; cur = t = 0
    for e in argsort_bounded(W,M):
        while t < W[e]: ans[t] = cur; t += 1
        u, v = U[e], V[e]
        while u != dsu[u]: dsu[u] = u = dsu[dsu[u]]
        while v != dsu[v]: dsu[v] = v = dsu[dsu[v]]
        if u != v: dsu[v], cur, X[u] = u, (cur+X[u]*X[v])%mod, (X[u]+X[v])%mod
    while t < M: ans[t] = cur; t += 1
    wtnl(ans)

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''





def scc_incremental(N, M, U, V):
    U, V, W, La, Ra, Va = U[:], V[:], [M]*M, [0]*N, [0]*N, [0]*M
    E, F, sccs, st, buf, tin, low = [*range(M)], [*range(M)], [0]*N, [0]*N, [0]*N, [-1]*N, [-1]*N

    def build_csr(N, E, el, er):
        u = tot = 0
        while u < N: La[u], tin[u] = 0, -1; u += 1
        i = el
        while i < er: La[U[e := E[i]]] += 1; i += 1
        u = 0
        while u < N: La[u] = Ra[u] = (tot := tot + La[u]); u += 1
        i = el
        while i < er: La[u] = a = La[u := U[e := E[i]]]-1; Va[a] = V[e]; i += 1

    def scc_labels(N, E, el, em, er, La, Ra, Va):
        t = cnt = -1; i = el
        while i < em:
            u = U[E[i]]; i += 1
            if tin[u] < 0:
                st[0] = u; d = b = 0
                while d >= 0:
                    if tin[u := st[d]] == -1: tin[u] = low[u] = (t:=t+1); buf[b] = u; b += 1
                    if La[u] < Ra[u]:
                        if (tv := tin[Va[La[u]]])== -1: st[d:=d+1] = Va[La[u]]
                        elif tv < low[u]: low[u] = tv
                        La[u] += 1
                    else:
                        if (d:=d-1) >= 0 and low[u] < low[st[d]]: low[st[d]] = low[u]
                        if low[u] == tin[u]:
                            v, cnt = -1, cnt+1
                            while u != v: tin[v := buf[b:=b-1]], sccs[buf[b]] = N, cnt
        while i < er:
            u, v = U[E[i]], V[E[i]]; i += 1
            if tin[u] < 0: tin[u], sccs[u] = N, (cnt:=cnt+1)
            if tin[v] < 0: tin[v], sccs[v] = N, (cnt:=cnt+1)
        return cnt+1
    
    def partition(el, er, tm):
        i = em = el
        while i < er:
            if sccs[U[e := E[i]]] == sccs[V[e]]: W[e], F[em] = tm, e; em += 1
            i += 1
        i, fm = el, em
        while i < er:
            if (u := sccs[U[e := E[i]]]) != (v := sccs[V[e]]): U[e], V[e], F[fm] = u, v, e; fm += 1
            i += 1
        return em
    
    def div_con(N, el, er, tl, tr):
        nonlocal E, F
        if el == er: return
        tm, em = (tl+tr) >> 1, el
        while em < er and E[em] <= tm: em += 1
        build_csr(N, E, el, em)
        nN = scc_labels(N, E, el, em, er, La, Ra, Va)
        em = partition(el, er, tm)
        if tr-tl==1: return
        E, F = F, E
        div_con(nN, em, er, tm, tr)
        div_con(N, el, em, tl, tm)
        E, F = F, E
    div_con(N, 0, M, -1, M)
    return W


def argsort_bounded(A, mx):
    I, cnt, t = [0]*len(A), [0]*(mx+1), 0
    for a in A: cnt[a] += 1
    for i in range(mx+1): cnt[i], t = t, t+cnt[i]
    for i,a in enumerate(A): I[cnt[a]] = i; cnt[a] += 1
    return I


from __pypy__.builders import StringBuilder
import sys
from os import read as os_read, write as os_write
from atexit import register as atexist_register

class Fastio:
    ibuf = bytes()
    pil = pir = 0
    sb = StringBuilder()
    def load(self):
        self.ibuf = self.ibuf[self.pil:]
        self.ibuf += os_read(0, 131072)
        self.pil = 0; self.pir = len(self.ibuf)
    def flush_atexit(self): os_write(1, self.sb.build().encode())
    def flush(self):
        os_write(1, self.sb.build().encode())
        self.sb = StringBuilder()
    def fastin(self):
        if self.pir - self.pil < 64: self.load()
        minus = x = 0
        while self.ibuf[self.pil] < 45: self.pil += 1
        if self.ibuf[self.pil] == 45: minus = 1; self.pil += 1
        while self.ibuf[self.pil] >= 48:
            x = x * 10 + (self.ibuf[self.pil] & 15)
            self.pil += 1
        if minus: return -x
        return x
    def fastin_string(self):
        if self.pir - self.pil < 64: self.load()
        while self.ibuf[self.pil] <= 32: self.pil += 1
        res = bytearray()
        while self.ibuf[self.pil] > 32:
            if self.pir - self.pil < 64: self.load()
            res.append(self.ibuf[self.pil])
            self.pil += 1
        return res
    def fastout(self, x): self.sb.append(str(x))
    def fastoutln(self, x): self.sb.append(str(x)); self.sb.append('\n')
fastio = Fastio()
rd = fastio.fastin; rds = fastio.fastin_string; wt = fastio.fastout; wtn = fastio.fastoutln; flush = fastio.flush
atexist_register(fastio.flush_atexit)
sys.stdin = None; sys.stdout = None
def rdl(n): return [rd() for _ in range(n)]
def wtnl(l): wtn(' '.join(map(str, l)))

if __name__ == '__main__':
    main()
Back to top page