cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: cp_library/alg/tree/tree_proto.py

Depends on

Required by

Verified with

Code

from cp_library.ds.elist_fn import elist
import cp_library.alg.tree.__header__

from typing import overload, Literal, Union
from functools import cached_property
from math import inf
from collections import deque
from cp_library.alg.graph.dfs_options_cls import DFSFlags, DFSEvent
from cp_library.alg.graph.graph_proto import GraphProtocol
from cp_library.alg.tree.lca_table_iterative_cls import LCATable

class TreeProtocol(GraphProtocol):

    @cached_property
    def lca(T):
        return LCATable(T)
    
    @overload
    def diameter(T) -> int: ...
    @overload
    def diameter(T, endpoints: Literal[True]) -> tuple[int,int,int]: ...
    def diameter(T, endpoints = False):
        mask = (1 << (shift := T.N.bit_length())) - 1
        s = max(d << shift | v for v,d in enumerate(T.distance(0))) & mask
        dg = max(d << shift | v for v,d in enumerate(T.distance(s))) 
        diam, g = dg >> shift, dg & mask
        return (diam, s, g) if endpoints else diam
    
    @overload
    def distance(T) -> list[list[int]]: ...
    @overload
    def distance(T, s: int = 0) -> list[int]: ...
    @overload
    def distance(T, s: int, g: int) -> int: ...
    def distance(T, s = None, g = None):
        if s == None:
            return [T.dfs(u) for u in range(T.N)]
        else:
            return T.dfs(s, g)
            
    @overload
    def dfs(T, s: int = 0) -> list[int]: ...
    @overload
    def dfs(T, s: int, g: int) -> int: ...
    def dfs(T, s = 0, g = None):
        D = [inf for _ in range(T.N)]
        D[s] = 0
        state = [True for _ in range(T.N)]
        stack = [s]

        while stack:
            u = stack.pop()
            if u == g: return D[u]
            state[u] = False
            for v in T[u]:
                if state[v]:
                    D[v] = D[u]+1
                    stack.append(v)
        return D if g is None else inf 


    def dfs_events(G, flags: DFSFlags, s: int = 0):         
        events = []
        stack = [(s,-1)]
        adj = [None]*G.N


        while stack:
            u, p = stack[-1]
            
            if adj[u] is None:
                adj[u] = iter(G.neighbors(u))
                if DFSFlags.ENTER in flags:
                    events.append((DFSEvent.ENTER, u))
            
            if (v := next(adj[u], None)) is not None:
                if v == p:
                    if DFSFlags.BACK in flags:
                        events.append((DFSEvent.BACK, u, v))
                else:
                    if DFSFlags.DOWN in flags:
                        events.append((DFSEvent.DOWN, u, v))
                    stack.append((v,u))
            else:
                stack.pop()

                if DFSFlags.LEAVE in flags:
                    events.append((DFSEvent.LEAVE, u))
                if p != -1 and DFSFlags.UP in flags:
                    events.append((DFSEvent.UP, u, p))
        return events
    
    def euler_tour(T, s = 0):
        N = len(T)
        T.tin = tin = [-1] * N
        T.tout = tout = [-1] * N
        T.par = par = [-1] * N
        T.order = order = elist(2*N)
        T.delta = delta = elist(2*N)
        
        stack = elist(N)
        stack.append(s)

        while stack:
            u = stack.pop()
            p = par[u]
            
            if tin[u] == -1:
                tin[u] = len(order)
                
                for v in T[u]:
                    if v != p:
                        par[v] = u
                        stack.append(u)
                        stack.append(v)
                
                delta.append(1)
            else:
                delta.append(-1)
            
            order.append(u)
            tout[u] = len(order)
        delta[0] = delta[-1] = 0

    def hld_precomp(T, r = 0):
        N, time = T.N, 0
        tin, tout, size = [0]*N, [0]*N, [1]*N+[0]
        par, heavy, head = [-1]*N, [-1]*N, [r]*N
        depth, order, state = [0]*N, [0]*N, [0]*N
        stack = elist(N)
        stack.append(r)
        while stack:
            if (s := state[v := stack.pop()]) == 0: # dfs down
                p, state[v] = par[v], 1
                stack.append(v)
                for c in T[v]:
                    if c != p:
                        depth[c], par[c] = depth[v]+1, v
                        stack.append(c)

            elif s == 1: # dfs up
                p, l = par[v], -1
                for c in T[v]:
                    if c != p:
                        size[v] += size[c]
                        if size[c] > size[l]:
                            l = c
                heavy[v] = l
                if p == -1:
                    state[v] = 2
                    stack.append(v)

            elif s == 2: # decompose down
                p, h, l = par[v], head[v], heavy[v]
                tin[v], order[time], state[v] = time, v, 3
                time += 1
                stack.append(v)
                
                for c in T[v]:
                    if c != p and c != l:
                        head[c], state[c] = c, 2
                        stack.append(c)

                if l != -1:
                    head[l], state[l] = h, 2
                    stack.append(l)

            elif s == 3: # decompose up
                tout[v] = time
        T.size, T.depth = size, depth
        T.order, T.tin, T.tout = order, tin, tout
        T.par, T.heavy, T.head = par, heavy, head
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''

