cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/data-structure/unionfind_with_potential.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/unionfind_with_potential

from cp_library.ds.potentialized_dsu_cls import PotentializedDSU
from cp_library.io.read_int_fn import read
from cp_library.io.write_fn import write

mod = 998244353
N, Q = read()

def op(x,y):
    return (x+y)%mod

def inv(x):
    return (-x)%mod

pdsu = PotentializedDSU(op,inv,0,N)

for _ in range(Q):
    t, *q = read()
    if t:
        u, v = q
        ans = pdsu.diff(u, v) if pdsu.same(u, v) else -1
        write(ans)
    else:
        u, v, x = q
        write(int(pdsu.consistent(u,v,x)))
        pdsu.merge(u, v, x)
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/unionfind_with_potential

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''

class PotentializedDSU:

    def __init__(self, op, inv, e, v) -> None:
        n = v if isinstance(v, int) else len(v)
        self.n = n
        self.par = [-1] * n
        self.op = op
        self.inv = inv
        self.e = e
        self.pot = [e] * n if isinstance(v, int) else v

    def leader(self, x: int) -> int:
        assert 0 <= x < self.n
        path = []
        while self.par[x] >= 0:
            path.append(x)
            x = self.par[x]
        for y in reversed(path):
            self.pot[y] = self.op(self.pot[y], self.pot[self.par[y]])
            self.par[y] = x
        return x
    
    def consistent(self, x: int, y: int, w) -> bool:
        rx = self.leader(x)
        ry = self.leader(y)
        if rx == ry:
            return self.op(self.pot[x], self.inv(self.pot[y])) == w
        return True

    def merge(self, x: int, y: int, w) -> tuple[int, int]:
        assert 0 <= x < self.n
        assert 0 <= y < self.n
        rx = self.leader(x)
        ry = self.leader(y)
        if rx != ry:
            par = self.par
            if par[rx] < par[ry]:
                x,y,w,rx,ry = y,x,self.inv(w),ry,rx
                
            par[ry] += par[rx]
            par[rx] = ry
            self.pot[rx] = self.op(
                self.op(self.inv(self.pot[x]), w), self.pot[y]
            )
        return ry, rx

    def same(self, x: int, y: int) -> bool:
        assert 0 <= x < self.n
        assert 0 <= y < self.n
        return self.leader(x) == self.leader(y)
    
    def size(self, x: int) -> int:
        assert 0 <= x < self.n
        return -self.par[self.leader(x)]
    
    def groups(self):
        leader_buf = [self.leader(i) for i in range(self.n)]

        result = [[] for _ in range(self.n)]
        for i in range(self.n):
            result[leader_buf[i]].append(i)

        return list(filter(lambda r: r, result))

    def diff(self, x: int, y: int):
        assert self.same(x, y)
        return self.op(self.pot[x], self.inv(self.pot[y]))


def read(shift=0, base=10):
    return [int(s, base) + shift for s in input().split()]
import os
import sys
from io import BytesIO, IOBase


class FastIO(IOBase):
    BUFSIZE = 8192
    newlines = 0

    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None

    def read(self):
        BUFSIZE = self.BUFSIZE
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()

    def readline(self):
        BUFSIZE = self.BUFSIZE
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()

    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)


class IOWrapper(IOBase):
    stdin: 'IOWrapper' = None
    stdout: 'IOWrapper' = None
    
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable

    def write(self, s):
        return self.buffer.write(s.encode("ascii"))
    
    def read(self):
        return self.buffer.read().decode("ascii")
    
    def readline(self):
        return self.buffer.readline().decode("ascii")

sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)

def write(*args, **kwargs):
    '''Prints the values to a stream, or to stdout_fast by default.'''
    sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IOWrapper.stdout)
    at_start = True
    for x in args:
        if not at_start:
            file.write(sep)
        file.write(str(x))
        at_start = False
    file.write(kwargs.pop("end", "\n"))
    if kwargs.pop("flush", False):
        file.flush()

mod = 998244353
N, Q = read()

def op(x,y):
    return (x+y)%mod

def inv(x):
    return (-x)%mod

pdsu = PotentializedDSU(op,inv,0,N)

for _ in range(Q):
    t, *q = read()
    if t:
        u, v = q
        ans = pdsu.diff(u, v) if pdsu.same(u, v) else -1
        write(ans)
    else:
        u, v, x = q
        write(int(pdsu.consistent(u,v,x)))
        pdsu.merge(u, v, x)
Back to top page