cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/graph/directedmst.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/directedmst
import os,sys,io
input=io.BytesIO(os.read(0,os.fstat(0).st_size)).readline


class UnionFind:
    def __init__(self, n):
        self.parent = [-1] * n
        self.n = n

    def root(self, x):
        if self.parent[x] < 0:
            return x
        else:
            self.parent[x] = self.root(self.parent[x])
            return self.parent[x]

    def merge(self, x, y):
        x = self.root(x)
        y = self.root(y)
        if x == y:
            return False
        self.parent[x] += self.parent[y]
        self.parent[y] = x
        return True

    def same(self, x, y):
        return self.root(x) == self.root(y)

from cp_library.ds.heap.skew_heap_cls import SkewHeap

def directed_mst(n, edges, root):
    OFFSET = 1 << 31
    from_cost = [0] * n
    from_heap = [SkewHeap() for _ in range(n)]
    from_ = [0] * n

    uf = UnionFind(n)
    par_e = [-1] * m
    stem = [-1] * n
    used = [0] * n
    used[root] = 2
    idxs = []

    for idx, (fr, to, cost) in enumerate(edges):
        from_heap[to].push(cost << 31 | idx)

    res = 0
    for v in range(n):
        if used[v] != 0:
            continue
        processing = []
        chi_e = []
        cycle = 0
        while used[v] != 2:
            used[v] = 1
            processing.append(v)
            if from_heap[v].empty(): return -1, par
            from_cost[v], idx = divmod(from_heap[v].pop(), OFFSET)
            from_[v] = uf.root(edges[idx][0])
            if stem[v] == -1:
                stem[v] = idx
            if from_[v] == v: continue
            res += from_cost[v]
            idxs.append(idx)
            while cycle:
                par_e[chi_e.pop()] = idx
                cycle -= 1
            chi_e.append(idx)
            if used[from_[v]] == 1:
                p = v
                while True:
                    if from_heap[p]: from_heap[p].add(-from_cost[p] << 31)
                    if p != v:
                        uf.merge(v, p)
                        from_heap[v].merge(from_heap[p])
                    p = uf.root(from_[p])
                    cycle += 1
                    if p == v:
                        break
            else:
                v = from_[v]
        for v in processing: used[v] = 2

    used_e = [0] * m
    tree = [-1] * n
    for idx in reversed(idxs):
        if used_e[idx]: continue
        fr, to, cost = edges[idx]
        tree[to] = fr
        x = stem[to]
        while x != idx:
            used_e[x] = 1
            x = par_e[x]
    return res, tree


n, m, root = map(int, input().split())
edges = [[int(s) for s in input().split()] for i in range(m)]


res, par = directed_mst(n, edges, root)
if res == -1:
    print(res)
else:
    print(res)
    print(*[p if p != -1 else i for i, p in enumerate(par)])
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/directedmst
import os,sys,io
input=io.BytesIO(os.read(0,os.fstat(0).st_size)).readline


class UnionFind:
    def __init__(self, n):
        self.parent = [-1] * n
        self.n = n

    def root(self, x):
        if self.parent[x] < 0:
            return x
        else:
            self.parent[x] = self.root(self.parent[x])
            return self.parent[x]

    def merge(self, x, y):
        x = self.root(x)
        y = self.root(y)
        if x == y:
            return False
        self.parent[x] += self.parent[y]
        self.parent[y] = x
        return True

    def same(self, x, y):
        return self.root(x) == self.root(y)

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
import operator
from typing import Generic, TypeVar
_S = TypeVar('S')
_T = TypeVar('T')
_U = TypeVar('U')


_TSkewHeap = TypeVar("SkewHeap", bound="SkewHeap")
class SkewHeap(Generic[_T]):
    __slots__ = 'root', 'op', 'e'
    V, A, L, R, st = [-1], [-1], [-1], [-1], []
    def __init__(H, op = operator.add, e: _T = 0):
        H.root, H.op, H.e = -1, op, e
    
    def merge(H: _TSkewHeap, O: _TSkewHeap):
        H.root = H.merge_nodes(H.root, O.root)
        O.root = -1

    def min(H):
        assert ~H.root
        H.propagate(H.root)
        return H.V[H.root]

    def push(H, x: _T):
        id = len(H.V)
        H.V.append(x); H.A.append(H.e); H.L.append(-1); H.R.append(-1)
        H.root = H.merge_nodes(H.root, id)

    def pop(H) -> _T:
        assert ~H.root
        H.propagate(H.root)
        val, H.root = H.V[H.root], H.merge_nodes(H.L[H.root], H.R[H.root])
        return val
    
    def add(H, val: _T): H.A[H.root] = H.op(H.A[H.root], val)
    def empty(H): return H.root == -1
    def __bool__(H): return H.root != -1
    
    def propagate(H, u: int):
        if (a := H.A[u]) != H.e:
            if ~(l := H.L[u]): H.A[l] = H.op(H.A[l], a)
            if ~(r := H.R[u]): H.A[r] = H.op(H.A[r], a)
            H.V[u] = H.op(H.V[u], a); H.A[u] = H.e

    def merge_nodes(H, u: int, v:int):
        while ~u and ~v:
            H.propagate(u); H.propagate(v)
            if H.V[v] < H.V[u]: u, v = v, u
            H.st.append(u); H.R[u], u = H.L[u], H.R[u]
        u = u if ~u else v
        while H.st: H.L[u := H.st.pop()] = u
        return u

def directed_mst(n, edges, root):
    OFFSET = 1 << 31
    from_cost = [0] * n
    from_heap = [SkewHeap() for _ in range(n)]
    from_ = [0] * n

    uf = UnionFind(n)
    par_e = [-1] * m
    stem = [-1] * n
    used = [0] * n
    used[root] = 2
    idxs = []

    for idx, (fr, to, cost) in enumerate(edges):
        from_heap[to].push(cost << 31 | idx)

    res = 0
    for v in range(n):
        if used[v] != 0:
            continue
        processing = []
        chi_e = []
        cycle = 0
        while used[v] != 2:
            used[v] = 1
            processing.append(v)
            if from_heap[v].empty(): return -1, par
            from_cost[v], idx = divmod(from_heap[v].pop(), OFFSET)
            from_[v] = uf.root(edges[idx][0])
            if stem[v] == -1:
                stem[v] = idx
            if from_[v] == v: continue
            res += from_cost[v]
            idxs.append(idx)
            while cycle:
                par_e[chi_e.pop()] = idx
                cycle -= 1
            chi_e.append(idx)
            if used[from_[v]] == 1:
                p = v
                while True:
                    if from_heap[p]: from_heap[p].add(-from_cost[p] << 31)
                    if p != v:
                        uf.merge(v, p)
                        from_heap[v].merge(from_heap[p])
                    p = uf.root(from_[p])
                    cycle += 1
                    if p == v:
                        break
            else:
                v = from_[v]
        for v in processing: used[v] = 2

    used_e = [0] * m
    tree = [-1] * n
    for idx in reversed(idxs):
        if used_e[idx]: continue
        fr, to, cost = edges[idx]
        tree[to] = fr
        x = stem[to]
        while x != idx:
            used_e[x] = 1
            x = par_e[x]
    return res, tree


n, m, root = map(int, input().split())
edges = [[int(s) for s in input().split()] for i in range(m)]


res, par = directed_mst(n, edges, root)
if res == -1:
    print(res)
else:
    print(res)
    print(*[p if p != -1 else i for i, p in enumerate(par)])
Back to top page