def elist(est_len: int) -> list: ...
try:
    from __pypy__ import newlist_hint
except:
    def newlist_hint(hint):
        return []
elist = newlist_hint
    


from typing import overload, Literal, Union
from functools import cached_property
from math import inf
from collections import deque


from enum import auto, IntFlag, IntEnum

class DFSFlags(IntFlag):
    ENTER = auto()
    DOWN = auto()
    BACK = auto()
    CROSS = auto()
    LEAVE = auto()
    UP = auto()
    MAXDEPTH = auto()

    RETURN_PARENTS = auto()
    RETURN_DEPTHS = auto()
    BACKTRACK = auto()
    CONNECT_ROOTS = auto()

    # Common combinations
    ALL_EDGES = DOWN | BACK | CROSS
    EULER_TOUR = DOWN | UP
    INTERVAL = ENTER | LEAVE
    TOPDOWN = DOWN | CONNECT_ROOTS
    BOTTOMUP = UP | CONNECT_ROOTS
    RETURN_ALL = RETURN_PARENTS | RETURN_DEPTHS

class DFSEvent(IntEnum):
    ENTER = DFSFlags.ENTER 
    DOWN = DFSFlags.DOWN 
    BACK = DFSFlags.BACK 
    CROSS = DFSFlags.CROSS 
    LEAVE = DFSFlags.LEAVE 
    UP = DFSFlags.UP 
    MAXDEPTH = DFSFlags.MAXDEPTH
    

import typing
from numbers import Number
from types import GenericAlias 
from typing import Callable, Collection, Iterator, Union
import os
import sys
from io import BytesIO, IOBase


class FastIO(IOBase):
    BUFSIZE = 8192
    newlines = 0

    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None

    def read(self):
        BUFSIZE = self.BUFSIZE
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()

    def readline(self):
        BUFSIZE = self.BUFSIZE
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()

    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)


class IOWrapper(IOBase):
    stdin: 'IOWrapper' = None
    stdout: 'IOWrapper' = None
    
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable

    def write(self, s):
        return self.buffer.write(s.encode("ascii"))
    
    def read(self):
        return self.buffer.read().decode("ascii")
    
    def readline(self):
        return self.buffer.readline().decode("ascii")

sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
from typing import TypeVar
_T = TypeVar('T')

class TokenStream(Iterator):
    stream = IOWrapper.stdin

    def __init__(self):
        self.queue = deque()

    def __next__(self):
        if not self.queue: self.queue.extend(self._line())
        return self.queue.popleft()
    
    def wait(self):
        if not self.queue: self.queue.extend(self._line())
        while self.queue: yield
 
    def _line(self):
        return TokenStream.stream.readline().split()

    def line(self):
        if self.queue:
            A = list(self.queue)
            self.queue.clear()
            return A
        return self._line()
