This documentation is automatically generated by online-judge-tools/verification-helper
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/lca
def main():
N, Q = read()
U = range(1,N)
P = read(list[int])
T = Tree(N, P, U)
lca = LCATable(T)
for _ in range(Q):
u, v = read()
a, _ = lca.query(u, v)
write(a)
from cp_library.alg.tree.lca_table_iterative_cls import LCATable
from cp_library.alg.tree.csr.tree_cls import Tree
from cp_library.io.read_fn import read
from cp_library.io.write_fn import write
if __name__ == '__main__':
main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/lca
def main():
N, Q = read()
U = range(1,N)
P = read(list[int])
T = Tree(N, P, U)
lca = LCATable(T)
for _ in range(Q):
u, v = read()
a, _ = lca.query(u, v)
write(a)
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
def sort2(a, b): return (a,b) if a < b else (b,a)
import operator
from itertools import accumulate
from typing import Callable, Iterable
from typing import TypeVar
_S = TypeVar('S'); _T = TypeVar('T'); _U = TypeVar('U'); _T1 = TypeVar('T1'); _T2 = TypeVar('T2'); _T3 = TypeVar('T3'); _T4 = TypeVar('T4'); _T5 = TypeVar('T5'); _T6 = TypeVar('T6')
def presum(iter: Iterable[_T], func: Callable[[_T,_T],_T] = None, initial: _T = None, step = 1) -> list[_T]:
if step == 1:
return list(accumulate(iter, func, initial=initial))
else:
assert step >= 2
if func is None: func = operator.add
A = list(iter)
if initial is not None: A = [initial] + A
for i in range(step,len(A)): A[i] = func(A[i], A[i-step])
return A
def min2(a, b): return a if a < b else b
class MinSparseTable:
def __init__(st, arr: list):
st.N = N = len(arr)
st.log = N.bit_length()
st.data = data = [0] * (st.log*N)
data[:N] = arr
for i in range(1,st.log):
a, b, c = i*N, (i-1)*N, (i-1)*N + (1 << (i-1))
for j in range(N - (1 << i) + 1):
data[a+j] = min2(data[b+j], data[c+j])
def query(st, l: int, r: int):
k = (r-l).bit_length() - 1
return min2(st.data[k*st.N + l], st.data[k*st.N + r - (1<<k)])
class LCATable(MinSparseTable):
def __init__(lca, T, root = 0):
N = len(T)
T.euler_tour(root)
lca.depth = depth = presum(T.delta)
lca.tin, lca.tout = T.tin[:], T.tout[:]
lca.mask = (1 << (shift := N.bit_length()))-1
lca.shift = shift
order = T.order
M = len(order)
packets = [0]*M
for i in range(M):
packets[i] = depth[i] << shift | order[i]
super().__init__(packets)
def _query(lca, u, v):
l, r = sort2(lca.tin[u], lca.tin[v]); r += 1
da = super().query(l, r)
return l, r, da & lca.mask, da >> lca.shift
def query(lca, u, v) -> tuple[int,int]:
l, r, a, d = lca._query(u, v)
return a, d
def distance(lca, u, v) -> int:
l, r, a, d = lca._query(u, v)
return lca.depth[l] + lca.depth[r-1] - 2*d
def path(lca, u, v):
path, par, lca, c = [], lca.T.par, lca.query(u, v)[0], u
while c != lca:
path.append(c)
c = par[c]
path.append(lca)
rev_path, c = [], v
while c != lca:
rev_path.append(c)
c = par[c]
path.extend(reversed(rev_path))
return path
from math import inf
from typing import Callable, Sequence, Union, overload
from types import GenericAlias
class Parsable:
@classmethod
def compile(cls):
def parser(io: 'IOBase'): return cls(next(io))
return parser
@classmethod
def __class_getitem__(cls, item): return GenericAlias(cls, item)
def chmin(dp, i, v):
if ch:=dp[i]>v:dp[i]=v
return ch
from enum import auto, IntFlag, IntEnum
class DFSFlags(IntFlag):
ENTER = auto()
DOWN = auto()
BACK = auto()
CROSS = auto()
LEAVE = auto()
UP = auto()
MAXDEPTH = auto()
RETURN_PARENTS = auto()
RETURN_DEPTHS = auto()
BACKTRACK = auto()
CONNECT_ROOTS = auto()
# Common combinations
ALL_EDGES = DOWN | BACK | CROSS
EULER_TOUR = DOWN | UP
INTERVAL = ENTER | LEAVE
TOPDOWN = DOWN | CONNECT_ROOTS
BOTTOMUP = UP | CONNECT_ROOTS
RETURN_ALL = RETURN_PARENTS | RETURN_DEPTHS
class DFSEvent(IntEnum):
ENTER = DFSFlags.ENTER
DOWN = DFSFlags.DOWN
BACK = DFSFlags.BACK
CROSS = DFSFlags.CROSS
LEAVE = DFSFlags.LEAVE
UP = DFSFlags.UP
MAXDEPTH = DFSFlags.MAXDEPTH
class GraphBase(Parsable):
def __init__(G, N: int, M: int, U: list[int], V: list[int],
deg: list[int], La: list[int], Ra: list[int],
Ua: list[int], Va: list[int], Ea: list[int], twin: list[int] = None):
G.N = N
'''The number of vertices.'''
G.M = M
'''The number of edges.'''
G.U = U
'''A list of source vertices in the original edge list.'''
G.V = V
'''A list of destination vertices in the original edge list.'''
G.deg = deg
'''deg[u] is the out degree of vertex u.'''
G.La = La
'''La[u] stores the start index of the list of adjacent vertices from u.'''
G.Ra = Ra
'''Ra[u] stores the stop index of the list of adjacent vertices from u.'''
G.Ua = Ua
'''Ua[i] = u for La[u] <= i < Ra[u], useful for backtracking.'''
G.Va = Va
'''Va[i] lists adjacent vertices to u for La[u] <= i < Ra[u].'''
G.Ea = Ea
'''Ea[i] lists the edge ids that start from u for La[u] <= i < Ra[u].
For undirected graphs, edge ids in range M<= e <2*M are edges from V[e-M] -> U[e-M].
'''
G.twin = twin if twin is not None else range(len(Ua))
'''twin[i] in undirected graphs stores index j of the same edge but with u and v swapped.'''
G.st: list[int] = None
G.order: list[int] = None
G.vis: list[int] = None
G.back: list[int] = None
G.tin: list[int] = None
def clear(G):
G.vis = G.back = G.tin = None
def prep_vis(G):
if G.vis is None: G.vis = u8f(G.N)
return G.vis
def prep_st(G):
if G.st is None: G.st = elist(G.N)
else: G.st.clear()
return G.st
def prep_order(G):
if G.order is None: G.order = elist(G.N)
else: G.order.clear()
return G.order
def prep_back(G):
if G.back is None: G.back = i32f(G.N, -2)
return G.back
def prep_tin(G):
if G.tin is None: G.tin = i32f(G.N, -1)
return G.tin
def _remove(G, a: int):
G.deg[u := G.Ua[a]] -= 1
G.Ra[u] = (r := G.Ra[u]-1)
G.Ua[a], G.Va[a], G.Ea[a] = G.Ua[r], G.Va[r], G.Ea[r]
G.twin[a], G.twin[r] = G.twin[r], G.twin[a]
G.twin[G.twin[a]] = a
G.twin[G.twin[r]] = r
def remove(G, a: int):
b = G.twin[a]; G._remove(a)
if a != b: G._remove(b)
def __len__(G) -> int: return G.N
def __getitem__(G, u): return view(G.Va, G.La[u], G.Ra[u])
def range(G, u): return range(G.La[u],G.Ra[u])
@overload
def distance(G) -> list[list[int]]: ...
@overload
def distance(G, s: int = 0) -> list[int]: ...
@overload
def distance(G, s: int, g: int) -> int: ...
def distance(G, s = None, g = None):
if s == None: return G.floyd_warshall()
else: return G.bfs(s, g)
def recover_path(G, s, t):
P = u32f(0)
while s != t: P.append(a := G.back[t]); t = G.Ua[a]
return P
def shortest_path(G, s: int, t: int):
if G.distance(s, t) >= inf: return None
P = G.recover_path(s, t)
P.reverse()
return P
@overload
def bfs(G, s: Union[int,list] = 0) -> list[int]: ...
@overload
def bfs(G, s: Union[int,list], g: int) -> int: ...
def bfs(G, s: int = 0, g: int = None):
S, Va, back, D = G.starts(s), G.Va, i32f(N := G.N, -1), [inf]*N
G.back, G.D = back, D
for u in S: D[u] = 0
que = Que(S)
while que:
nd = D[u := que.pop()]+1
if u == g: return nd-1
for i in G.range(u):
if chmin(D, v := Va[i], nd): back[v] = i; que.push(v)
return D if g is None else inf
def floyd_warshall(G) -> list[list[int]]:
G.D = D = [[inf]*G.N for _ in range(G.N)]
for u in range(G.N): D[u][u] = 0
for i in range(len(G.Ua)): D[G.Ua[i]][G.Va[i]] = 1
for k, Dk in enumerate(D):
for Di in D:
if (Dik := Di[k]) == inf: continue
for j in range(G.N):
chmin(Di, j, Dik+Dk[j])
return D
def find_cycle_indices(G, s: Union[int, None] = None):
Ea, Ua, Va, vis, back = G.Ea, G. Ua, G.Va, u8f(N := G.N), u32f(N, i32_max)
G.vis, G.back, st = vis, back, elist(N)
for s in G.starts(s):
if vis[s]: continue
st.append(s)
while st:
if not vis[u := st.pop()]:
st.append(u)
vis[u], pe = 1, Ea[j] if (j := back[u]) != i32_max else i32_max
for i in G.range(u):
if not vis[v := Va[i]]:
back[v] = i
st.append(v)
elif vis[v] == 1 and pe != Ea[i]:
I = u32f(1,i)
while v != u: I.append(i := back[u]), (u := Ua[i])
I.reverse()
return I
else:
vis[u] = 2
# check for self loops
for i in range(len(Ua)):
if Ua[i] == Va[i]:
return u32f(1,i)
def find_cycle(G, s: Union[int, None] = None):
if I := G.find_cycle_indices(s): return [G.Ua[i] for i in I]
def find_cycle_edge_ids(G, s: Union[int, None] = None):
if I := G.find_cycle_indices(s): return [G.Ea[i] for i in I]
def find_minimal_cycle(G, s=0):
D, par, que, Va = u32f(N := G.N, u32_max), i32f(N, -1), Que([s]), G.Va
D[s] = 0
while que:
for i in G.range(u := que.pop()):
if (v := Va[i]) == s: # Found cycle back to start
cycle = [u]
while u != s: cycle.append(u := par[u])
return cycle
if D[v] < u32_max: continue
D[v], par[v] = D[u]+1, u; que.push(v)
def dfs_topo(G, s: Union[int,list] = None) -> list[int]:
'''Returns lists of indices i where Ua[i] -> Va[i] are edges in order of top down discovery'''
vis, st, order = G.prep_vis(), G.prep_st(), G.prep_order()
for s in G.starts(s):
if vis[s]: continue
vis[s] = 1; st.append(s)
while st:
for i in G.range(st.pop()):
if vis[v := G.Va[i]]: continue
vis[v] = 1; order.append(i); st.append(v)
return order
def dfs(G, s: Union[int,list] = None, /,
backtrack = False,
max_depth = None,
enter_fn: Callable[[int],None] = None,
leave_fn: Callable[[int],None] = None,
max_depth_fn: Callable[[int],None] = None,
down_fn: Callable[[int,int,int],None] = None,
back_fn: Callable[[int,int,int],None] = None,
forward_fn: Callable[[int,int,int],None] = None,
cross_fn: Callable[[int,int,int],None] = None,
up_fn: Callable[[int,int,int],None] = None):
I, time, vis, st, back, tin = G.La[:], -1, G.prep_vis(), G.prep_st(), G.prep_back(), G.prep_tin()
for s in G.starts(s):
if vis[s]: continue
back[s], tin[s] = -1, (time := time+1); st.append(s)
while st:
if vis[u := st[-1]] == 0:
vis[u] = 1
if enter_fn: enter_fn(u)
if max_depth is not None and len(st) > max_depth:
I[u] = G.Ra[u]
if max_depth_fn: max_depth_fn(u)
if (i := I[u]) < G.Ra[u]:
I[u] += 1
if (s := vis[v := G.Va[i]]) == 0:
back[v], tin[v] = i, (time := time+1); st.append(v)
if down_fn: down_fn(u,v,i)
elif back_fn and s == 1 and back[u] != G.twin[i]: back_fn(u,v,i)
elif (cross_fn or forward_fn) and s == 2:
if forward_fn and tin[u] < tin[v]: forward_fn(u,v,i)
elif cross_fn: cross_fn(u,v,i)
else:
vis[u] = 2; st.pop()
if backtrack: vis[u], I[u] = 0, G.La[u]
if leave_fn: leave_fn(u)
if up_fn and st: up_fn(u, st[-1], back[u])
def dfs_enter_leave(G, s: Union[int,list[int],None] = None) -> Sequence[tuple[DFSEvent,int]]:
N, I = G.N, G.La[:]
st, back, plst = elist(N), i32f(N,-2), PacketList(order := elist(2*N), N-1)
G.back, ENTER, LEAVE = back, int(DFSEvent.ENTER) << plst.shift, int(DFSEvent.LEAVE) << plst.shift
for s in G.starts(s):
if back[s] >= -1: continue
back[s] = -1
order.append(ENTER | s), st.append(s)
while st:
if (i := I[u := st[-1]]) < G.Ra[u]:
I[u] += 1
if back[v := G.Va[i]] >= -1: continue
back[v] = i; order.append(ENTER | v); st.append(v)
else:
order.append(LEAVE | u); st.pop()
return plst
def starts(G, s: Union[int,list[int],None] = None) -> list[int]:
if isinstance(s, int): return [s]
elif s is None: return range(G.N)
elif isinstance(s, list): return s
else: return list(s)
@classmethod
def compile(cls, N: int, M: int, shift: int = -1):
def parse(io: IOBase):
U, V = u32f(M), u32f(M)
for i in range(M): u, v = io.readints(); U[i], V[i] = u+shift, v+shift
return cls(N, U, V)
return parse
u32_max = (1<<32)-1
i32_max = (1<<31)-1
from array import array
def u8f(N: int, elm: int = 0): return array('B', (elm,))*N # unsigned char
def u32f(N: int, elm: int = 0): return array('I', (elm,))*N # unsigned int
def i32f(N: int, elm: int = 0): return array('i', (elm,))*N # signed int
def elist(hint: int) -> list: ...
try:
from __pypy__ import newlist_hint
except:
def newlist_hint(hint): return []
elist = newlist_hint
class PacketList(Sequence[tuple[int,int]]):
def __init__(lst, A: list[int], max1: int):
lst.A = A
lst.mask = (1 << (shift := (max1).bit_length())) - 1
lst.shift = shift
def __len__(lst): return lst.A.__len__()
def __contains__(lst, x: tuple[int,int]): return lst.A.__contains__(x[0] << lst.shift | x[1])
def __getitem__(lst, key) -> tuple[int,int]:
x = lst.A[key]
return x >> lst.shift, x & lst.mask
class Que:
def __init__(que, v = None): que.q = elist(v) if isinstance(v, int) else list(v) if v else []; que.h = 0
def push(que, item): que.q.append(item)
def pop(que): que.h = (h := que.h) + 1; return que.q[h]
def extend(que, items): que.q.extend(items)
def __getitem__(que, i: int): return que.q[que.h+i]
def __setitem__(que, i: int, v): que.q[que.h+i] = v
def __len__(que): return que.q.__len__() - que.h
def __hash__(que): return hash(tuple(que.q[que.h:]))
from typing import Generic
import sys
def list_find(lst: list, value, start = 0, stop = sys.maxsize):
try:
return lst.index(value, start, stop)
except:
return -1
class view(Generic[_T]):
__slots__ = 'A', 'l', 'r'
def __init__(V, A: list[_T], l: int = 0, r: int = 0): V.A, V.l, V.r = A, l, r
def __len__(V): return V.r - V.l
def __getitem__(V, i: int):
if 0 <= i < V.r - V.l: return V.A[V.l+i]
else: raise IndexError
def __setitem__(V, i: int, v: _T): V.A[V.l+i] = v
def __contains__(V, v: _T): return list_find(V.A, v, V.l, V.r) != -1
def set_range(V, l: int, r: int): V.l, V.r = l, r
def index(V, v: _T): return V.A.index(v, V.l, V.r) - V.l
def reverse(V):
l, r = V.l, V.r-1
while l < r: V.A[l], V.A[r] = V.A[r], V.A[l]; l += 1; r -= 1
def sort(V, /, *args, **kwargs):
A = V.A[V.l:V.r]; A.sort(*args, **kwargs)
for i,a in enumerate(A,V.l): V.A[i] = a
def pop(V): V.r -= 1; return V.A[V.r]
def append(V, v: _T): V.A[V.r] = v; V.r += 1
def popleft(V): V.l += 1; return V.A[V.l-1]
def appendleft(V, v: _T): V.l -= 1; V.A[V.l] = v;
def validate(V): return 0 <= V.l <= V.r <= len(V.A)
class IOBase:
@property
def char(io) -> bool: ...
@property
def writable(io) -> bool: ...
def __next__(io) -> str: ...
def write(io, s: str) -> None: ...
def readline(io) -> str: ...
def readtoken(io) -> str: ...
def readtokens(io) -> list[str]: ...
def readints(io) -> list[int]: ...
def readdigits(io) -> list[int]: ...
def readnums(io) -> list[int]: ...
def readchar(io) -> str: ...
def readchars(io) -> str: ...
def readinto(io, lst: list[str]) -> list[str]: ...
def readcharsinto(io, lst: list[str]) -> list[str]: ...
def readtokensinto(io, lst: list[str]) -> list[str]: ...
def readintsinto(io, lst: list[int]) -> list[int]: ...
def readdigitsinto(io, lst: list[int]) -> list[int]: ...
def readnumsinto(io, lst: list[int]) -> list[int]: ...
def wait(io): ...
def flush(io) -> None: ...
def line(io) -> list[str]: ...
class Graph(GraphBase):
def __init__(G, N: int, U: list[int], V: list[int]):
M, Ma, deg = len(U), 0, u32f(N)
for e in range(M := len(U)):
distinct = (u := U[e]) != (v := V[e])
deg[u] += 1; deg[v] += distinct; Ma += 1+distinct
twin, Ea, Ua, Va, La, Ra, i = i32f(Ma), i32f(Ma), u32f(Ma), u32f(Ma), u32f(N), u32f(N), 0
for u in range(N): La[u] = Ra[u] = i; i = i+deg[u]
for e in range(M):
i, j = Ra[u := U[e]], Ra[v := V[e]]
Ra[u], Ua[i], Va[i], Ea[i], twin[i] = i+1, u, v, e, j
if i == j: continue
Ra[v], Ua[j], Va[j], Ea[j], twin[j] = j+1, v, u, e, i
super().__init__(N, M, U, V, deg, La, Ra, Ua, Va, Ea, twin)
from typing import Callable, Literal, Union, overload
class TreeBase(GraphBase):
@overload
def distance(T) -> list[list[int]]: ...
@overload
def distance(T, s: int = 0) -> list[int]: ...
@overload
def distance(T, s: int, g: int) -> int: ...
def distance(T, s = None, g = None):
if s == None:
return [T.dfs_distance(u) for u in range(T.N)]
else:
return T.dfs_distance(s, g)
@overload
def diameter(T) -> int: ...
@overload
def diameter(T, endpoints: Literal[True]) -> tuple[int,int,int]: ...
def diameter(T, endpoints = False):
mask = (1 << (shift := T.N.bit_length())) - 1
s = max(d << shift | v for v,d in enumerate(T.distance(0))) & mask
dg = max(d << shift | v for v,d in enumerate(T.distance(s)))
diam, g = dg >> shift, dg & mask
return (diam, s, g) if endpoints else diam
def dfs_distance(T, s: int, g: Union[int,None] = None):
st, Va = elist(N := T.N), T.Va
T.D, T.back = D, back = [inf]*N, i32f(N, -1)
D[s] = 0
st.append(s)
while st:
nd = D[u := st.pop()]+1
if u == g: return nd-1
for i in T.range(u):
if nd < D[v := Va[i]]:
D[v], back[v] = nd, i
st.append(v)
return D if g is None else inf
def rerooting_dp(T, e: _T,
merge: Callable[[_T,_T],_T],
edge_op: Callable[[_T,int,int,int],_T] = lambda s,i,p,u:s,
s: int = 0):
La, Ua, Va = T.La, T.Ua, T.Va
order, dp, suf, I = T.dfs_topo(s), [e]*T.N, [e]*len(Ua), T.Ra[:]
# up
for i in order[::-1]:
u,v = Ua[i], Va[i]
# subtree v finished up pass, store value to accumulate for u
dp[v] = new = edge_op(dp[v], i, u, v)
dp[u] = merge(dp[u], new)
# suffix accumulation
if (c:=I[u]-1) > La[u]: suf[c-1] = merge(suf[c], new)
I[u] = c
# down
dp[s] = e # at this point dp stores values to be merged in parent
for i in order:
u,v = Ua[i], Va[i]
dp[u] = merge(pre := dp[u], dp[v])
dp[v] = edge_op(merge(suf[I[u]], pre), i, v, u)
I[u] += 1
return dp
def euler_tour(T, s = 0):
N, Va = len(T), T.Va
tin, tout, par, back = [-1]*N,[-1]*N,[-1]*N,[0]*N
order, delta = elist(2*N), elist(2*N)
st = elist(N); st.append(s)
while st:
p = par[u := st.pop()]
if tin[u] == -1:
tin[u] = len(order)
for i in T.range(u):
if (v := Va[i]) != p:
par[v], back[v] = u, i
st.append(u); st.append(v)
delta.append(1)
else:
delta.append(-1)
order.append(u)
tout[u] = len(order)
delta[0] = delta[-1] = 0
T.tin, T.tout, T.par, T.back = tin, tout, par, back
T.order, T.delta = order, delta
@classmethod
def compile(cls, N: int, shift: int = -1):
return GraphBase.compile.__func__(cls, N, N-1, shift)
class Tree(TreeBase, Graph):
pass
from typing import Type, Union, overload
@overload
def read() -> list[int]: ...
@overload
def read(spec: Type[_T], char=False) -> _T: ...
@overload
def read(spec: _U, char=False) -> _U: ...
@overload
def read(*specs: Type[_T], char=False) -> tuple[_T, ...]: ...
@overload
def read(*specs: _U, char=False) -> tuple[_U, ...]: ...
def read(*specs: Union[Type[_T],_T], char=False):
IO.stdin.char = char
if not specs: return IO.stdin.readnumsinto([])
parser: _T = Parser.compile(specs[0] if len(specs) == 1 else specs)
return parser(IO.stdin)
from os import read as os_read, write as os_write, fstat as os_fstat
from __pypy__.builders import StringBuilder
def max2(a, b): return a if a > b else b
class IO(IOBase):
BUFSIZE = 1 << 16; stdin: 'IO'; stdout: 'IO'
__slots__ = 'f', 'file', 'B', 'O', 'V', 'S', 'l', 'p', 'char', 'sz', 'st', 'ist', 'writable', 'encoding', 'errors'
def __init__(io, file):
io.file = file
try: io.f = file.fileno(); io.sz, io.writable = max2(io.BUFSIZE, os_fstat(io.f).st_size), ('x' in file.mode or 'r' not in file.mode)
except: io.f, io.sz, io.writable = -1, io.BUFSIZE, False
io.B, io.O, io.S = bytearray(), [], StringBuilder(); io.V = memoryview(io.B); io.l = io.p = 0
io.char, io.st, io.ist, io.encoding, io.errors = False, [], [], 'ascii', 'ignore'
def _dec(io, l, r): return io.V[l:r].tobytes().decode(io.encoding, io.errors)
def readbytes(io, sz): return os_read(io.f, sz)
def load(io):
while io.l >= len(io.O):
if not (b := io.readbytes(io.sz)):
if io.O[-1] < len(io.B): io.O.append(len(io.B))
break
pos = len(io.B); io.B.extend(b)
while ~(pos := io.B.find(b'\n', pos)): io.O.append(pos := pos+1)
def __next__(io):
if io.char: return io.readchar()
else: return io.readtoken()
def readchar(io):
io.load(); r = io.O[io.l]
c = chr(io.B[io.p])
if io.p >= r-1: io.p = r; io.l += 1
else: io.p += 1
return c
def write(io, s: str): io.S.append(s)
def readline(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p)
def readtoken(io):
io.load(); r = io.O[io.l]
if ~(p := io.B.find(b' ', io.p, r)): s = io._dec(io.p, p); io.p = p+1
else: s = io._dec(io.p, r-1); io.p = r; io.l += 1
return s
def readtokens(io): io.st.clear(); return io.readtokensinto(io.st)
def readints(io): io.ist.clear(); return io.readintsinto(io.ist)
def readdigits(io): io.ist.clear(); return io.readdigitsinto(io.ist)
def readnums(io): io.ist.clear(); return io.readnumsinto(io.ist)
def readchars(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p-1)
def readinto(io, lst):
if io.char: return io.readcharsinto(lst)
else: return io.readtokensinto(lst)
def readcharsinto(io, lst): lst.extend(io.readchars()); return lst
def readtokensinto(io, lst):
io.load(); r = io.O[io.l]
while ~(p := io.B.find(b' ', io.p, r)): lst.append(io._dec(io.p, p)); io.p = p+1
lst.append(io._dec(io.p, r-1)); io.p = r; io.l += 1; return lst
def _readint(io, r):
while io.p < r and io.B[io.p] <= 32: io.p += 1
if io.p >= r: return None
minus = x = 0
if io.B[io.p] == 45: minus = 1; io.p += 1
while io.p < r and io.B[io.p] >= 48: x = x * 10 + (io.B[io.p] & 15); io.p += 1
io.p += 1
return -x if minus else x
def readintsinto(io, lst):
io.load(); r = io.O[io.l]
while io.p < r and (x := io._readint(r)) is not None: lst.append(x)
io.l += 1; return lst
def _readdigit(io): d = io.B[io.p] & 15; io.p += 1; return d
def readdigitsinto(io, lst):
io.load(); r = io.O[io.l]
while io.p < r and io.B[io.p] > 32: lst.append(io._readdigit())
if io.B[io.p] == 10: io.l += 1
io.p += 1
return lst
def readnumsinto(io, lst):
if io.char: return io.readdigitsinto(lst)
else: return io.readintsinto(lst)
def line(io): io.st.clear(); return io.readinto(io.st)
def wait(io):
io.load(); r = io.O[io.l]
while io.p < r: yield
def flush(io):
if io.writable: os_write(io.f, io.S.build().encode(io.encoding, io.errors)); io.S = StringBuilder()
sys.stdin = IO.stdin = IO(sys.stdin); sys.stdout = IO.stdout = IO(sys.stdout)
import typing
from numbers import Number
from typing import Callable, Collection
class Parser:
def __init__(self, spec): self.parse = Parser.compile(spec)
def __call__(self, io: IOBase): return self.parse(io)
@staticmethod
def compile_type(cls, args = ()):
if issubclass(cls, Parsable): return cls.compile(*args)
elif issubclass(cls, (Number, str)):
def parse(io: IOBase): return cls(next(io))
return parse
elif issubclass(cls, tuple): return Parser.compile_tuple(cls, args)
elif issubclass(cls, Collection): return Parser.compile_collection(cls, args)
elif callable(cls):
def parse(io: IOBase): return cls(next(io))
return parse
else: raise NotImplementedError()
@staticmethod
def compile(spec=int):
if isinstance(spec, (type, GenericAlias)):
cls, args = typing.get_origin(spec) or spec, typing.get_args(spec) or tuple()
return Parser.compile_type(cls, args)
elif isinstance(offset := spec, Number):
cls = type(spec)
def parse(io: IOBase): return cls(next(io)) + offset
return parse
elif isinstance(args := spec, tuple): return Parser.compile_tuple(type(spec), args)
elif isinstance(args := spec, Collection): return Parser.compile_collection(type(spec), args)
elif isinstance(fn := spec, Callable):
def parse(io: IOBase): return fn(next(io))
return parse
else: raise NotImplementedError()
@staticmethod
def compile_line(cls, spec=int):
if spec is int:
def parse(io: IOBase): return cls(io.readnums())
elif spec is str:
def parse(io: IOBase): return cls(io.line())
else:
fn = Parser.compile(spec)
def parse(io: IOBase): return cls((fn(io) for _ in io.wait()))
return parse
@staticmethod
def compile_repeat(cls, spec, N):
fn = Parser.compile(spec)
def parse(io: IOBase): return cls([fn(io) for _ in range(N)])
return parse
@staticmethod
def compile_children(cls, specs):
fns = tuple((Parser.compile(spec) for spec in specs))
def parse(io: IOBase): return cls([fn(io) for fn in fns])
return parse
@staticmethod
def compile_tuple(cls, specs):
if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...: return Parser.compile_line(cls, specs[0])
else: return Parser.compile_children(cls, specs)
@staticmethod
def compile_collection(cls, specs):
if not specs or len(specs) == 1 or isinstance(specs, set):
return Parser.compile_line(cls, *specs)
elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
return Parser.compile_repeat(cls, specs[0], specs[1])
else:
raise NotImplementedError()
def write(*args, **kwargs):
'''Prints the values to a stream, or to stdout_fast by default.'''
sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IO.stdout)
at_start = True
for x in args:
if not at_start: file.write(sep)
file.write(str(x))
at_start = False
file.write(kwargs.pop("end", "\n"))
if kwargs.pop("flush", False): file.flush()
if __name__ == '__main__':
main()