cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: cp_library/alg/graph/edmonds_fn.py

Depends on

Verified with

Code

import cp_library.alg.graph.__header__
from functools import reduce
from heapq import heapify
import cp_library.misc.setrecursionlimit
from cp_library.ds.dsu_cls import DSU
from cp_library.alg.graph.floyds_cycle_fn import floyds_cycle

def edmonds_branching(E, N, root) -> list[tuple[int,int,any]]:
    # obtain incoming edges
    Gin = [[] for _ in range(N)]
    for id,(u,v,w) in enumerate(E):
        if v != root:
            Gin[v].append([w,u,id])
    

    # heapify for fast access to optimal edges
    for v in range(N):
        heapify(Gin[v])

    groups = DSU(N)
    active = set(range(N))
    active.discard(root)

    def find_cycle(min_in):
        for v in active:
            cyc = floyds_cycle(min_in, v)
            if cyc: return cyc
        return None
    
    def contract(cyc):
        kickout = [-1]*len(E)
        active.difference_update(cyc)
        nv = reduce(groups.merge, cyc)
        active.add(nv)
        new_edges = []
        
        # Update Gin to reflect the contracted cycle
        for v in cyc:
            cw, _, cid = Gin[v][0]
            for edge in Gin[v]:
                _, u, id = edge
                if groups.leader(u) != nv:
                    edge[0] -= cw # update weight
                    kickout[id] = cid
                    new_edges.append(edge)
                    if new_edges[-1][0] < new_edges[0][0]:
                        new_edges[0], new_edges[-1] = new_edges[-1], new_edges[0]
            Gin[v].clear()
        Gin[nv] = new_edges
        return kickout


    def rec(Gin):
        min_in = [groups.leader(Gin[v][0][1]) if Gin[v] else -1 for v in range(N)]
        cyc = find_cycle(min_in)
        if cyc:
            C = { Gin[v][0][2] for v in cyc }
            kickout = contract(cyc)
            MCA = rec(Gin)
            for id in MCA:
                C.discard(kickout[id])
            MCA.extend(C)
            return MCA
        else:
            return [edges[0][2] for edges in Gin if edges]

    return [E[id] for id in rec(Gin)]
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
from functools import reduce
from heapq import heapify


import sys
sys.setrecursionlimit(10**6)
import pypyjit
pypyjit.set_param("max_unroll_recursion=-1")
from typing import Collection

import typing
from collections import deque
from numbers import Number
from types import GenericAlias 
from typing import Callable, Collection, Iterator, Union
import os
from io import BytesIO, IOBase


class FastIO(IOBase):
    BUFSIZE = 8192
    newlines = 0

    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None

    def read(self):
        BUFSIZE = self.BUFSIZE
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()

    def readline(self):
        BUFSIZE = self.BUFSIZE
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()

    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)


class IOWrapper(IOBase):
    stdin: 'IOWrapper' = None
    stdout: 'IOWrapper' = None
    
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable

    def write(self, s):
        return self.buffer.write(s.encode("ascii"))
    
    def read(self):
        return self.buffer.read().decode("ascii")
    
    def readline(self):
        return self.buffer.readline().decode("ascii")

sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
from typing import TypeVar
_T = TypeVar('T')

class TokenStream(Iterator):
    stream = IOWrapper.stdin

    def __init__(self):
        self.queue = deque()

    def __next__(self):
        if not self.queue: self.queue.extend(self._line())
        return self.queue.popleft()
    
    def wait(self):
        if not self.queue: self.queue.extend(self._line())
        while self.queue: yield
 
    def _line(self):
        return TokenStream.stream.readline().split()

    def line(self):
        if self.queue:
            A = list(self.queue)
            self.queue.clear()
            return A
        return self._line()
TokenStream.default = TokenStream()

class CharStream(TokenStream):
    def _line(self):
        return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()


ParseFn = Callable[[TokenStream],_T]
class Parser:
    def __init__(self, spec: Union[type[_T],_T]):
        self.parse = Parser.compile(spec)

    def __call__(self, ts: TokenStream) -> _T:
        return self.parse(ts)
    
    @staticmethod
    def compile_type(cls: type[_T], args = ()) -> _T:
        if issubclass(cls, Parsable):
            return cls.compile(*args)
        elif issubclass(cls, (Number, str)):
            def parse(ts: TokenStream): return cls(next(ts))              
            return parse
        elif issubclass(cls, tuple):
            return Parser.compile_tuple(cls, args)
        elif issubclass(cls, Collection):
            return Parser.compile_collection(cls, args)
        elif callable(cls):
            def parse(ts: TokenStream):
                return cls(next(ts))              
            return parse
        else:
            raise NotImplementedError()
    
    @staticmethod
    def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
        if isinstance(spec, (type, GenericAlias)):
            cls = typing.get_origin(spec) or spec
            args = typing.get_args(spec) or tuple()
            return Parser.compile_type(cls, args)
        elif isinstance(offset := spec, Number): 
            cls = type(spec)  
            def parse(ts: TokenStream): return cls(next(ts)) + offset
            return parse
        elif isinstance(args := spec, tuple):      
            return Parser.compile_tuple(type(spec), args)
        elif isinstance(args := spec, Collection):  
            return Parser.compile_collection(type(spec), args)
        elif isinstance(fn := spec, Callable): 
            def parse(ts: TokenStream): return fn(next(ts))
            return parse
        else:
            raise NotImplementedError()

    @staticmethod
    def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
        if spec is int:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
            return parse
        else:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
            return parse

    @staticmethod
    def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
        fn = Parser.compile(spec)
        def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
        return parse

    @staticmethod
    def compile_children(cls: _T, specs) -> ParseFn[_T]:
        fns = tuple((Parser.compile(spec) for spec in specs))
        def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])  
        return parse
            
    @staticmethod
    def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
        if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
            return Parser.compile_line(cls, specs[0])
        else:
            return Parser.compile_children(cls, specs)

    @staticmethod
    def compile_collection(cls, specs):
        if not specs or len(specs) == 1 or isinstance(specs, set):
            return Parser.compile_line(cls, *specs)
        elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
            return Parser.compile_repeat(cls, specs[0], specs[1])
        else:
            raise NotImplementedError()