TokenStream.default = TokenStream()

class CharStream(TokenStream):
    def _line(self):
        return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()


ParseFn = Callable[[TokenStream],_T]
class Parser:
    def __init__(self, spec: Union[type[_T],_T]):
        self.parse = Parser.compile(spec)

    def __call__(self, ts: TokenStream) -> _T:
        return self.parse(ts)
    
    @staticmethod
    def compile_type(cls: type[_T], args = ()) -> _T:
        if issubclass(cls, Parsable):
            return cls.compile(*args)
        elif issubclass(cls, (Number, str)):
            def parse(ts: TokenStream): return cls(next(ts))              
            return parse
        elif issubclass(cls, tuple):
            return Parser.compile_tuple(cls, args)
        elif issubclass(cls, Collection):
            return Parser.compile_collection(cls, args)
        elif callable(cls):
            def parse(ts: TokenStream):
                return cls(next(ts))              
            return parse
        else:
            raise NotImplementedError()
    
    @staticmethod
    def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
        if isinstance(spec, (type, GenericAlias)):
            cls = typing.get_origin(spec) or spec
            args = typing.get_args(spec) or tuple()
            return Parser.compile_type(cls, args)
        elif isinstance(offset := spec, Number): 
            cls = type(spec)  
            def parse(ts: TokenStream): return cls(next(ts)) + offset
            return parse
        elif isinstance(args := spec, tuple):      
            return Parser.compile_tuple(type(spec), args)
        elif isinstance(args := spec, Collection):  
            return Parser.compile_collection(type(spec), args)
        elif isinstance(fn := spec, Callable): 
            def parse(ts: TokenStream): return fn(next(ts))
            return parse
        else:
            raise NotImplementedError()

    @staticmethod
    def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
        if spec is int:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
            return parse
        else:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
            return parse

    @staticmethod
    def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
        fn = Parser.compile(spec)
        def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
        return parse

    @staticmethod
    def compile_children(cls: _T, specs) -> ParseFn[_T]:
        fns = tuple((Parser.compile(spec) for spec in specs))
        def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])  
        return parse
            
    @staticmethod
    def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
        if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
            return Parser.compile_line(cls, specs[0])
        else:
            return Parser.compile_children(cls, specs)

    @staticmethod
    def compile_collection(cls, specs):
        if not specs or len(specs) == 1 or isinstance(specs, set):
            return Parser.compile_line(cls, *specs)
        elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
            return Parser.compile_repeat(cls, specs[0], specs[1])
        else:
            raise NotImplementedError()

class Parsable:
    @classmethod
    def compile(cls):
        def parser(ts: TokenStream): return cls(next(ts))
        return parser
from typing import Iterable, Union, overload

