This documentation is automatically generated by online-judge-tools/verification-helper
from cp_library.ds.elist_fn import elist
import cp_library.alg.tree.__header__
from typing import overload, Literal, Union
from functools import cached_property
from math import inf
from collections import deque
from cp_library.alg.graph.dfs_options_cls import DFSFlags, DFSEvent
from cp_library.alg.graph.graph_proto import GraphProtocol
from cp_library.alg.tree.lca_table_iterative_cls import LCATable
class TreeProtocol(GraphProtocol):
@cached_property
def lca(T):
return LCATable(T)
@overload
def diameter(T) -> int: ...
@overload
def diameter(T, endpoints: Literal[True]) -> tuple[int,int,int]: ...
def diameter(T, endpoints = False):
mask = (1 << (shift := T.N.bit_length())) - 1
s = max(d << shift | v for v,d in enumerate(T.distance(0))) & mask
dg = max(d << shift | v for v,d in enumerate(T.distance(s)))
diam, g = dg >> shift, dg & mask
return (diam, s, g) if endpoints else diam
@overload
def distance(T) -> list[list[int]]: ...
@overload
def distance(T, s: int = 0) -> list[int]: ...
@overload
def distance(T, s: int, g: int) -> int: ...
def distance(T, s = None, g = None):
if s == None:
return [T.dfs(u) for u in range(T.N)]
else:
return T.dfs(s, g)
@overload
def dfs(T, s: int = 0) -> list[int]: ...
@overload
def dfs(T, s: int, g: int) -> int: ...
def dfs(T, s = 0, g = None):
D = [inf for _ in range(T.N)]
D[s] = 0
state = [True for _ in range(T.N)]
stack = [s]
while stack:
u = stack.pop()
if u == g: return D[u]
state[u] = False
for v in T[u]:
if state[v]:
D[v] = D[u]+1
stack.append(v)
return D if g is None else inf
def dfs_events(G, flags: DFSFlags, s: int = 0):
events = []
stack = [(s,-1)]
adj = [None]*G.N
while stack:
u, p = stack[-1]
if adj[u] is None:
adj[u] = iter(G.neighbors(u))
if DFSFlags.ENTER in flags:
events.append((DFSEvent.ENTER, u))
if (v := next(adj[u], None)) is not None:
if v == p:
if DFSFlags.BACK in flags:
events.append((DFSEvent.BACK, u, v))
else:
if DFSFlags.DOWN in flags:
events.append((DFSEvent.DOWN, u, v))
stack.append((v,u))
else:
stack.pop()
if DFSFlags.LEAVE in flags:
events.append((DFSEvent.LEAVE, u))
if p != -1 and DFSFlags.UP in flags:
events.append((DFSEvent.UP, u, p))
return events
def euler_tour(T, s = 0):
N = len(T)
T.tin = tin = [-1] * N
T.tout = tout = [-1] * N
T.par = par = [-1] * N
T.order = order = elist(2*N)
T.delta = delta = elist(2*N)
stack = elist(N)
stack.append(s)
while stack:
u = stack.pop()
p = par[u]
if tin[u] == -1:
tin[u] = len(order)
for v in T[u]:
if v != p:
par[v] = u
stack.append(u)
stack.append(v)
delta.append(1)
else:
delta.append(-1)
order.append(u)
tout[u] = len(order)
delta[0] = delta[-1] = 0
def hld_precomp(T, r = 0):
N, time = T.N, 0
tin, tout, size = [0]*N, [0]*N, [1]*N+[0]
par, heavy, head = [-1]*N, [-1]*N, [r]*N
depth, order, state = [0]*N, [0]*N, [0]*N
stack = elist(N)
stack.append(r)
while stack:
if (s := state[v := stack.pop()]) == 0: # dfs down
p, state[v] = par[v], 1
stack.append(v)
for c in T[v]:
if c != p:
depth[c], par[c] = depth[v]+1, v
stack.append(c)
elif s == 1: # dfs up
p, l = par[v], -1
for c in T[v]:
if c != p:
size[v] += size[c]
if size[c] > size[l]:
l = c
heavy[v] = l
if p == -1:
state[v] = 2
stack.append(v)
elif s == 2: # decompose down
p, h, l = par[v], head[v], heavy[v]
tin[v], order[time], state[v] = time, v, 3
time += 1
stack.append(v)
for c in T[v]:
if c != p and c != l:
head[c], state[c] = c, 2
stack.append(c)
if l != -1:
head[l], state[l] = h, 2
stack.append(l)
elif s == 3: # decompose up
tout[v] = time
T.size, T.depth = size, depth
T.order, T.tin, T.tout = order, tin, tout
T.par, T.heavy, T.head = par, heavy, head
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
def elist(est_len: int) -> list: ...
try:
from __pypy__ import newlist_hint
except:
def newlist_hint(hint):
return []
elist = newlist_hint
from typing import overload, Literal, Union
from functools import cached_property
from math import inf
from collections import deque
from enum import auto, IntFlag, IntEnum
class DFSFlags(IntFlag):
ENTER = auto()
DOWN = auto()
BACK = auto()
CROSS = auto()
LEAVE = auto()
UP = auto()
MAXDEPTH = auto()
RETURN_PARENTS = auto()
RETURN_DEPTHS = auto()
BACKTRACK = auto()
CONNECT_ROOTS = auto()
# Common combinations
ALL_EDGES = DOWN | BACK | CROSS
EULER_TOUR = DOWN | UP
INTERVAL = ENTER | LEAVE
TOPDOWN = DOWN | CONNECT_ROOTS
BOTTOMUP = UP | CONNECT_ROOTS
RETURN_ALL = RETURN_PARENTS | RETURN_DEPTHS
class DFSEvent(IntEnum):
ENTER = DFSFlags.ENTER
DOWN = DFSFlags.DOWN
BACK = DFSFlags.BACK
CROSS = DFSFlags.CROSS
LEAVE = DFSFlags.LEAVE
UP = DFSFlags.UP
MAXDEPTH = DFSFlags.MAXDEPTH
import typing
from numbers import Number
from types import GenericAlias
from typing import Callable, Collection, Iterator, Union
import os
import sys
from io import BytesIO, IOBase
class FastIO(IOBase):
BUFSIZE = 8192
newlines = 0
def __init__(self, file):
self._fd = file.fileno()
self.buffer = BytesIO()
self.writable = "x" in file.mode or "r" not in file.mode
self.write = self.buffer.write if self.writable else None
def read(self):
BUFSIZE = self.BUFSIZE
while True:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
if not b:
break
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines = 0
return self.buffer.read()
def readline(self):
BUFSIZE = self.BUFSIZE
while self.newlines == 0:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
self.newlines = b.count(b"\n") + (not b)
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines -= 1
return self.buffer.readline()
def flush(self):
if self.writable:
os.write(self._fd, self.buffer.getvalue())
self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
stdin: 'IOWrapper' = None
stdout: 'IOWrapper' = None
def __init__(self, file):
self.buffer = FastIO(file)
self.flush = self.buffer.flush
self.writable = self.buffer.writable
def write(self, s):
return self.buffer.write(s.encode("ascii"))
def read(self):
return self.buffer.read().decode("ascii")
def readline(self):
return self.buffer.readline().decode("ascii")
sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
from typing import TypeVar
_T = TypeVar('T')
class TokenStream(Iterator):
stream = IOWrapper.stdin
def __init__(self):
self.queue = deque()
def __next__(self):
if not self.queue: self.queue.extend(self._line())
return self.queue.popleft()
def wait(self):
if not self.queue: self.queue.extend(self._line())
while self.queue: yield
def _line(self):
return TokenStream.stream.readline().split()
def line(self):
if self.queue:
A = list(self.queue)
self.queue.clear()
return A
return self._line()
TokenStream.default = TokenStream()
class CharStream(TokenStream):
def _line(self):
return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()
ParseFn = Callable[[TokenStream],_T]
class Parser:
def __init__(self, spec: Union[type[_T],_T]):
self.parse = Parser.compile(spec)
def __call__(self, ts: TokenStream) -> _T:
return self.parse(ts)
@staticmethod
def compile_type(cls: type[_T], args = ()) -> _T:
if issubclass(cls, Parsable):
return cls.compile(*args)
elif issubclass(cls, (Number, str)):
def parse(ts: TokenStream): return cls(next(ts))
return parse
elif issubclass(cls, tuple):
return Parser.compile_tuple(cls, args)
elif issubclass(cls, Collection):
return Parser.compile_collection(cls, args)
elif callable(cls):
def parse(ts: TokenStream):
return cls(next(ts))
return parse
else:
raise NotImplementedError()
@staticmethod
def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
if isinstance(spec, (type, GenericAlias)):
cls = typing.get_origin(spec) or spec
args = typing.get_args(spec) or tuple()
return Parser.compile_type(cls, args)
elif isinstance(offset := spec, Number):
cls = type(spec)
def parse(ts: TokenStream): return cls(next(ts)) + offset
return parse
elif isinstance(args := spec, tuple):
return Parser.compile_tuple(type(spec), args)
elif isinstance(args := spec, Collection):
return Parser.compile_collection(type(spec), args)
elif isinstance(fn := spec, Callable):
def parse(ts: TokenStream): return fn(next(ts))
return parse
else:
raise NotImplementedError()
@staticmethod
def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
if spec is int:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
return parse
else:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
return parse
@staticmethod
def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
return parse
@staticmethod
def compile_children(cls: _T, specs) -> ParseFn[_T]:
fns = tuple((Parser.compile(spec) for spec in specs))
def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])
return parse
@staticmethod
def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
return Parser.compile_line(cls, specs[0])
else:
return Parser.compile_children(cls, specs)
@staticmethod
def compile_collection(cls, specs):
if not specs or len(specs) == 1 or isinstance(specs, set):
return Parser.compile_line(cls, *specs)
elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
return Parser.compile_repeat(cls, specs[0], specs[1])
else:
raise NotImplementedError()
class Parsable:
@classmethod
def compile(cls):
def parser(ts: TokenStream): return cls(next(ts))
return parser
from typing import Iterable, Union, overload
class GraphProtocol(list, Parsable):
def __init__(G, N: int, E: list = None, adj: Iterable = None):
G.N = N
if E is not None:
G.M, G.E = len(E), E
if adj is not None:
super().__init__(adj)
def neighbors(G, v: int) -> Iterable[int]:
return G[v]
def edge_ids(G) -> list[list[int]]: ...
@overload
def distance(G) -> list[list[int]]: ...
@overload
def distance(G, s: int = 0) -> list[int]: ...
@overload
def distance(G, s: int, g: int) -> int: ...
def distance(G, s = None, g = None):
if s == None:
return G.floyd_warshall()
else:
return G.bfs(s, g)
@overload
def bfs(G, s: Union[int,list] = 0) -> list[int]: ...
@overload
def bfs(G, s: Union[int,list], g: int) -> int: ...
def bfs(G, s = 0, g = None):
D = [inf for _ in range(G.N)]
q = deque([s] if isinstance(s, int) else s)
for u in q: D[u] = 0
while q:
nd = D[u := q.popleft()]+1
if u == g: return D[u]
for v in G.neighbors(u):
if nd < D[v]:
D[v] = nd
q.append(v)
return D if g is None else inf
@overload
def shortest_path(G, s: int, g: int) -> Union[list[int],None]: ...
@overload
def shortest_path(G, s: int, g: int, distances = True) -> tuple[Union[list[int],None],list[int]]: ...
def shortest_path(G, s: int, g: int, distances = False) -> list[int]:
D = [inf] * G.N
D[s] = 0
if s == g:
return ([], D) if distances else []
par = [-1] * G.N
par_edge = [-1] * G.N
Eid = G.edge_ids()
q = deque([s])
while q:
nd = D[u := q.popleft()] + 1
if u == g: break
for v, eid in zip(G[u], Eid[u]):
if nd < D[v]:
D[v] = nd
par[v] = u
par_edge[v] = eid
q.append(v)
if D[g] == inf:
return (None, D) if distances else None
path = []
current = g
while current != s:
path.append(par_edge[current])
current = par[current]
return (path[::-1], D) if distances else path[::-1]
def floyd_warshall(G) -> list[list[int]]:
D = [[inf]*G.N for _ in range(G.N)]
for u in range(G.N):
D[u][u] = 0
for v in G.neighbors(u):
D[u][v] = 1
for k, Dk in enumerate(D):
for Di in D:
if Di[k] == inf: continue
for j in range(G.N):
if Dk[j] == inf: continue
Di[j] = min(Di[j], Di[k]+Dk[j])
return D
def find_cycle(G, s = 0, vis = None, par = None):
N = G.N
vis = vis or [0] * N
par = par or [-1] * N
if vis[s]: return None
vis[s] = 1
stack = [(True, s)]
while stack:
forw, v = stack.pop()
if forw:
stack.append((False, v))
vis[v] = 1
for u in G.neighbors(v):
if vis[u] == 1 and u != par[v]:
# Cycle detected
cyc = [u]
vis[u] = 2
while v != u:
cyc.append(v)
vis[v] = 2
v = par[v]
return cyc
elif vis[u] == 0:
par[u] = v
stack.append((True, u))
else:
vis[v] = 2
return None
def find_minimal_cycle(G, s=0):
D, par, que = [inf] * (N := G.N), [-1] * N, deque([s])
D[s] = 0
while que:
for v in G[u := que.popleft()]:
if v == s: # Found cycle back to start
cycle = [u]
while u != s: cycle.append(u := par[u])
return cycle
if D[v] < inf: continue
D[v], par[v] = D[u]+1, u
que.append(v)
def bridges(G):
tin = [-1] * G.N
low = [-1] * G.N
par = [-1] * G.N
vis = [0] * G.N
in_edge = [-1] * G.N
Eid = G.edge_ids()
time = 0
bridges = []
stack = list(range(G.N))
while stack:
p = par[v := stack.pop()]
if vis[v] == 0:
vis[v] = 1
tin[v] = low[v] = time
time += 1
stack.append(v)
for i, child in enumerate(G.neighbors(v)):
if child == p: continue
if vis[child] == 0: # Tree edge - recurse
par[child] = v
in_edge[child] = Eid[v][i]
stack.append(child)
else: # Back edge - update low-link value
low[v] = min(low[v], tin[child])
elif vis[v] == 1:
vis[v] = 2
if p != -1:
low[p] = min(low[p], low[v])
if low[v] > tin[p]: bridges.append(in_edge[v])
return bridges
def articulation_points(G):
'''
Find articulation points in an undirected graph using DFS events.
Returns a boolean list that is True for indices where the vertex is an articulation point.
'''
N = G.N
order = [-1] * N
low = [-1] * N
par = [-1] * N
state = [0] * N
children = [0] * N
ap = [False] * N
time = 0
stack = list(range(N))
while stack:
v = stack.pop()
p = par[v]
if state[v] == 0:
state[v] = 1
order[v] = low[v] = time
time += 1
stack.append(v)
for child in G[v]:
if order[child] == -1:
par[child] = v
stack.append(child)
elif child != p:
low[v] = min(low[v], order[child])
if p != -1:
children[p] += 1
elif state[v] == 1:
state[v] = 2
ap[v] |= p == -1 and children[v] > 1
if p != -1:
low[p] = min(low[p], low[v])
ap[p] |= par[p] != -1 and low[v] >= order[p]
return ap
def dfs_events(G, flags: DFSFlags, s: Union[int,list,None] = None, max_depth: Union[int,None] = None):
if flags == DFSFlags.INTERVAL:
if max_depth is None:
return G.dfs_enter_leave(s)
elif flags == DFSFlags.DOWN or flags == DFSFlags.TOPDOWN:
if max_depth is None:
edges = G.dfs_topdown(s, DFSFlags.CONNECT_ROOTS in flags)
return [(DFSEvent.DOWN, p, u) for p,u in edges]
elif flags == DFSFlags.UP or flags == DFSFlags.BOTTOMUP:
if max_depth is None:
edges = G.dfs_bottomup(s, DFSFlags.CONNECT_ROOTS in flags)
return [(DFSEvent.UP, p, u) for p,u in edges]
elif flags & DFSFlags.BACKTRACK:
return G.dfs_backtrack(flags, s, max_depth)
state = [0] * G.N
child = [0] * G.N
stack = [0] * G.N
if flags & DFSFlags.RETURN_PARENTS:
parents = [-1] * G.N
if flags & DFSFlags.RETURN_DEPTHS:
depths = [-1] * G.N
events = []
for s in G.starts(s):
stack[depth := 0] = s
if (DFSFlags.DOWN|DFSFlags.CONNECT_ROOTS) in flags:
events.append((DFSEvent.DOWN,-1,s))
while depth != -1:
u = stack[depth]
if not state[u]:
state[u] = 1
if flags & DFSFlags.ENTER:
events.append((DFSEvent.ENTER, u))
if flags & DFSFlags.RETURN_DEPTHS:
depths[u] = depth
if (c := child[u]) < len(G[u]):
child[u] += 1
if (s := state[v := G[u][c]]) == 0: # Unvisited
if max_depth is None or depth <= max_depth:
if flags & DFSFlags.DOWN:
events.append((DFSEvent.DOWN, u, v))
stack[depth := depth+1] = v
if flags & DFSFlags.RETURN_PARENTS:
parents[v] = u
elif s == 1: # In progress
if flags & DFSFlags.BACK:
events.append((DFSEvent.BACK, u, v))
elif s == 2: # Completed
if flags & DFSFlags.CROSS:
events.append((DFSEvent.CROSS, u, v))
else:
depth -= 1
state[u] = 0 if DFSFlags.BACKTRACK in flags else 2
if flags & DFSFlags.LEAVE:
events.append((DFSEvent.LEAVE, u))
if depth != -1 and flags & DFSFlags.UP:
events.append((DFSEvent.UP, stack[depth], u))
if (DFSFlags.UP|DFSFlags.CONNECT_ROOTS) in flags:
events.append((DFSEvent.UP,-1,s))
ret = tuple((events,)) if DFSFlags.RETURN_ALL & flags else events
if DFSFlags.RETURN_PARENTS in flags:
ret += (parents,)
if DFSFlags.RETURN_DEPTHS in flags:
ret += (depths,)
return ret
def dfs_backtrack(G, flags: DFSFlags, s: Union[int,list] = None, max_depth: Union[int,None] = None):
stack_depth = (max_depth+1 if max_depth is not None else G.N)
stack = [0]*stack_depth
child = [0]*stack_depth
state = [0]*G.N
events: list[tuple[DFSEvent, int]|tuple[DFSEvent, int, int]] = []
for s in G.starts(s):
if state[s]: continue
state[s] = 1
stack[depth := 0] = s
if DFSFlags.DOWN|DFSFlags.CONNECT_ROOTS in flags:
events.append((DFSEvent.DOWN,-1,s))
while depth != -1:
u = stack[depth]
if state[u] == 1:
state[u] = 2
if DFSFlags.ENTER in flags:
events.append((DFSEvent.ENTER,u))
if max_depth is not None and depth >= max_depth:
child[depth] = len(G[u])
if DFSFlags.MAXDEPTH in flags:
events.append((DFSEvent.MAXDEPTH,u))
if (c := child[depth]) < len(G[u]):
child[depth] += 1
if state[v := G[u][c]]:
if DFSFlags.BACK in flags:
events.append((DFSEvent.BACK,u,v))
continue
state[v] = 1
if DFSFlags.DOWN in flags:
events.append((DFSEvent.DOWN,u,v))
stack[depth := depth+1] = v
else:
state[u] = 0
if DFSFlags.LEAVE in flags:
events.append((DFSEvent.LEAVE,u))
child[depth] = 0
depth -= 1
if depth and DFSFlags.UP in flags:
events.append((DFSEvent.UP, stack[depth], u))
if DFSFlags.UP|DFSFlags.CONNECT_ROOTS in flags:
events.append((DFSEvent.UP,-1,s))
return events
def dfs_enter_leave(G, s: Union[int,list,None] = None):
state = [True] * G.N
child: list[int] = elist(G.N)
stack: list[int] = elist(G.N)
events = []
for s in G.starts(s):
if not state[s]: continue
stack.append(s)
child.append(0)
while stack:
u = stack[-1]
if state[u]:
state[u] = False
events.append((DFSEvent.ENTER, u))
if (c := child[-1]) < len(G[u]):
child[-1] += 1
if state[v := G[u][c]]:
stack.append(v)
child.append(0)
else:
stack.pop()
child.pop()
events.append((DFSEvent.LEAVE, u))
return events
def dfs_topdown(G, s: Union[int,list,None] = None, connect_roots = False):
'''Returns list of (u,v) representing u->v edges in order of top down discovery'''
stack: list[int] = elist(G.N)
vis = [False]*G.N
edges: list[tuple[int,int]] = elist(G.N)
for s in G.starts(s):
if vis[s]: continue
if connect_roots:
edges.append((-1,s))
vis[s] = True
stack.append(s)
while stack:
u = stack.pop()
for v in G[u]:
if vis[v]: continue
vis[v] = True
edges.append((u,v))
stack.append(v)
return edges
def dfs_bottomup(G, s: Union[int,list,None] = None, connect_roots = False):
'''Returns list of (p,u) representing p->u edges in bottom up order'''
edges = G.dfs_topdown(s, connect_roots)
edges.reverse()
return edges
def is_bipartite(G):
N = G.N
que = deque()
color = [-1]*N
for s in range(N):
if color[s] >= 0:
continue
color[s] = 1
que.append(s)
while que:
u = que.popleft()
for v in G[u]:
if color[v] == -1:
color[v] = 1 - color[u]
que.append(v)
elif color[v] == color[u]:
return False
return True
def starts(G, v: Union[int,list,None]) -> Iterable:
if isinstance(v, int):
return (v,)
elif v is None:
return range(G.N)
else:
return v
@classmethod
def compile(cls, N: int, M: int, E):
edge = Parser.compile(E)
def parse(ts: TokenStream):
return cls(N, [edge(ts) for _ in range(M)])
return parse
def sort2(a, b):
return (a,b) if a < b else (b,a)
import operator
from itertools import accumulate
def presum(iter: Iterable[_T], func: Callable[[_T,_T],_T] = None, initial: _T = None, step = 1) -> list[_T]:
if step == 1:
return list(accumulate(iter, func, initial=initial))
else:
assert step >= 2
if func is None:
func = operator.add
A = list(iter)
if initial is not None:
A = [initial] + A
for i in range(step,len(A)):
A[i] = func(A[i], A[i-step])
return A
# from typing import Generic
# from cp_library.misc.typing import _T
def min2(a, b):
return a if a < b else b
class MinSparseTable:
def __init__(st, arr: list):
st.N = N = len(arr)
st.log = N.bit_length()
st.data = data = [0] * (st.log*N)
data[:N] = arr
for i in range(1,st.log):
a, b, c = i*N, (i-1)*N, (i-1)*N + (1 << (i-1))
for j in range(N - (1 << i) + 1):
data[a+j] = min2(data[b+j], data[c+j])
def query(st, l: int, r: int):
k = (r-l).bit_length() - 1
return min2(st.data[k*st.N + l], st.data[k*st.N + r - (1<<k)])
class LCATable(MinSparseTable):
def __init__(lca, T, root = 0):
N = len(T)
T.euler_tour(root)
lca.depth = depth = presum(T.delta)
lca.tin, lca.tout = T.tin[:], T.tout[:]
lca.mask = (1 << (shift := N.bit_length()))-1
lca.shift = shift
order = T.order
M = len(order)
packets = [0]*M
for i in range(M):
packets[i] = depth[i] << shift | order[i]
super().__init__(packets)
def _query(lca, u, v):
l, r = sort2(lca.tin[u], lca.tin[v]); r += 1
da = super().query(l, r)
return l, r, da & lca.mask, da >> lca.shift
def query(lca, u, v) -> tuple[int,int]:
l, r, a, d = lca._query(u, v)
return a, d
def distance(lca, u, v) -> int:
l, r, a, d = lca._query(u, v)
return lca.depth[l] + lca.depth[r-1] - 2*d
def path(lca, u, v):
path, par, lca, c = [], lca.T.par, lca.query(u, v)[0], u
while c != lca:
path.append(c)
c = par[c]
path.append(lca)
rev_path, c = [], v
while c != lca:
rev_path.append(c)
c = par[c]
path.extend(reversed(rev_path))
return path
class TreeProtocol(GraphProtocol):
@cached_property
def lca(T):
return LCATable(T)
@overload
def diameter(T) -> int: ...
@overload
def diameter(T, endpoints: Literal[True]) -> tuple[int,int,int]: ...
def diameter(T, endpoints = False):
mask = (1 << (shift := T.N.bit_length())) - 1
s = max(d << shift | v for v,d in enumerate(T.distance(0))) & mask
dg = max(d << shift | v for v,d in enumerate(T.distance(s)))
diam, g = dg >> shift, dg & mask
return (diam, s, g) if endpoints else diam
@overload
def distance(T) -> list[list[int]]: ...
@overload
def distance(T, s: int = 0) -> list[int]: ...
@overload
def distance(T, s: int, g: int) -> int: ...
def distance(T, s = None, g = None):
if s == None:
return [T.dfs(u) for u in range(T.N)]
else:
return T.dfs(s, g)
@overload
def dfs(T, s: int = 0) -> list[int]: ...
@overload
def dfs(T, s: int, g: int) -> int: ...
def dfs(T, s = 0, g = None):
D = [inf for _ in range(T.N)]
D[s] = 0
state = [True for _ in range(T.N)]
stack = [s]
while stack:
u = stack.pop()
if u == g: return D[u]
state[u] = False
for v in T[u]:
if state[v]:
D[v] = D[u]+1
stack.append(v)
return D if g is None else inf
def dfs_events(G, flags: DFSFlags, s: int = 0):
events = []
stack = [(s,-1)]
adj = [None]*G.N
while stack:
u, p = stack[-1]
if adj[u] is None:
adj[u] = iter(G.neighbors(u))
if DFSFlags.ENTER in flags:
events.append((DFSEvent.ENTER, u))
if (v := next(adj[u], None)) is not None:
if v == p:
if DFSFlags.BACK in flags:
events.append((DFSEvent.BACK, u, v))
else:
if DFSFlags.DOWN in flags:
events.append((DFSEvent.DOWN, u, v))
stack.append((v,u))
else:
stack.pop()
if DFSFlags.LEAVE in flags:
events.append((DFSEvent.LEAVE, u))
if p != -1 and DFSFlags.UP in flags:
events.append((DFSEvent.UP, u, p))
return events
def euler_tour(T, s = 0):
N = len(T)
T.tin = tin = [-1] * N
T.tout = tout = [-1] * N
T.par = par = [-1] * N
T.order = order = elist(2*N)
T.delta = delta = elist(2*N)
stack = elist(N)
stack.append(s)
while stack:
u = stack.pop()
p = par[u]
if tin[u] == -1:
tin[u] = len(order)
for v in T[u]:
if v != p:
par[v] = u
stack.append(u)
stack.append(v)
delta.append(1)
else:
delta.append(-1)
order.append(u)
tout[u] = len(order)
delta[0] = delta[-1] = 0
def hld_precomp(T, r = 0):
N, time = T.N, 0
tin, tout, size = [0]*N, [0]*N, [1]*N+[0]
par, heavy, head = [-1]*N, [-1]*N, [r]*N
depth, order, state = [0]*N, [0]*N, [0]*N
stack = elist(N)
stack.append(r)
while stack:
if (s := state[v := stack.pop()]) == 0: # dfs down
p, state[v] = par[v], 1
stack.append(v)
for c in T[v]:
if c != p:
depth[c], par[c] = depth[v]+1, v
stack.append(c)
elif s == 1: # dfs up
p, l = par[v], -1
for c in T[v]:
if c != p:
size[v] += size[c]
if size[c] > size[l]:
l = c
heavy[v] = l
if p == -1:
state[v] = 2
stack.append(v)
elif s == 2: # decompose down
p, h, l = par[v], head[v], heavy[v]
tin[v], order[time], state[v] = time, v, 3
time += 1
stack.append(v)
for c in T[v]:
if c != p and c != l:
head[c], state[c] = c, 2
stack.append(c)
if l != -1:
head[l], state[l] = h, 2
stack.append(l)
elif s == 3: # decompose up
tout[v] = time
T.size, T.depth = size, depth
T.order, T.tin, T.tout = order, tin, tout
T.par, T.heavy, T.head = par, heavy, head