This documentation is automatically generated by online-judge-tools/verification-helper
import cp_library.__header__
import cp_library.alg.__header__
from cp_library.alg.graph.csr.graph_weighted_meta_cls import GraphWeightedMeta
import cp_library.alg.tree.__header__
import cp_library.alg.tree.csr.__header__
from cp_library.alg.tree.csr.tree_weighted_base_cls import TreeWeightedBase
class TreeWeightedMeta(TreeWeightedBase, GraphWeightedMeta):
@classmethod
def compile(cls, N: int, T: list[type] = [-1,-1,int,int]):
return GraphWeightedMeta.compile.__func__(cls, N, N-1, T)
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
import typing
from math import inf
from typing import overload
def chmin(dp, i, v):
if ch:=dp[i]>v:dp[i]=v
return ch
def argsort(A: list[int], reverse=False):
P = Packer(len(I := A.copy())-1); P.ienumerate(I, reverse); I.sort(); P.iindices(I)
return I
class Packer:
def __init__(P, mx: int):
P.s = mx.bit_length()
P.m = (1 << P.s) - 1
def enc(P, a: int, b: int): return a << P.s | b
def dec(P, x: int) -> tuple[int, int]: return x >> P.s, x & P.m
def enumerate(P, A, reverse=False): P.ienumerate(A:=A.copy(), reverse); return A
def ienumerate(P, A, reverse=False):
if reverse:
for i,a in enumerate(A): A[i] = P.enc(-a, i)
else:
for i,a in enumerate(A): A[i] = P.enc(a, i)
def indices(P, A: list[int]): P.iindices(A:=A.copy()); return A
def iindices(P, A):
for i,a in enumerate(A): A[i] = P.m&a
from collections import deque
from typing import Callable, Sequence, Union, overload
from numbers import Number
from types import GenericAlias
from typing import Callable, Collection, Iterator, Union
import os
import sys
from io import BytesIO, IOBase
class FastIO(IOBase):
BUFSIZE = 8192
newlines = 0
def __init__(self, file):
self._fd = file.fileno()
self.buffer = BytesIO()
self.writable = "x" in file.mode or "r" not in file.mode
self.write = self.buffer.write if self.writable else None
def read(self):
BUFSIZE = self.BUFSIZE
while True:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
if not b:
break
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines = 0
return self.buffer.read()
def readline(self):
BUFSIZE = self.BUFSIZE
while self.newlines == 0:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
self.newlines = b.count(b"\n") + (not b)
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines -= 1
return self.buffer.readline()
def flush(self):
if self.writable:
os.write(self._fd, self.buffer.getvalue())
self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
stdin: 'IOWrapper' = None
stdout: 'IOWrapper' = None
def __init__(self, file):
self.buffer = FastIO(file)
self.flush = self.buffer.flush
self.writable = self.buffer.writable
def write(self, s):
return self.buffer.write(s.encode("ascii"))
def read(self):
return self.buffer.read().decode("ascii")
def readline(self):
return self.buffer.readline().decode("ascii")
try:
sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
except:
pass
from typing import TypeVar
_T = TypeVar('T')
_U = TypeVar('U')
class TokenStream(Iterator):
stream = IOWrapper.stdin
def __init__(self):
self.queue = deque()
def __next__(self):
if not self.queue: self.queue.extend(self._line())
return self.queue.popleft()
def wait(self):
if not self.queue: self.queue.extend(self._line())
while self.queue: yield
def _line(self):
return TokenStream.stream.readline().split()
def line(self):
if self.queue:
A = list(self.queue)
self.queue.clear()
return A
return self._line()
TokenStream.default = TokenStream()
class CharStream(TokenStream):
def _line(self):
return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()
ParseFn = Callable[[TokenStream],_T]
class Parser:
def __init__(self, spec: Union[type[_T],_T]):
self.parse = Parser.compile(spec)
def __call__(self, ts: TokenStream) -> _T:
return self.parse(ts)
@staticmethod
def compile_type(cls: type[_T], args = ()) -> _T:
if issubclass(cls, Parsable):
return cls.compile(*args)
elif issubclass(cls, (Number, str)):
def parse(ts: TokenStream): return cls(next(ts))
return parse
elif issubclass(cls, tuple):
return Parser.compile_tuple(cls, args)
elif issubclass(cls, Collection):
return Parser.compile_collection(cls, args)
elif callable(cls):
def parse(ts: TokenStream):
return cls(next(ts))
return parse
else:
raise NotImplementedError()
@staticmethod
def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
if isinstance(spec, (type, GenericAlias)):
cls = typing.get_origin(spec) or spec
args = typing.get_args(spec) or tuple()
return Parser.compile_type(cls, args)
elif isinstance(offset := spec, Number):
cls = type(spec)
def parse(ts: TokenStream): return cls(next(ts)) + offset
return parse
elif isinstance(args := spec, tuple):
return Parser.compile_tuple(type(spec), args)
elif isinstance(args := spec, Collection):
return Parser.compile_collection(type(spec), args)
elif isinstance(fn := spec, Callable):
def parse(ts: TokenStream): return fn(next(ts))
return parse
else:
raise NotImplementedError()
@staticmethod
def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
if spec is int:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
return parse
else:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
return parse
@staticmethod
def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
return parse
@staticmethod
def compile_children(cls: _T, specs) -> ParseFn[_T]:
fns = tuple((Parser.compile(spec) for spec in specs))
def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])
return parse
@staticmethod
def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
return Parser.compile_line(cls, specs[0])
else:
return Parser.compile_children(cls, specs)
@staticmethod
def compile_collection(cls, specs):
if not specs or len(specs) == 1 or isinstance(specs, set):
return Parser.compile_line(cls, *specs)
elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
return Parser.compile_repeat(cls, specs[0], specs[1])
else:
raise NotImplementedError()
class Parsable:
@classmethod
def compile(cls):
def parser(ts: TokenStream): return cls(next(ts))
return parser
@classmethod
def __class_getitem__(cls, item):
return GenericAlias(cls, item)
from enum import auto, IntFlag, IntEnum
class DFSFlags(IntFlag):
ENTER = auto()
DOWN = auto()
BACK = auto()
CROSS = auto()
LEAVE = auto()
UP = auto()
MAXDEPTH = auto()
RETURN_PARENTS = auto()
RETURN_DEPTHS = auto()
BACKTRACK = auto()
CONNECT_ROOTS = auto()
# Common combinations
ALL_EDGES = DOWN | BACK | CROSS
EULER_TOUR = DOWN | UP
INTERVAL = ENTER | LEAVE
TOPDOWN = DOWN | CONNECT_ROOTS
BOTTOMUP = UP | CONNECT_ROOTS
RETURN_ALL = RETURN_PARENTS | RETURN_DEPTHS
class DFSEvent(IntEnum):
ENTER = DFSFlags.ENTER
DOWN = DFSFlags.DOWN
BACK = DFSFlags.BACK
CROSS = DFSFlags.CROSS
LEAVE = DFSFlags.LEAVE
UP = DFSFlags.UP
MAXDEPTH = DFSFlags.MAXDEPTH
class GraphBase(Parsable):
def __init__(G, N: int, M: int, U: list[int], V: list[int],
deg: list[int], La: list[int], Ra: list[int],
Ua: list[int], Va: list[int], Ea: list[int], twin: list[int] = None):
G.N = N
'''The number of vertices.'''
G.M = M
'''The number of edges.'''
G.U = U
'''A list of source vertices in the original edge list.'''
G.V = V
'''A list of destination vertices in the original edge list.'''
G.deg = deg
'''deg[u] is the out degree of vertex u.'''
G.La = La
'''La[u] stores the start index of the list of adjacent vertices from u.'''
G.Ra = Ra
'''Ra[u] stores the stop index of the list of adjacent vertices from u.'''
G.Ua = Ua
'''Ua[i] = u for La[u] <= i < Ra[u], useful for backtracking.'''
G.Va = Va
'''Va[i] lists adjacent vertices to u for La[u] <= i < Ra[u].'''
G.Ea = Ea
'''Ea[i] lists the edge ids that start from u for La[u] <= i < Ra[u].
For undirected graphs, edge ids in range M<= e <2*M are edges from V[e-M] -> U[e-M].
'''
G.twin = twin if twin is not None else range(len(Ua))
'''twin[i] in undirected graphs stores index j of the same edge but with u and v swapped.'''
G.st: list[int] = None
G.order: list[int] = None
G.vis: list[int] = None
G.back: list[int] = None
G.tin: list[int] = None
def clear(G):
G.vis = G.back = G.tin = None
def prep_vis(G):
if G.vis is None: G.vis = u8f(G.N)
return G.vis
def prep_st(G):
if G.st is None: G.st = elist(G.N)
else: G.st.clear()
return G.st
def prep_order(G):
if G.order is None: G.order = elist(G.N)
else: G.order.clear()
return G.order
def prep_back(G):
if G.back is None: G.back = i32f(G.N, -2)
return G.back
def prep_tin(G):
if G.tin is None: G.tin = i32f(G.N, -1)
return G.tin
def __len__(G) -> int: return G.N
def __getitem__(G, u): return G.Va[G.La[u]:G.Ra[u]]
def range(G, u): return range(G.La[u],G.Ra[u])
@overload
def distance(G) -> list[list[int]]: ...
@overload
def distance(G, s: int = 0) -> list[int]: ...
@overload
def distance(G, s: int, g: int) -> int: ...
def distance(G, s = None, g = None):
if s == None: return G.floyd_warshall()
else: return G.bfs(s, g)
def recover_path(G, s, t):
Ua, back, vertices = G.Ua, G.back, u32f(1, v := t)
while v != s: vertices.append(v := Ua[back[v]])
return vertices
def recover_path_edge_ids(G, s, t):
Ea, Ua, back, edges, v = G.Ea, G.Ua, G.back, u32f(0), t
while v != s: edges.append(Ea[i := back[v]]), (v := Ua[i])
return edges
def shortest_path(G, s: int, t: int):
if G.distance(s, t) >= inf: return None
vertices = G.recover_path(s, t)
vertices.reverse()
return vertices
def shortest_path_edge_ids(G, s: int, t: int):
if G.distance(s, t) >= inf: return None
edges = G.recover_path_edge_ids(s, t)
edges.reverse()
return edges
@overload
def bfs(G, s: Union[int,list] = 0) -> list[int]: ...
@overload
def bfs(G, s: Union[int,list], g: int) -> int: ...
def bfs(G, s: int = 0, g: int = None):
S, Va, back, D = G.starts(s), G.Va, i32f(N := G.N, -1), [inf]*N
G.back, G.D = back, D
for u in S: D[u] = 0
que = deque(S)
while que:
nd = D[u := que.popleft()]+1
if u == g: return nd-1
for i in G.range(u):
if nd < D[v := Va[i]]:
D[v], back[v] = nd, i
que.append(v)
return D if g is None else inf
def floyd_warshall(G) -> list[list[int]]:
G.D = D = [[inf]*G.N for _ in range(G.N)]
for u in range(G.N): D[u][u] = 0
for i in range(len(G.Ua)): D[G.Ua[i]][G.Va[i]] = 1
for k, Dk in enumerate(D):
for Di in D:
if (Dik := Di[k]) == inf: continue
for j in range(G.N):
chmin(Di, j, Dik+Dk[j])
return D
def find_cycle_indices(G, s: Union[int, None] = None):
Ea, Ua, Va, vis, back = G.Ea, G. Ua, G.Va, u8f(N := G.N), u32f(N, i32_max)
G.vis, G.back, st = vis, back, elist(N)
for s in G.starts(s):
if vis[s]: continue
st.append(s)
while st:
if not vis[u := st.pop()]:
st.append(u)
vis[u], pe = 1, Ea[j] if (j := back[u]) != i32_max else i32_max
for i in G.range(u):
if not vis[v := Va[i]]:
back[v] = i
st.append(v)
elif vis[v] == 1 and pe != Ea[i]:
I = u32f(1,i)
while v != u: I.append(i := back[u]), (u := Ua[i])
I.reverse()
return I
else:
vis[u] = 2
# check for self loops
for i in range(len(Ua)):
if Ua[i] == Va[i]:
return u32f(1,i)
def find_cycle(G, s: Union[int, None] = None):
if I := G.find_cycle_indices(s): return [G.Ua[i] for i in I]
def find_cycle_edge_ids(G, s: Union[int, None] = None):
if I := G.find_cycle_indices(s): return [G.Ea[i] for i in I]
def find_minimal_cycle(G, s=0):
D, par, que, Va = u32f(N := G.N, u32_max), i32f(N, -1), deque([s]), G.Va
D[s] = 0
while que:
for i in G.range(u := que.popleft()):
if (v := Va[i]) == s: # Found cycle back to start
cycle = [u]
while u != s: cycle.append(u := par[u])
return cycle
if D[v] < u32_max: continue
D[v], par[v] = D[u]+1, u; que.append(v)
def dfs_topdown(G, s: Union[int,list] = None) -> list[int]:
'''Returns lists of indices i where Ua[i] -> Va[i] are edges in order of top down discovery'''
vis, st, order = G.prep_vis(), G.prep_st(), G.prep_order()
for s in G.starts(s):
if vis[s]: continue
vis[s] = 1; st.append(s)
while st:
for i in G.range(st.pop()):
if vis[v := G.Va[i]]: continue
vis[v] = 1; order.append(i); st.append(v)
return order
def dfs(G, s: Union[int,list] = None, /,
backtrack = False,
max_depth = None,
enter_fn: Callable[[int],None] = None,
leave_fn: Callable[[int],None] = None,
max_depth_fn: Callable[[int],None] = None,
down_fn: Callable[[int,int,int],None] = None,
back_fn: Callable[[int,int,int],None] = None,
forward_fn: Callable[[int,int,int],None] = None,
cross_fn: Callable[[int,int,int],None] = None,
up_fn: Callable[[int,int,int],None] = None):
I, time, vis, st, back, tin = G.La[:], -1, G.prep_vis(), G.prep_st(), G.prep_back(), G.prep_tin()
for s in G.starts(s):
if vis[s]: continue
back[s], tin[s] = -1, (time := time+1); st.append(s)
while st:
if vis[u := st[-1]] == 0:
vis[u] = 1
if enter_fn: enter_fn(u)
if max_depth is not None and len(st) > max_depth:
I[u] = G.Ra[u]
if max_depth_fn: max_depth_fn(u)
if (i := I[u]) < G.Ra[u]:
I[u] += 1
if (s := vis[v := G.Va[i]]) == 0:
back[v], tin[v] = i, (time := time+1); st.append(v)
if down_fn: down_fn(u,v,i)
elif back_fn and s == 1 and back[u] != G.twin[i]: back_fn(u,v,i)
elif (cross_fn or forward_fn) and s == 2:
if forward_fn and tin[u] < tin[v]: forward_fn(u,v,i)
elif cross_fn: cross_fn(u,v,i)
else:
vis[u] = 2; st.pop()
if backtrack: vis[u], I[u] = 0, G.La[u]
if leave_fn: leave_fn(u)
if up_fn and st: up_fn(u, st[-1], back[u])
def dfs_enter_leave(G, s: Union[int,list[int],None] = None) -> Sequence[tuple[DFSEvent,int]]:
N, I = G.N, G.La[:]
st, back, plst = elist(N), i32f(N,-2), PacketList(order := elist(2*N), N-1)
G.back, ENTER, LEAVE = back, int(DFSEvent.ENTER) << plst.shift, int(DFSEvent.LEAVE) << plst.shift
for s in G.starts(s):
if back[s] >= -1: continue
back[s] = -1
order.append(ENTER | s), st.append(s)
while st:
if (i := I[u := st[-1]]) < G.Ra[u]:
I[u] += 1
if back[v := G.Va[i]] >= -1: continue
back[v] = i; order.append(ENTER | v); st.append(v)
else:
order.append(LEAVE | u); st.pop()
return plst
def starts(G, s: Union[int,list[int],None] = None) -> list[int]:
if isinstance(s, int): return [s]
elif s is None: return range(G.N)
elif isinstance(s, list): return s
else: return list(s)
@classmethod
def compile(cls, N: int, M: int, shift: int = -1):
def parse(ts: TokenStream):
U, V = u32f(M), u32f(M)
for i in range(M):
u, v = ts._line()
U[i], V[i] = int(u)+shift, int(v)+shift
return cls(N, U, V)
return parse
u32_max = (1<<32)-1
i32_max = (1<<31)-1
from array import array
def u8f(N: int, elm: int = 0): return array('B', (elm,))*N # unsigned char
def u32f(N: int, elm: int = 0): return array('I', (elm,))*N # unsigned int
def i32f(N: int, elm: int = 0): return array('i', (elm,))*N # signed int
def elist(est_len: int) -> list: ...
try:
from __pypy__ import newlist_hint
except:
def newlist_hint(hint):
return []
elist = newlist_hint
class PacketList(Sequence[tuple[int,int]]):
def __init__(lst, A: list[int], max1: int):
lst.A = A
lst.mask = (1 << (shift := (max1).bit_length())) - 1
lst.shift = shift
def __len__(lst): return lst.A.__len__()
def __contains__(lst, x: tuple[int,int]): return lst.A.__contains__(x[0] << lst.shift | x[1])
def __getitem__(lst, key) -> tuple[int,int]:
x = lst.A[key]
return x >> lst.shift, x & lst.mask
class GraphWeightedBase(GraphBase):
def __init__(self, N: int, M: int, U: list[int], V: list[int], W: list[int],
deg: list[int], La: list[int], Ra: list[int],
Ua: list[int], Va: list[int], Wa: list[int], Ea: list[int], twin: list[int] = None):
super().__init__(N, M, U, V, deg, La, Ra, Ua, Va, Ea, twin)
self.W = W
self.Wa = Wa
'''Wa[i] lists weights to edges from u for La[u] <= i < Ra[u].'''
def __getitem__(G, u):
l,r = G.La[u],G.Ra[u]
return zip(G.Va[l:r], G.Wa[l:r])
@overload
def distance(G) -> list[list[int]]: ...
@overload
def distance(G, s: int = 0) -> list[int]: ...
@overload
def distance(G, s: int, g: int) -> int: ...
def distance(G, s = None, g = None):
if s == None: return G.floyd_warshall()
else: return G.dijkstra(s, g)
def dijkstra(G, s: int, t: int = None):
G.back, G.D, S = i32f(G.N, -1), [inf]*G.N, G.starts(s)
for s in S: G.D[s] = 0
que = PriorityQueue(G.N, S)
while que:
d, u = que.pop()
if d > G.D[u]: continue
if u == t: return d
i, r = G.La[u]-1, G.Ra[u]
while (i:=i+1)<r:
if chmin(G.D, v := G.Va[i], nd := d + G.Wa[i]):
G.back[v] = i; que.push(nd, v)
return G.D if t is None else inf
def kruskal(G):
U, V, W, dsu, MST, need = G.U, G.V, G.W, DSU(N := G.N), [0]*(N-1), N-1
for e in argsort(W):
u, v = dsu.merge(U[e],V[e])
if u != v:
MST[need := need-1] = e
if not need: break
return None if need else MST
def kruskal_heap(G):
N, M, U, V, W = G.N, G.M, G.U, G.V, G.W
que, dsu, MST = PriorityQueue(M, list(range(M)), W), DSU(N), [0]*(need := N-1)
while que and need:
_, e = que.pop()
u, v = dsu.merge(U[e],V[e])
if u != v:
MST[need := need-1] = e
return None if need else MST
def bellman_ford(G, s: int = 0) -> list[int]:
Ua, Va, Wa, D = G.Ua, G.Va, G.Wa, [inf]*(N := G.N)
D[s] = 0
for _ in range(N-1):
for i, u in enumerate(Ua):
if D[u] < inf: chmin(D, Va[i], D[u] + Wa[i])
return D
def bellman_ford_neg_cyc_check(G, s: int = 0) -> tuple[bool, list[int]]:
M, U, V, W, D = G.M, G.U, G.V, G.W, G.bellman_ford(s)
neg_cycle = any(D[U[i]]+W[i]<D[V[i]] for i in range(M) if D[U[i]] < inf)
return neg_cycle, D
def floyd_warshall(G) -> list[list[int]]:
N, Ua, Va, Wa = G.N, G.Ua, G.Va, G.Wa
D = [[inf]*N for _ in range(N)]
for u in range(N): D[u][u] = 0
for i in range(len(Ua)): chmin(D[Ua[i]], Va[i], Wa[i])
for k, Dk in enumerate(D):
for Di in D:
if Di[k] >= inf: continue
for j in range(N):
if Dk[j] >= inf: continue
chmin(Di, j, Di[k]+Dk[j])
return D
def floyd_warshall_neg_cyc_check(G):
D = G.floyd_warshall()
return any(D[i][i] < 0 for i in range(G.N)), D
@classmethod
def compile(cls, N: int, M: int, shift: int = -1):
def parse(ts: TokenStream):
U, V, W = u32f(M), u32f(M), [0]*M
for i in range(M):
u, v, w = ts._line()
U[i], V[i], W[i] = int(u)+shift, int(v)+shift, int(w)
return cls(N, U, V, W)
return parse
class DSU(Parsable):
def __init__(dsu, N): dsu.N, dsu.cc, dsu.par = N, N, [-1]*N
def merge(dsu, u, v):
x, y = dsu.root(u), dsu.root(v)
if x == y: return x,y
if dsu.par[x] > dsu.par[y]: x, y = y, x
dsu.par[x] += dsu.par[y]; dsu.par[y] = x; dsu.cc -= 1
return x, y
def root(dsu, i) -> int:
p = (par := dsu.par)[i]
while p >= 0:
if par[p] < 0: return p
par[i], i, p = par[p], par[p], par[par[p]]
return i
def groups(dsu) -> 'CSRIncremental[int]':
sizes, row, p = [0]*dsu.cc, [-1]*dsu.N, 0
for i in range(dsu.cc):
while dsu.par[p] >= 0: p += 1
sizes[i], row[p] = -dsu.par[p], i; p += 1
csr = CSRIncremental(sizes)
for i in range(dsu.N): csr.append(row[dsu.root(i)], i)
return csr
__iter__ = groups
def merge_dest(dsu, u, v): return dsu.merge(u, v)[0]
def same(dsu, u: int, v: int): return dsu.root(u) == dsu.root(v)
def size(dsu, i) -> int: return -dsu.par[dsu.root(i)]
def __len__(dsu): return dsu.cc
def __contains__(dsu, uv): u, v = uv; return dsu.same(u, v)
@classmethod
def compile(cls, N: int, M: int, shift = -1):
def parse_fn(ts: TokenStream):
dsu = cls(N)
for _ in range(M): u, v = ts._line(); dsu.merge(int(u)+shift, int(v)+shift)
return dsu
return parse_fn
class CSRIncremental(Sequence[list[_T]]):
def __init__(csr, sizes: list[int]):
csr.L, N = [0]*len(sizes), 0
for i,sz in enumerate(sizes):
csr.L[i] = N; N += sz
csr.R, csr.A = csr.L[:], [0]*N
def append(csr, i: int, x: _T):
csr.A[csr.R[i]] = x; csr.R[i] += 1
def __iter__(csr):
for i,l in enumerate(csr.L):
yield csr.A[l:csr.R[i]]
def __getitem__(csr, i: int) -> _T:
return csr.A[i]
def __len__(dsu):
return len(dsu.L)
def range(csr, i: int) -> _T:
return range(csr.L[i], csr.R[i])
def heappush(heap: list, item):
heap.append(item)
heapsiftdown(heap, 0, len(heap)-1)
def heappop(heap: list):
item = heap.pop()
if heap: item, heap[0] = heap[0], item; heapsiftup(heap, 0)
return item
def heapreplace(heap: list, item):
item, heap[0] = heap[0], item; heapsiftup(heap, 0)
return item
def heappushpop(heap: list, item):
if heap and heap[0] < item: item, heap[0] = heap[0], item; heapsiftup(heap, 0)
return item
def heapify(x: list):
for i in reversed(range(len(x)//2)): heapsiftup(x, i)
def heapsiftdown(heap: list, root: int, pos: int):
item = heap[pos]
while root < pos and item < heap[p := (pos-1)>>1]: heap[pos], pos = heap[p], p
heap[pos] = item
def heapsiftup(heap: list, pos: int):
n, item, c = len(heap)-1, heap[pos], pos<<1|1
while c < n and heap[c := c+(heap[c+1]<heap[c])] < item: heap[pos], pos, c = heap[c], c, c<<1|1
if c == n and heap[c] < item: heap[pos], pos = heap[c], c
heap[pos] = item
def heappop_max(heap: list):
item = heap.pop()
if heap: item, heap[0] = heap[0], item; heapsiftup_max(heap, 0)
return item
def heapreplace_max(heap: list, item):
item, heap[0] = heap[0], item; heapsiftup_max(heap, 0)
return item
def heapify_max(x: list):
for i in reversed(range(len(x)//2)): heapsiftup_max(x, i)
def heappush_max(heap: list, item):
heap.append(item); heapsiftdown_max(heap, 0, len(heap)-1)
def heapreplace_max(heap: list, item):
item, heap[0] = heap[0], item; heapsiftup_max(heap, 0)
return item
def heappushpop_max(heap: list, item):
if heap and heap[0] > item: item, heap[0] = heap[0], item; heapsiftup_max(heap, 0)
return item
def heapsiftdown_max(heap: list, root: int, pos: int):
item = heap[pos]
while root < pos and heap[p := (pos-1)>>1] < item: heap[pos], pos = heap[p], p
heap[pos] = item
def heapsiftup_max(heap: list, pos: int):
n, item, c = len(heap)-1, heap[pos], pos<<1|1
while c < n and item < heap[c := c+(heap[c]<heap[c+1])]: heap[pos], pos, c = heap[c], c, c<<1|1
if c == n and item < heap[c]: heap[pos], pos = heap[c], c
heap[pos] = item
from typing import Generic
class HeapProtocol(Generic[_T]):
def peek(heap) -> _T: return heap.data[0]
def pop(heap) -> _T: ...
def push(heap, item: _T): ...
def pushpop(heap, item: _T) -> _T: ...
def replace(heap, item: _T) -> _T: ...
def __contains__(heap, item: _T): return item in heap.data
def __len__(heap): return len(heap.data)
def clear(heap): heap.data.clear()
class PriorityQueue(HeapProtocol[int]):
def __init__(que, N: int, ids: list[int] = None, priorities: list[int] = None, /):
que.pkr = Packer(N)
if ids is None: que.data = elist(N)
elif priorities is None: heapify(ids); que.data = ids
else:
que.data = [0]*(M := len(ids))
for i in range(M): que.data[i] = que.pkr.enc(priorities[i], ids[i])
heapify(que.data)
def pop(que): return que.pkr.dec(heappop(que.data))
def push(que, priority: int, id: int): heappush(que.data, que.pkr.enc(priority, id))
def pushpop(que, priority: int, id: int): return que.pkr.dec(heappushpop(que.data, que.pkr.enc(priority, id)))
def replace(que, priority: int, id: int): return que.pkr.dec(heapreplace(que.data, que.pkr.enc(priority, id)))
def peek(que): return que.pkr.dec(que.data[0])
class GraphWeighted(GraphWeightedBase):
def __init__(G, N: int, U: list[int], V: list[int], W: list[int]):
Ma, deg = 0, u32f(N)
for e in range(M := len(U)):
distinct = (u := U[e]) != (v := V[e])
deg[u] += 1; deg[v] += distinct; Ma += 1+distinct
twin, Ea, Ua, Va, Wa = u32f(Ma), u32f(Ma), u32f(Ma), u32f(Ma), [0]*Ma
La, i = u32f(N), 0
for u,d in enumerate(deg):
La[u], i = i, i + d
Ra = La[:]
for e in range(M):
u, v, w = U[e], V[e], W[e]
i, j = Ra[u], Ra[v]
Ra[u],Ua[i],Va[i],Wa[i],Ea[i],twin[i] = i+1,u,v,w,e,j
if i == j: continue # don't add self loops twice
Ra[v],Ua[j],Va[j],Wa[j],Ea[j],twin[j] = j+1,v,u,w,e,i
super().__init__(N, M, U, V, W, deg, La, Ra, Ua, Va, Wa, Ea, twin)
class GraphWeightedMeta(GraphWeighted):
def __init__(G, N: int, U: list[int], V: list[int], W: list[int],
X: list[int] = None, Y: list[int] = None, Z: list[int] = None):
super().__init__(N, U, V, W)
M2 = len(G.Ea)
if X is not None:
Xa = [0]*M2
for i,e in enumerate(G.Ea):
Xa[i] = X[e]
G.X = X
'''A parallel lists of edge meta data from the original edge list.'''
G.Xa = Xa
'''Xa[i] parallel lists of adjacent meta data to u for La[u] <= i < Ra[u].'''
if Y is not None:
Ya = [0]*M2
for i,e in enumerate(G.Ea):
Ya[i] = Y[e]
G.Y = Y
'''A parallel lists of edge meta data from the original edge list.'''
G.Ya = Ya
'''Ya[i] parallel lists of adjacent meta data to u for La[u] <= i < Ra[u].'''
if Z is not None:
Za = [0]*M2
for i,e in enumerate(G.Ea):
Za[i] = Z[e]
G.Z = Z
'''A parallel lists of edge meta data from the original edge list.'''
G.Za = Za
'''Za[i] parallel lists of adjacent meta data to u for La[u] <= i < Ra[u].'''
@classmethod
def compile(cls, N: int, M: int, T: list[type] = [-1,-1,int,int]):
u, v, *w = map(Parser.compile, typing.get_args(T) or T)
if len(w) == 2:
if T == [-1,-1,int,int]:
def parse(ts: TokenStream):
U, V, W, X = u32f(M), u32f(M), [0]*M, [0]*M
for i in range(M):
u,v,a,b = ts.line()
U[i], V[i], W[i], X[i] = int(u)-1, int(v)-1, int(a), int(b)
return cls(N, U, V, W, X)
else:
w, x = w
def parse(ts: TokenStream):
U, V, W, X = u32f(M), u32f(M), [0]*M, [0]*M
for i in range(M):
U[i], V[i], W[i], X[i] = u(ts), v(ts), w(ts), x(ts)
return cls(N, U, V, W, X)
elif len(w) == 3:
w, x, y = w
def parse(ts: TokenStream):
U, V, W, X, Y = u32f(M), u32f(M), [0]*M, [0]*M, [0]*M
for i in range(M):
U[i], V[i], W[i], X[i], Y[i] = u(ts), v(ts), w(ts), x(ts), y(ts)
return cls(N, U, V, W, X, Y)
else:
w, x, y, z = w
def parse(ts: TokenStream):
U, V, W, X, Y, Z = u32f(M), u32f(M), [0]*M, [0]*M, [0]*M, [0]*M
for i in range(M):
U[i], V[i], W[i], X[i], Y[i], Z[i] = u(ts), v(ts), w(ts), x(ts), y(ts), z(ts)
return cls(N, U, V, W, X, Y, Z)
return parse
from typing import Optional
from typing import Callable, Literal, Union, overload
class TreeBase(GraphBase):
@overload
def distance(T) -> list[list[int]]: ...
@overload
def distance(T, s: int = 0) -> list[int]: ...
@overload
def distance(T, s: int, g: int) -> int: ...
def distance(T, s = None, g = None):
if s == None:
return [T.dfs_distance(u) for u in range(T.N)]
else:
return T.dfs_distance(s, g)
@overload
def diameter(T) -> int: ...
@overload
def diameter(T, endpoints: Literal[True]) -> tuple[int,int,int]: ...
def diameter(T, endpoints = False):
mask = (1 << (shift := T.N.bit_length())) - 1
s = max(d << shift | v for v,d in enumerate(T.distance(0))) & mask
dg = max(d << shift | v for v,d in enumerate(T.distance(s)))
diam, g = dg >> shift, dg & mask
return (diam, s, g) if endpoints else diam
def dfs_distance(T, s: int, g: Union[int,None] = None):
st, Va = elist(N := T.N), T.Va
T.D, T.back = D, back = [inf]*N, i32f(N, -1)
D[s] = 0
st.append(s)
while st:
nd = D[u := st.pop()]+1
if u == g: return nd-1
for i in T.range(u):
if nd < D[v := Va[i]]:
D[v], back[v] = nd, i
st.append(v)
return D if g is None else inf
def rerooting_dp(T, e: _T,
merge: Callable[[_T,_T],_T],
edge_op: Callable[[_T,int,int,int],_T] = lambda s,i,p,u:s,
s: int = 0):
La, Ua, Va = T.La, T.Ua, T.Va
order, dp, suf, I = T.dfs_topdown(s), [e]*T.N, [e]*len(Ua), T.Ra[:]
# up
for i in order[::-1]:
u,v = Ua[i], Va[i]
# subtree v finished up pass, store value to accumulate for u
dp[v] = new = edge_op(dp[v], i, u, v)
dp[u] = merge(dp[u], new)
# suffix accumulation
if (c:=I[u]-1) > La[u]: suf[c-1] = merge(suf[c], new)
I[u] = c
# down
dp[s] = e # at this point dp stores values to be merged in parent
for i in order:
u,v = Ua[i], Va[i]
dp[u] = merge(pre := dp[u], dp[v])
dp[v] = edge_op(merge(suf[I[u]], pre), i, v, u)
I[u] += 1
return dp
def euler_tour(T, s = 0):
N, Va = len(T), T.Va
tin, tout, par, back = [-1]*N,[-1]*N,[-1]*N,[0]*N
order, delta = elist(2*N), elist(2*N)
st = elist(N); st.append(s)
while st:
p = par[u := st.pop()]
if tin[u] == -1:
tin[u] = len(order)
for i in T.range(u):
if (v := Va[i]) != p:
par[v], back[v] = u, i
st.append(u); st.append(v)
delta.append(1)
else:
delta.append(-1)
order.append(u)
tout[u] = len(order)
delta[0] = delta[-1] = 0
T.tin, T.tout, T.par, T.back = tin, tout, par, back
T.order, T.delta = order, delta
@classmethod
def compile(cls, N: int, shift: int = -1):
return GraphBase.compile.__func__(cls, N, N-1, shift)
class TreeWeightedBase(TreeBase, GraphWeightedBase):
def dfs_distance(T, s: int, g: Optional[int] = None):
st, Wa, Va = elist(N := T.N), T.Wa, T.Va
T.D, T.back = D, back = [inf]*N, i32f(N, -1)
D[s] = 0; st.append(s)
while st:
d = D[u := st.pop()]
if u == g: return d
for i in T.range(u):
if (nd := d+Wa[i]) < D[v := Va[i]]:
D[v], back[v] = nd, i; st.append(v)
return D if g is None else inf
def euler_tour(T, s = 0):
N, Va, Wa = len(T), T.Va, T.Wa
tin, tout, par = [-1]*N,[-1]*N,[-1]*N
order, delta, Wdelta = elist(2*N), elist(2*N), elist(2*N)
st, Wst = elist(N), elist(N)
st.append(s); Wst.append(0)
while st:
p, wd = par[u := st.pop()], Wst.pop()
if tin[u] == -1:
tin[u] = len(order)
for i in T.range(u):
if (v := Va[i]) != p:
w, par[v] = Wa[i], u
st.append(u); st.append(v); Wst.append(-w); Wst.append(w)
delta.append(1)
else:
delta.append(-1)
Wdelta.append(wd); order.append(u)
tout[u] = len(order)
delta[0] = delta[-1] = 0
T.tin, T.tout, T.par = tin, tout, par
T.order, T.delta, T.Wdelta = order, delta, Wdelta
@classmethod
def compile(cls, N: int, shift: int = -1):
return GraphWeightedBase.compile.__func__(cls, N, N-1, shift)
class TreeWeightedMeta(TreeWeightedBase, GraphWeightedMeta):
@classmethod
def compile(cls, N: int, T: list[type] = [-1,-1,int,int]):
return GraphWeightedMeta.compile.__func__(cls, N, N-1, T)