class Parsable:
    @classmethod
    def compile(cls):
        def parser(ts: TokenStream): return cls(next(ts))
        return parser

from typing import Sequence


class CSRIncremental(Sequence[list[_T]]):
    def __init__(csr, sizes: list[int]):
        csr.L, N = [0]*len(sizes), 0
        for i,sz in enumerate(sizes):
            csr.L[i] = N; N += sz
        csr.R, csr.A = csr.L[:], [0]*N

    def append(csr, i: int, x: _T):
        csr.A[csr.R[i]] = x; csr.R[i] += 1
    
    def __iter__(csr):
        for i,l in enumerate(csr.L):
            yield csr.A[l:csr.R[i]]
    
    def __getitem__(csr, i: int) -> _T:
        return csr.A[i]
    
    def __len__(dsu):
        return len(dsu.L)

    def range(csr, i: int) -> _T:
        return range(csr.L[i], csr.R[i])

class DSU(Parsable, Collection):
    def __init__(dsu, N):
        dsu.N, dsu.cc, dsu.par = N, N, [-1]*N

    def merge(dsu, u, v, src = False):
        x, y = dsu.leader(u), dsu.leader(v)
        if x == y: return (x,y) if src else x
        if dsu.par[x] > dsu.par[y]: x, y = y, x
        dsu.par[x] += dsu.par[y]; dsu.par[y] = x; dsu.cc -= 1
        return (x,y) if src else x

    def same(dsu, u: int, v: int):
        return dsu.leader(u) == dsu.leader(v)

    def leader(dsu, i) -> int:
        p = (par := dsu.par)[i]
        while p >= 0:
            if par[p] < 0: return p
            par[i], i, p = par[p], par[p], par[par[p]]
        return i

    def size(dsu, i) -> int:
        return -dsu.par[dsu.leader(i)]

    def groups(dsu) -> CSRIncremental[int]:
        sizes, row, p = [0]*dsu.cc, [-1]*dsu.N, 0
        for i in range(dsu.cc):
            while dsu.par[p] >= 0: p += 1
            sizes[i], row[p] = -dsu.par[p], i; p += 1
        csr = CSRIncremental(sizes)
        for i in range(dsu.N): csr.append(row[dsu.leader(i)], i)
        return csr
    
    __iter__ = groups
    
    def __len__(dsu):
        return dsu.cc
    
    def __contains__(dsu, uv):
        u, v = uv
        return dsu.same(u, v)
    
    @classmethod
    def compile(cls, N: int, M: int, shift = -1):
        def parse_fn(ts: TokenStream):
            dsu = cls(N)
            for _ in range(M):
                u, v = ts._line()
                dsu.merge(int(u)+shift, int(v)+shift)
            return dsu
        return parse_fn

def floyds_cycle(F, root):
    slow = fast = root
    while F[fast] != -1 and F[F[fast]] != -1:
        slow, fast = F[slow], F[F[fast]]
        if slow == fast:
            cyc = [slow]
            while F[slow] != cyc[0]:
                slow = F[slow]
                cyc.append(slow)
            return cyc
    return None

def edmonds_branching(E, N, root) -> list[tuple[int,int,any]]:
    # obtain incoming edges
    Gin = [[] for _ in range(N)]
    for id,(u,v,w) in enumerate(E):
        if v != root:
            Gin[v].append([w,u,id])
    

    # heapify for fast access to optimal edges
    for v in range(N):
        heapify(Gin[v])

    groups = DSU(N)
    active = set(range(N))
    active.discard(root)

    def find_cycle(min_in):
        for v in active:
            cyc = floyds_cycle(min_in, v)
            if cyc: return cyc
        return None
    
    def contract(cyc):
        kickout = [-1]*len(E)
        active.difference_update(cyc)
        nv = reduce(groups.merge, cyc)
        active.add(nv)
        new_edges = []
        
        # Update Gin to reflect the contracted cycle
        for v in cyc:
            cw, _, cid = Gin[v][0]
            for edge in Gin[v]:
                _, u, id = edge
                if groups.leader(u) != nv:
                    edge[0] -= cw # update weight
                    kickout[id] = cid
                    new_edges.append(edge)
                    if new_edges[-1][0] < new_edges[0][0]:
                        new_edges[0], new_edges[-1] = new_edges[-1], new_edges[0]
            Gin[v].clear()
        Gin[nv] = new_edges
        return kickout


    def rec(Gin):
        min_in = [groups.leader(Gin[v][0][1]) if Gin[v] else -1 for v in range(N)]
        cyc = find_cycle(min_in)
        if cyc:
            C = { Gin[v][0][2] for v in cyc }
            kickout = contract(cyc)
            MCA = rec(Gin)
            for id in MCA:
                C.discard(kickout[id])
            MCA.extend(C)
            return MCA
        else:
            return [edges[0][2] for edges in Gin if edges]

    return [E[id] for id in rec(Gin)]
Back to top page