class GraphProtocol(list, Parsable):
    def __init__(G, N: int, E: list = None, adj: Iterable = None):
        G.N = N
        if E is not None:
            G.M, G.E = len(E), E
        if adj is not None:
            super().__init__(adj)

    def neighbors(G, v: int) -> Iterable[int]:
        return G[v]
    
    def edge_ids(G) -> list[list[int]]: ...

    @overload
    def distance(G) -> list[list[int]]: ...
    @overload
    def distance(G, s: int = 0) -> list[int]: ...
    @overload
    def distance(G, s: int, g: int) -> int: ...
    def distance(G, s = None, g = None):
        if s == None:
            return G.floyd_warshall()
        else:
            return G.bfs(s, g)

    @overload
    def bfs(G, s: Union[int,list] = 0) -> list[int]: ...
    @overload
    def bfs(G, s: Union[int,list], g: int) -> int: ...
    def bfs(G, s = 0, g = None):
        D = [inf for _ in range(G.N)]
        q = deque([s] if isinstance(s, int) else s)
        for u in q: D[u] = 0
        while q:
            nd = D[u := q.popleft()]+1
            if u == g: return D[u]
            for v in G.neighbors(u):
                if nd < D[v]:
                    D[v] = nd
                    q.append(v)
        return D if g is None else inf 

    @overload
    def shortest_path(G, s: int, g: int) -> Union[list[int],None]: ...
    @overload
    def shortest_path(G, s: int, g: int, distances = True) -> tuple[Union[list[int],None],list[int]]: ...
    def shortest_path(G, s: int, g: int, distances = False) -> list[int]:
        D = [inf] * G.N
        D[s] = 0
        if s == g:
            return ([], D) if distances else []
            
        par = [-1] * G.N
        par_edge = [-1] * G.N
        Eid = G.edge_ids()
        q = deque([s])
        
        while q:
            nd = D[u := q.popleft()] + 1
            if u == g: break
                
            for v, eid in zip(G[u], Eid[u]):
                if nd < D[v]:
                    D[v] = nd
                    par[v] = u
                    par_edge[v] = eid
                    q.append(v)
        
        if D[g] == inf:
            return (None, D) if distances else None
            
        path = []
        current = g
        while current != s:
            path.append(par_edge[current])
            current = par[current]
            
        return (path[::-1], D) if distances else path[::-1]
            
     
            
        
    def floyd_warshall(G) -> list[list[int]]:
        D = [[inf]*G.N for _ in range(G.N)]

        for u in range(G.N):
            D[u][u] = 0
            for v in G.neighbors(u):
                D[u][v] = 1
        
        for k, Dk in enumerate(D):
            for Di in D:
                if Di[k] == inf: continue
                for j in range(G.N):
                    if Dk[j] == inf: continue
                    Di[j] = min(Di[j], Di[k]+Dk[j])
        return D
    
    def find_cycle(G, s = 0, vis = None, par = None):
        N = G.N
        vis = vis or [0] * N
        par = par or [-1] * N
        if vis[s]: return None
        vis[s] = 1
        stack = [(True, s)]
        while stack:
            forw, v = stack.pop()
            if forw:
                stack.append((False, v))
                vis[v] = 1
                for u in G.neighbors(v):
                    if vis[u] == 1 and u != par[v]:
                        # Cycle detected
                        cyc = [u]
                        vis[u] = 2
                        while v != u:
                            cyc.append(v)
                            vis[v] = 2
                            v = par[v]
                        return cyc
                    elif vis[u] == 0:
                        par[u] = v
                        stack.append((True, u))
            else:
                vis[v] = 2
        return None

    def find_minimal_cycle(G, s=0):
        D, par, que = [inf] * (N := G.N), [-1] * N, deque([s])
        D[s] = 0
        while que:
            for v in G[u := que.popleft()]:
                if v == s:  # Found cycle back to start
                    cycle = [u]
                    while u != s: cycle.append(u := par[u])
                    return cycle
                if D[v] < inf: continue
                D[v], par[v] = D[u]+1, u
                que.append(v)
    
    def bridges(G):
        tin = [-1] * G.N
        low = [-1] * G.N
        par = [-1] * G.N
        vis = [0] * G.N
        in_edge = [-1] * G.N

        Eid = G.edge_ids()
        time = 0
        bridges = []
        stack = list(range(G.N))
        while stack:
            p = par[v := stack.pop()]
            if vis[v] == 0:
                vis[v] = 1
                tin[v] = low[v] = time
                time += 1
                stack.append(v)
                for i, child in enumerate(G.neighbors(v)):
                    if child == p: continue
                    if vis[child] == 0: # Tree edge - recurse
                        par[child] = v
                        in_edge[child] = Eid[v][i]
                        stack.append(child)
                    else: # Back edge - update low-link value
                        low[v] = min(low[v], tin[child])
            elif vis[v] == 1:
                vis[v] = 2
                if p != -1:
                    low[p] = min(low[p], low[v])
                    if low[v] > tin[p]: bridges.append(in_edge[v])
        return bridges

    def articulation_points(G):
        '''
        Find articulation points in an undirected graph using DFS events.
        Returns a boolean list that is True for indices where the vertex is an articulation point.
        '''
        N = G.N
        order = [-1] * N
        low = [-1] * N
        par = [-1] * N
        state = [0] * N
        children = [0] * N
        ap = [False] * N
        time = 0
        stack = list(range(N))

        while stack:
            v = stack.pop()
            p = par[v]
            if state[v] == 0:
                state[v] = 1
                order[v] = low[v] = time
                time += 1
            
                stack.append(v)
                for child in G[v]:
                    if order[child] == -1:
                        par[child] = v
                        stack.append(child)
                    elif child != p:
                        low[v] = min(low[v], order[child])
                if p != -1:
                    children[p] += 1
            elif state[v] == 1:
                state[v] = 2
                ap[v] |= p == -1 and children[v] > 1
                if p != -1:
                    low[p] = min(low[p], low[v])
                    ap[p] |= par[p] != -1 and low[v] >= order[p]

        return ap
    
    def dfs_events(G, flags: DFSFlags, s: Union[int,list,None] = None, max_depth: Union[int,None] = None):
        if flags == DFSFlags.INTERVAL:
            if max_depth is None:
                return G.dfs_enter_leave(s)
        elif flags == DFSFlags.DOWN or flags == DFSFlags.TOPDOWN:
            if max_depth is None:
                edges = G.dfs_topdown(s, DFSFlags.CONNECT_ROOTS in flags)
                return [(DFSEvent.DOWN, p, u) for p,u in edges]
        elif flags == DFSFlags.UP or flags == DFSFlags.BOTTOMUP:
            if max_depth is None:
                edges = G.dfs_bottomup(s, DFSFlags.CONNECT_ROOTS in flags)
                return [(DFSEvent.UP, p, u) for p,u in edges]
        elif flags & DFSFlags.BACKTRACK:
            return G.dfs_backtrack(flags, s, max_depth)
        state = [0] * G.N
        child = [0] * G.N
        stack = [0] * G.N
        if flags & DFSFlags.RETURN_PARENTS:
            parents = [-1] * G.N
        if flags & DFSFlags.RETURN_DEPTHS:
            depths = [-1] * G.N

        events = []
        for s in G.starts(s):
            stack[depth := 0] = s
            if (DFSFlags.DOWN|DFSFlags.CONNECT_ROOTS) in flags:
                events.append((DFSEvent.DOWN,-1,s))
            while depth != -1:
                u = stack[depth]
                
                if not state[u]:
                    state[u] = 1
                    if flags & DFSFlags.ENTER:
                        events.append((DFSEvent.ENTER, u))
                    if flags & DFSFlags.RETURN_DEPTHS:
                        depths[u] = depth
                
                if (c := child[u]) < len(G[u]):
                    child[u] += 1
                    if (s := state[v := G[u][c]]) == 0: # Unvisited
                        if max_depth is None or depth <= max_depth:
                            if flags & DFSFlags.DOWN:
                                events.append((DFSEvent.DOWN, u, v))
                            stack[depth := depth+1] = v
                            if flags & DFSFlags.RETURN_PARENTS:
                                parents[v] = u
                    elif s == 1:  # In progress
                        if flags & DFSFlags.BACK:
                            events.append((DFSEvent.BACK, u, v))
                    elif s == 2: # Completed
                        if flags & DFSFlags.CROSS:
                            events.append((DFSEvent.CROSS, u, v))
                else:
                    depth -= 1
                    state[u] = 0 if DFSFlags.BACKTRACK in flags else 2
                    if flags & DFSFlags.LEAVE:
                        events.append((DFSEvent.LEAVE, u))
                    if depth != -1 and flags & DFSFlags.UP:
                        events.append((DFSEvent.UP, stack[depth], u))
            if (DFSFlags.UP|DFSFlags.CONNECT_ROOTS) in flags:
                events.append((DFSEvent.UP,-1,s))
        ret = tuple((events,)) if DFSFlags.RETURN_ALL & flags else events
        if DFSFlags.RETURN_PARENTS in flags:
            ret += (parents,)
        if DFSFlags.RETURN_DEPTHS in flags:
            ret += (depths,)
        return ret

    def dfs_backtrack(G, flags: DFSFlags, s: Union[int,list] = None, max_depth: Union[int,None] = None):
        stack_depth = (max_depth+1 if max_depth is not None else G.N)
        stack = [0]*stack_depth
        child = [0]*stack_depth
        state = [0]*G.N
        events: list[tuple[DFSEvent, int]|tuple[DFSEvent, int, int]] = []

        for s in G.starts(s):
            if state[s]: continue
            state[s] = 1
            stack[depth := 0] = s
            if DFSFlags.DOWN|DFSFlags.CONNECT_ROOTS in flags:
                events.append((DFSEvent.DOWN,-1,s))
            while depth != -1:
                u = stack[depth]
                if state[u] == 1:
                    state[u] = 2
                    if DFSFlags.ENTER in flags:
                        events.append((DFSEvent.ENTER,u))
                    if max_depth is not None and depth >= max_depth:
                        child[depth] = len(G[u])
                        if DFSFlags.MAXDEPTH in flags:
                            events.append((DFSEvent.MAXDEPTH,u))

                if (c := child[depth]) < len(G[u]):
                    child[depth] += 1
                    if state[v := G[u][c]]:
                        if DFSFlags.BACK in flags:
                            events.append((DFSEvent.BACK,u,v))
                        continue
                    state[v] = 1
                    if DFSFlags.DOWN in flags:
                        events.append((DFSEvent.DOWN,u,v))
                    stack[depth := depth+1] = v
                else:
                    state[u] = 0
                    if DFSFlags.LEAVE in flags:
                        events.append((DFSEvent.LEAVE,u))
                    child[depth] = 0
                    depth -= 1
                    if depth and DFSFlags.UP in flags:
                        events.append((DFSEvent.UP, stack[depth], u))
            if DFSFlags.UP|DFSFlags.CONNECT_ROOTS in flags:
                events.append((DFSEvent.UP,-1,s))
        return events

    def dfs_enter_leave(G, s: Union[int,list,None] = None):
        state = [True] * G.N
        child: list[int] = elist(G.N)
        stack: list[int] = elist(G.N)

        events = []
        for s in G.starts(s):
            if not state[s]: continue
            stack.append(s)
            child.append(0)
            
            while stack:
                u = stack[-1]
                
                if state[u]:
                    state[u] = False
                    events.append((DFSEvent.ENTER, u))

                
                if (c := child[-1]) < len(G[u]):
                    child[-1] += 1
                    if state[v := G[u][c]]:
                        stack.append(v)
                        child.append(0)
                else:
                    stack.pop()
                    child.pop()
                    events.append((DFSEvent.LEAVE, u))

        return events
    
    def dfs_topdown(G, s: Union[int,list,None] = None, connect_roots = False):
        '''Returns list of (u,v) representing u->v edges in order of top down discovery'''
        stack: list[int] = elist(G.N)
        vis = [False]*G.N
        edges: list[tuple[int,int]] = elist(G.N)

        for s in G.starts(s):
            if vis[s]: continue
            if connect_roots:
                edges.append((-1,s))
            vis[s] = True
            stack.append(s)
            while stack:
                u = stack.pop()
                for v in G[u]:
                    if vis[v]: continue
                    vis[v] = True
                    edges.append((u,v))
                    stack.append(v)
        return edges
    
    def dfs_bottomup(G, s: Union[int,list,None] = None, connect_roots = False):
        '''Returns list of (p,u) representing p->u edges in bottom up order'''
        edges = G.dfs_topdown(s, connect_roots)
        edges.reverse()
        return edges

    def is_bipartite(G):
        N = G.N
        que = deque()
        color = [-1]*N
                
        for s in range(N):
            if color[s] >= 0:
                continue
            color[s] = 1
            que.append(s)
            while que:
                u = que.popleft()
                for v in G[u]:
                    if color[v] == -1:
                        color[v] = 1 - color[u]
                        que.append(v)
                    elif color[v] == color[u]:
                        return False
        return True
    
    def starts(G, v: Union[int,list,None]) -> Iterable:
        if isinstance(v, int):
            return (v,)
        elif v is None:
            return range(G.N)
        else:
            return v

    @classmethod
    def compile(cls, N: int, M: int, E):
        edge = Parser.compile(E)
        def parse(ts: TokenStream):
            return cls(N, [edge(ts) for _ in range(M)])
        return parse
    


def sort2(a, b):
    return (a,b) if a < b else (b,a)

import operator
from itertools import accumulate

def presum(iter: Iterable[_T], func: Callable[[_T,_T],_T] = None, initial: _T = None, step = 1) -> list[_T]:
    if step == 1:
        return list(accumulate(iter, func, initial=initial))
    else:
        assert step >= 2
        if func is None:
            func = operator.add
        A = list(iter)
        if initial is not None:
            A = [initial] + A
        for i in range(step,len(A)):
            A[i] = func(A[i], A[i-step])
        return A
# from typing import Generic
# from cp_library.misc.typing import _T

def min2(a, b):
    return a if a < b else b


class MinSparseTable:
    def __init__(st, arr: list):
        st.N = N = len(arr)
        st.log = N.bit_length()
        st.data = data = [0] * (st.log*N)
        data[:N] = arr 
        for i in range(1,st.log):
            a, b, c = i*N, (i-1)*N, (i-1)*N + (1 << (i-1))
            for j in range(N - (1 << i) + 1):
                data[a+j] = min2(data[b+j], data[c+j])

    def query(st, l: int, r: int):
        k = (r-l).bit_length() - 1
        return min2(st.data[k*st.N + l], st.data[k*st.N + r - (1<<k)])
    

class LCATable(MinSparseTable):
    def __init__(lca, T, root = 0):
        N = len(T)
        T.euler_tour(root)
        lca.depth = depth = presum(T.delta)
        lca.tin, lca.tout = T.tin[:], T.tout[:]
        lca.mask = (1 << (shift := N.bit_length()))-1
        lca.shift = shift
        order = T.order
        M = len(order)
        packets = [0]*M
        for i in range(M):
            packets[i] = depth[i] << shift | order[i] 
        super().__init__(packets)

    def _query(lca, u, v):
        l, r = sort2(lca.tin[u], lca.tin[v]); r += 1
        da = super().query(l, r)
        return l, r, da & lca.mask, da >> lca.shift

    def query(lca, u, v) -> tuple[int,int]:
        l, r, a, d = lca._query(u, v)
        return a, d
    
    def distance(lca, u, v) -> int:
        l, r, a, d = lca._query(u, v)
        return lca.depth[l] + lca.depth[r-1] - 2*d
    
    def path(lca, u, v):
        path, par, lca, c = [], lca.T.par, lca.query(u, v)[0], u
        while c != lca:
            path.append(c)
            c = par[c]
        path.append(lca)
        rev_path, c = [], v
        while c != lca:
            rev_path.append(c)
            c = par[c]
        path.extend(reversed(rev_path))
        return path

class TreeProtocol(GraphProtocol):

    @cached_property
    def lca(T):
        return LCATable(T)
    
    @overload
    def diameter(T) -> int: ...
    @overload
    def diameter(T, endpoints: Literal[True]) -> tuple[int,int,int]: ...
    def diameter(T, endpoints = False):
        mask = (1 << (shift := T.N.bit_length())) - 1
        s = max(d << shift | v for v,d in enumerate(T.distance(0))) & mask
        dg = max(d << shift | v for v,d in enumerate(T.distance(s))) 
        diam, g = dg >> shift, dg & mask
        return (diam, s, g) if endpoints else diam
    
    @overload
    def distance(T) -> list[list[int]]: ...
    @overload
    def distance(T, s: int = 0) -> list[int]: ...
    @overload
    def distance(T, s: int, g: int) -> int: ...
    def distance(T, s = None, g = None):
        if s == None:
            return [T.dfs(u) for u in range(T.N)]
        else:
            return T.dfs(s, g)
            
    @overload
    def dfs(T, s: int = 0) -> list[int]: ...
    @overload
    def dfs(T, s: int, g: int) -> int: ...
    def dfs(T, s = 0, g = None):
        D = [inf for _ in range(T.N)]
        D[s] = 0
        state = [True for _ in range(T.N)]
        stack = [s]

        while stack:
            u = stack.pop()
            if u == g: return D[u]
            state[u] = False
            for v in T[u]:
                if state[v]:
                    D[v] = D[u]+1
                    stack.append(v)
        return D if g is None else inf 


    def dfs_events(G, flags: DFSFlags, s: int = 0):         
        events = []
        stack = [(s,-1)]
        adj = [None]*G.N


        while stack:
            u, p = stack[-1]
            
            if adj[u] is None:
                adj[u] = iter(G.neighbors(u))
                if DFSFlags.ENTER in flags:
                    events.append((DFSEvent.ENTER, u))
            
            if (v := next(adj[u], None)) is not None:
                if v == p:
                    if DFSFlags.BACK in flags:
                        events.append((DFSEvent.BACK, u, v))
                else:
                    if DFSFlags.DOWN in flags:
                        events.append((DFSEvent.DOWN, u, v))
                    stack.append((v,u))
            else:
                stack.pop()

                if DFSFlags.LEAVE in flags:
                    events.append((DFSEvent.LEAVE, u))
                if p != -1 and DFSFlags.UP in flags:
                    events.append((DFSEvent.UP, u, p))
        return events
    
    def euler_tour(T, s = 0):
        N = len(T)
        T.tin = tin = [-1] * N
        T.tout = tout = [-1] * N
        T.par = par = [-1] * N
        T.order = order = elist(2*N)
        T.delta = delta = elist(2*N)
        
        stack = elist(N)
        stack.append(s)

        while stack:
            u = stack.pop()
            p = par[u]
            
            if tin[u] == -1:
                tin[u] = len(order)
                
                for v in T[u]:
                    if v != p:
                        par[v] = u
                        stack.append(u)
                        stack.append(v)
                
                delta.append(1)
            else:
                delta.append(-1)
            
            order.append(u)
            tout[u] = len(order)
        delta[0] = delta[-1] = 0

    def hld_precomp(T, r = 0):
        N, time = T.N, 0
        tin, tout, size = [0]*N, [0]*N, [1]*N+[0]
        par, heavy, head = [-1]*N, [-1]*N, [r]*N
        depth, order, state = [0]*N, [0]*N, [0]*N
        stack = elist(N)
        stack.append(r)
        while stack:
            if (s := state[v := stack.pop()]) == 0: # dfs down
                p, state[v] = par[v], 1
                stack.append(v)
                for c in T[v]:
                    if c != p:
                        depth[c], par[c] = depth[v]+1, v
                        stack.append(c)

            elif s == 1: # dfs up
                p, l = par[v], -1
                for c in T[v]:
                    if c != p:
                        size[v] += size[c]
                        if size[c] > size[l]:
                            l = c
                heavy[v] = l
                if p == -1:
                    state[v] = 2
                    stack.append(v)

            elif s == 2: # decompose down
                p, h, l = par[v], head[v], heavy[v]
                tin[v], order[time], state[v] = time, v, 3
                time += 1
                stack.append(v)
                
                for c in T[v]:
                    if c != p and c != l:
                        head[c], state[c] = c, 2
                        stack.append(c)

                if l != -1:
                    head[l], state[l] = h, 2
                    stack.append(l)

            elif s == 3: # decompose up
                tout[v] = time
        T.size, T.depth = size, depth
        T.order, T.tin, T.tout = order, tin, tout
        T.par, T.heavy, T.head = par, heavy, head
Back to top page