cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/tree/lca.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/lca

def main():
    N, Q = read()
    U = range(1,N)
    P = read(list[int])
    T = Tree(N, P, U)
    lca = LCATable(T)
    for _ in range(Q):
        u, v = read()
        a, _ = lca.query(u, v)
        write(a)

from cp_library.alg.tree.lca_table_iterative_cls import LCATable
from cp_library.alg.tree.csr.tree_cls import Tree
from cp_library.io.read_fn import read
from cp_library.io.write_fn import write

if __name__ == '__main__':
    main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/lca

def main():
    N, Q = read()
    U = range(1,N)
    P = read(list[int])
    T = Tree(N, P, U)
    lca = LCATable(T)
    for _ in range(Q):
        u, v = read()
        a, _ = lca.query(u, v)
        write(a)

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''




def sort2(a, b): return (a,b) if a < b else (b,a)
import operator
from itertools import accumulate
from typing import Callable, Iterable

from typing import TypeVar

_S = TypeVar('S'); _T = TypeVar('T'); _U = TypeVar('U'); _T1 = TypeVar('T1'); _T2 = TypeVar('T2'); _T3 = TypeVar('T3'); _T4 = TypeVar('T4'); _T5 = TypeVar('T5'); _T6 = TypeVar('T6')

def presum(iter: Iterable[_T], func: Callable[[_T,_T],_T] = None, initial: _T = None, step = 1) -> list[_T]:
    if step == 1:
        return list(accumulate(iter, func, initial=initial))
    else:
        assert step >= 2
        if func is None: func = operator.add
        A = list(iter)
        if initial is not None: A = [initial] + A
        for i in range(step,len(A)): A[i] = func(A[i], A[i-step])
        return A

def min2(a, b): return a if a < b else b


class MinSparseTable:
    def __init__(st, arr: list):
        st.N = N = len(arr)
        st.log = N.bit_length()
        st.data = data = [0] * (st.log*N)
        data[:N] = arr 
        for i in range(1,st.log):
            a, b, c = i*N, (i-1)*N, (i-1)*N + (1 << (i-1))
            for j in range(N - (1 << i) + 1):
                data[a+j] = min2(data[b+j], data[c+j])

    def query(st, l: int, r: int):
        k = (r-l).bit_length() - 1
        return min2(st.data[k*st.N + l], st.data[k*st.N + r - (1<<k)])
    

class LCATable(MinSparseTable):
    def __init__(lca, T, root = 0):
        N = len(T)
        T.euler_tour(root)
        lca.depth = depth = presum(T.delta)
        lca.tin, lca.tout = T.tin[:], T.tout[:]
        lca.mask = (1 << (shift := N.bit_length()))-1
        lca.shift = shift
        order = T.order
        M = len(order)
        packets = [0]*M
        for i in range(M):
            packets[i] = depth[i] << shift | order[i] 
        super().__init__(packets)

    def _query(lca, u, v):
        l, r = sort2(lca.tin[u], lca.tin[v]); r += 1
        da = super().query(l, r)
        return l, r, da & lca.mask, da >> lca.shift

    def query(lca, u, v) -> tuple[int,int]:
        l, r, a, d = lca._query(u, v)
        return a, d
    
    def distance(lca, u, v) -> int:
        l, r, a, d = lca._query(u, v)
        return lca.depth[l] + lca.depth[r-1] - 2*d
    
    def path(lca, u, v):
        path, par, lca, c = [], lca.T.par, lca.query(u, v)[0], u
        while c != lca:
            path.append(c)
            c = par[c]
        path.append(lca)
        rev_path, c = [], v
        while c != lca:
            rev_path.append(c)
            c = par[c]
        path.extend(reversed(rev_path))
        return path



from math import inf
from typing import Callable, Sequence, Union, overload
from types import GenericAlias


class Parsable:
    @classmethod
    def compile(cls):
        def parser(io: 'IOBase'): return cls(next(io))
        return parser
    @classmethod
    def __class_getitem__(cls, item): return GenericAlias(cls, item)

def chmin(dp, i, v):
    if ch:=dp[i]>v:dp[i]=v
    return ch
from enum import auto, IntFlag, IntEnum

class DFSFlags(IntFlag):
    ENTER = auto()
    DOWN = auto()
    BACK = auto()
    CROSS = auto()
    LEAVE = auto()
    UP = auto()
    MAXDEPTH = auto()

    RETURN_PARENTS = auto()
    RETURN_DEPTHS = auto()
    BACKTRACK = auto()
    CONNECT_ROOTS = auto()

    # Common combinations
    ALL_EDGES = DOWN | BACK | CROSS
    EULER_TOUR = DOWN | UP
    INTERVAL = ENTER | LEAVE
    TOPDOWN = DOWN | CONNECT_ROOTS
    BOTTOMUP = UP | CONNECT_ROOTS
    RETURN_ALL = RETURN_PARENTS | RETURN_DEPTHS

class DFSEvent(IntEnum):
    ENTER = DFSFlags.ENTER 
    DOWN = DFSFlags.DOWN 
    BACK = DFSFlags.BACK 
    CROSS = DFSFlags.CROSS 
    LEAVE = DFSFlags.LEAVE 
    UP = DFSFlags.UP 
    MAXDEPTH = DFSFlags.MAXDEPTH

class GraphBase(Parsable):
    def __init__(G, N: int, M: int, U: list[int], V: list[int], 
                 deg: list[int], La: list[int], Ra: list[int],
                 Ua: list[int], Va: list[int], Ea: list[int], twin: list[int] = None):
        G.N = N
        '''The number of vertices.'''
        G.M = M
        '''The number of edges.'''
        G.U = U
        '''A list of source vertices in the original edge list.'''
        G.V = V
        '''A list of destination vertices in the original edge list.'''
        G.deg = deg
        '''deg[u] is the out degree of vertex u.'''
        G.La = La
        '''La[u] stores the start index of the list of adjacent vertices from u.'''
        G.Ra = Ra
        '''Ra[u] stores the stop index of the list of adjacent vertices from u.'''
        G.Ua = Ua
        '''Ua[i] = u for La[u] <= i < Ra[u], useful for backtracking.'''
        G.Va = Va
        '''Va[i] lists adjacent vertices to u for La[u] <= i < Ra[u].'''
        G.Ea = Ea
        '''Ea[i] lists the edge ids that start from u for La[u] <= i < Ra[u].
        For undirected graphs, edge ids in range M<= e <2*M are edges from V[e-M] -> U[e-M].
        '''
        G.twin = twin if twin is not None else range(len(Ua))
        '''twin[i] in undirected graphs stores index j of the same edge but with u and v swapped.'''
        G.st: list[int] = None
        G.order: list[int] = None
        G.vis: list[int] = None
        G.back: list[int] = None
        G.tin: list[int] = None
    
    def clear(G):
        G.vis = G.back = G.tin = None

    def prep_vis(G):
        if G.vis is None: G.vis = u8f(G.N)
        return G.vis
    
    def prep_st(G):
        if G.st is None: G.st = elist(G.N)
        else: G.st.clear()
        return G.st
    
    def prep_order(G):
        if G.order is None: G.order = elist(G.N)
        else: G.order.clear()
        return G.order
    
    def prep_back(G):
        if G.back is None: G.back = i32f(G.N, -2)
        return G.back
    
    def prep_tin(G):
        if G.tin is None: G.tin = i32f(G.N, -1)
        return G.tin
    
    def _remove(G, a: int):
        G.deg[u := G.Ua[a]] -= 1
        G.Ra[u] = (r := G.Ra[u]-1)
        G.Ua[a], G.Va[a], G.Ea[a] = G.Ua[r], G.Va[r], G.Ea[r]
        G.twin[a], G.twin[r] = G.twin[r], G.twin[a]
        G.twin[G.twin[a]] = a
        G.twin[G.twin[r]] = r

    def remove(G, a: int):
        b = G.twin[a]; G._remove(a)
        if a != b: G._remove(b)

    def __len__(G) -> int: return G.N
    def __getitem__(G, u): return view(G.Va, G.La[u], G.Ra[u])
    def range(G, u): return range(G.La[u],G.Ra[u])
    
    @overload
    def distance(G) -> list[list[int]]: ...
    @overload
    def distance(G, s: int = 0) -> list[int]: ...
    @overload
    def distance(G, s: int, g: int) -> int: ...
    def distance(G, s = None, g = None):
        if s == None: return G.floyd_warshall()
        else: return G.bfs(s, g)

    def recover_path(G, s, t):
        P = u32f(0)
        while s != t: P.append(a := G.back[t]); t = G.Ua[a] 
        return P
    
    def shortest_path(G, s: int, t: int):
        if G.distance(s, t) >= inf: return None
        P = G.recover_path(s, t)
        P.reverse()
        return P
    
    @overload
    def bfs(G, s: Union[int,list] = 0) -> list[int]: ...
    @overload
    def bfs(G, s: Union[int,list], g: int) -> int: ...
    def bfs(G, s: int = 0, g: int = None):
        S, Va, back, D = G.starts(s), G.Va, i32f(N := G.N, -1), [inf]*N
        G.back, G.D = back, D
        for u in S: D[u] = 0
        que = Que(S)
        while que:
            nd = D[u := que.pop()]+1
            if u == g: return nd-1
            for i in G.range(u):
                if chmin(D, v := Va[i], nd): back[v] = i; que.push(v)
        return D if g is None else inf 

    def floyd_warshall(G) -> list[list[int]]:
        G.D = D = [[inf]*G.N for _ in range(G.N)]
        for u in range(G.N): D[u][u] = 0
        for i in range(len(G.Ua)): D[G.Ua[i]][G.Va[i]] = 1
        for k, Dk in enumerate(D):
            for Di in D:
                if (Dik := Di[k]) == inf: continue
                for j in range(G.N):
                    chmin(Di, j, Dik+Dk[j])
        return D

    def find_cycle_indices(G, s: Union[int, None] = None):
        Ea, Ua, Va, vis, back = G.Ea, G. Ua, G.Va, u8f(N := G.N), u32f(N, i32_max)
        G.vis, G.back, st = vis, back, elist(N)
        for s in G.starts(s):
            if vis[s]: continue
            st.append(s)
            while st:
                if not vis[u := st.pop()]:
                    st.append(u)
                    vis[u], pe = 1, Ea[j] if (j := back[u]) != i32_max else i32_max
                    for i in G.range(u):
                        if not vis[v := Va[i]]:
                            back[v] = i
                            st.append(v)
                        elif vis[v] == 1 and pe != Ea[i]:
                            I = u32f(1,i)
                            while v != u: I.append(i := back[u]), (u := Ua[i])
                            I.reverse()
                            return I
                else:
                    vis[u] = 2
        # check for self loops
        for i in range(len(Ua)):
            if Ua[i] == Va[i]:
                return u32f(1,i)
    
    def find_cycle(G, s: Union[int, None] = None):
        if I := G.find_cycle_indices(s): return [G.Ua[i] for i in I]
    
    def find_cycle_edge_ids(G, s: Union[int, None] = None):
        if I := G.find_cycle_indices(s): return [G.Ea[i] for i in I]

    def find_minimal_cycle(G, s=0):
        D, par, que, Va = u32f(N := G.N, u32_max), i32f(N, -1), Que([s]), G.Va
        D[s] = 0
        while que:
            for i in G.range(u := que.pop()):
                if (v := Va[i]) == s:  # Found cycle back to start
                    cycle = [u]
                    while u != s: cycle.append(u := par[u])
                    return cycle
                if D[v] < u32_max: continue
                D[v], par[v] = D[u]+1, u; que.push(v)

    def dfs_topo(G, s: Union[int,list] = None) -> list[int]:
        '''Returns lists of indices i where Ua[i] -> Va[i] are edges in order of top down discovery'''
        vis, st, order = G.prep_vis(), G.prep_st(), G.prep_order()
        for s in G.starts(s):
            if vis[s]: continue
            vis[s] = 1; st.append(s) 
            while st:
                for i in G.range(st.pop()):
                    if vis[v := G.Va[i]]: continue
                    vis[v] = 1; order.append(i); st.append(v)
        return order

    def dfs(G, s: Union[int,list] = None, /, 
            backtrack = False,
            max_depth = None,
            enter_fn: Callable[[int],None] = None,
            leave_fn: Callable[[int],None] = None,
            max_depth_fn: Callable[[int],None] = None,
            down_fn: Callable[[int,int,int],None] = None,
            back_fn: Callable[[int,int,int],None] = None,
            forward_fn: Callable[[int,int,int],None] = None,
            cross_fn: Callable[[int,int,int],None] = None,
            up_fn: Callable[[int,int,int],None] = None):
        I, time, vis, st, back, tin = G.La[:], -1, G.prep_vis(), G.prep_st(), G.prep_back(), G.prep_tin()
        for s in G.starts(s):
            if vis[s]: continue
            back[s], tin[s] = -1, (time := time+1); st.append(s)
            while st:
                if vis[u := st[-1]] == 0:
                    vis[u] = 1
                    if enter_fn: enter_fn(u)
                    if max_depth is not None and len(st) > max_depth:
                        I[u] = G.Ra[u]
                        if max_depth_fn: max_depth_fn(u)
                if (i := I[u]) < G.Ra[u]:
                    I[u] += 1
                    if (s := vis[v := G.Va[i]]) == 0:
                        back[v], tin[v] = i, (time := time+1); st.append(v)
                        if down_fn: down_fn(u,v,i)
                    elif back_fn and s == 1 and back[u] != G.twin[i]: back_fn(u,v,i)
                    elif (cross_fn or forward_fn) and s == 2:
                        if forward_fn and tin[u] < tin[v]: forward_fn(u,v,i)
                        elif cross_fn: cross_fn(u,v,i)
                else:
                    vis[u] = 2; st.pop()
                    if backtrack: vis[u], I[u] = 0, G.La[u]
                    if leave_fn: leave_fn(u)
                    if up_fn and st: up_fn(u, st[-1], back[u])
    
    def dfs_enter_leave(G, s: Union[int,list[int],None] = None) -> Sequence[tuple[DFSEvent,int]]:
        N, I = G.N, G.La[:]
        st, back, plst = elist(N), i32f(N,-2), PacketList(order := elist(2*N), N-1)
        G.back, ENTER, LEAVE = back, int(DFSEvent.ENTER) << plst.shift, int(DFSEvent.LEAVE) << plst.shift
        for s in G.starts(s):
            if back[s] >= -1: continue
            back[s] = -1
            order.append(ENTER | s), st.append(s)
            while st:
                if (i := I[u := st[-1]]) < G.Ra[u]:
                    I[u] += 1
                    if back[v := G.Va[i]] >= -1: continue
                    back[v] = i; order.append(ENTER | v); st.append(v)
                else:
                    order.append(LEAVE | u); st.pop()
        return plst
    
    def starts(G, s: Union[int,list[int],None] = None) -> list[int]:
        if isinstance(s, int): return [s]
        elif s is None: return range(G.N)
        elif isinstance(s, list): return s
        else: return list(s)

    @classmethod
    def compile(cls, N: int, M: int, shift: int = -1):
        def parse(io: IOBase):
            U, V = u32f(M), u32f(M)
            for i in range(M): u, v = io.readints(); U[i], V[i] = u+shift, v+shift
            return cls(N, U, V)
        return parse


u32_max = (1<<32)-1
i32_max = (1<<31)-1
from array import array

def u8f(N: int, elm: int = 0):      return array('B', (elm,))*N  # unsigned char
def u32f(N: int, elm: int = 0):     return array('I', (elm,))*N  # unsigned int
def i32f(N: int, elm: int = 0):     return array('i', (elm,))*N  # signed int


def elist(hint: int) -> list: ...
try:
    from __pypy__ import newlist_hint
except:
    def newlist_hint(hint): return []
elist = newlist_hint
    

class PacketList(Sequence[tuple[int,int]]):
    def __init__(lst, A: list[int], max1: int):
        lst.A = A
        lst.mask = (1 << (shift := (max1).bit_length())) - 1
        lst.shift = shift
    def __len__(lst): return lst.A.__len__()
    def __contains__(lst, x: tuple[int,int]): return lst.A.__contains__(x[0] << lst.shift | x[1])
    def __getitem__(lst, key) -> tuple[int,int]:
        x = lst.A[key]
        return x >> lst.shift, x & lst.mask


class Que:
    def __init__(que, v = None): que.q = elist(v) if isinstance(v, int) else list(v) if v else []; que.h = 0
    def push(que, item): que.q.append(item)
    def pop(que): que.h = (h := que.h) + 1; return que.q[h]
    def extend(que, items): que.q.extend(items)
    def __getitem__(que, i: int): return que.q[que.h+i]
    def __setitem__(que, i: int, v): que.q[que.h+i] = v
    def __len__(que): return que.q.__len__() - que.h
    def __hash__(que): return hash(tuple(que.q[que.h:]))

from typing import Generic
import sys

def list_find(lst: list, value, start = 0, stop = sys.maxsize):
    try:
        return lst.index(value, start, stop)
    except:
        return -1


class view(Generic[_T]):
    __slots__ = 'A', 'l', 'r'
    def __init__(V, A: list[_T], l: int = 0, r: int = 0): V.A, V.l, V.r = A, l, r
    def __len__(V): return V.r - V.l
    def __getitem__(V, i: int): 
        if 0 <= i < V.r - V.l: return V.A[V.l+i]
        else: raise IndexError
    def __setitem__(V, i: int, v: _T): V.A[V.l+i] = v
    def __contains__(V, v: _T): return list_find(V.A, v, V.l, V.r) != -1
    def set_range(V, l: int, r: int): V.l, V.r = l, r
    def index(V, v: _T): return V.A.index(v, V.l, V.r) - V.l
    def reverse(V):
        l, r = V.l, V.r-1
        while l < r: V.A[l], V.A[r] = V.A[r], V.A[l]; l += 1; r -= 1
    def sort(V, /, *args, **kwargs):
        A = V.A[V.l:V.r]; A.sort(*args, **kwargs)
        for i,a in enumerate(A,V.l): V.A[i] = a
    def pop(V): V.r -= 1; return V.A[V.r]
    def append(V, v: _T): V.A[V.r] = v; V.r += 1
    def popleft(V): V.l += 1; return V.A[V.l-1]
    def appendleft(V, v: _T): V.l -= 1; V.A[V.l] = v; 
    def validate(V): return 0 <= V.l <= V.r <= len(V.A)

class IOBase:
    @property
    def char(io) -> bool: ...
    @property
    def writable(io) -> bool: ...
    def __next__(io) -> str: ...
    def write(io, s: str) -> None: ...
    def readline(io) -> str: ...
    def readtoken(io) -> str: ...
    def readtokens(io) -> list[str]: ...
    def readints(io) -> list[int]: ...
    def readdigits(io) -> list[int]: ...
    def readnums(io) -> list[int]: ...
    def readchar(io) -> str: ...
    def readchars(io) -> str: ...
    def readinto(io, lst: list[str]) -> list[str]: ...
    def readcharsinto(io, lst: list[str]) -> list[str]: ...
    def readtokensinto(io, lst: list[str]) -> list[str]: ...
    def readintsinto(io, lst: list[int]) -> list[int]: ...
    def readdigitsinto(io, lst: list[int]) -> list[int]: ...
    def readnumsinto(io, lst: list[int]) -> list[int]: ...
    def wait(io): ...
    def flush(io) -> None: ...
    def line(io) -> list[str]: ...

class Graph(GraphBase):
    def __init__(G, N: int, U: list[int], V: list[int]):
        M, Ma, deg = len(U), 0, u32f(N)
        for e in range(M := len(U)):
            distinct = (u := U[e]) != (v := V[e])
            deg[u] += 1; deg[v] += distinct; Ma += 1+distinct
        twin, Ea, Ua, Va, La, Ra, i = i32f(Ma), i32f(Ma), u32f(Ma), u32f(Ma), u32f(N), u32f(N), 0
        for u in range(N): La[u] = Ra[u] = i; i = i+deg[u]
        for e in range(M):
            i, j = Ra[u := U[e]], Ra[v := V[e]]
            Ra[u], Ua[i], Va[i], Ea[i], twin[i] = i+1, u, v, e, j
            if i == j: continue
            Ra[v], Ua[j], Va[j], Ea[j], twin[j] = j+1, v, u, e, i
        super().__init__(N, M, U, V, deg, La, Ra, Ua, Va, Ea, twin)
from typing import Callable, Literal, Union, overload

class TreeBase(GraphBase):
    @overload
    def distance(T) -> list[list[int]]: ...
    @overload
    def distance(T, s: int = 0) -> list[int]: ...
    @overload
    def distance(T, s: int, g: int) -> int: ...
    def distance(T, s = None, g = None):
        if s == None:
            return [T.dfs_distance(u) for u in range(T.N)]
        else:
            return T.dfs_distance(s, g)

    @overload
    def diameter(T) -> int: ...
    @overload
    def diameter(T, endpoints: Literal[True]) -> tuple[int,int,int]: ...
    def diameter(T, endpoints = False):
        mask = (1 << (shift := T.N.bit_length())) - 1
        s = max(d << shift | v for v,d in enumerate(T.distance(0))) & mask
        dg = max(d << shift | v for v,d in enumerate(T.distance(s))) 
        diam, g = dg >> shift, dg & mask
        return (diam, s, g) if endpoints else diam
    
    def dfs_distance(T, s: int, g: Union[int,None] = None):
        st, Va = elist(N := T.N), T.Va
        T.D, T.back = D, back = [inf]*N, i32f(N, -1)
        D[s] = 0
        st.append(s)
        while st:
            nd = D[u := st.pop()]+1
            if u == g: return nd-1
            for i in T.range(u):
                if nd < D[v := Va[i]]:
                    D[v], back[v] = nd, i
                    st.append(v)
        return D if g is None else inf

    def rerooting_dp(T, e: _T, 
                     merge: Callable[[_T,_T],_T], 
                     edge_op: Callable[[_T,int,int,int],_T] = lambda s,i,p,u:s,
                     s: int = 0):
        La, Ua, Va = T.La, T.Ua, T.Va
        order, dp, suf, I = T.dfs_topo(s), [e]*T.N, [e]*len(Ua), T.Ra[:]
        # up
        for i in order[::-1]:
            u,v = Ua[i], Va[i]
            # subtree v finished up pass, store value to accumulate for u
            dp[v] = new = edge_op(dp[v], i, u, v)
            dp[u] = merge(dp[u], new)
            # suffix accumulation
            if (c:=I[u]-1) > La[u]: suf[c-1] = merge(suf[c], new)
            I[u] = c
        # down
        dp[s] = e # at this point dp stores values to be merged in parent
        for i in order:
            u,v = Ua[i], Va[i]
            dp[u] = merge(pre := dp[u], dp[v])
            dp[v] = edge_op(merge(suf[I[u]], pre), i, v, u)
            I[u] += 1
        return dp
    
    def euler_tour(T, s = 0):
        N, Va = len(T), T.Va
        tin, tout, par, back = [-1]*N,[-1]*N,[-1]*N,[0]*N
        order, delta = elist(2*N), elist(2*N)
        
        st = elist(N); st.append(s)
        while st:
            p = par[u := st.pop()]
            if tin[u] == -1:
                tin[u] = len(order)
                for i in T.range(u):
                    if (v := Va[i]) != p:
                        par[v], back[v] = u, i
                        st.append(u); st.append(v)
                delta.append(1)
            else:
                delta.append(-1)
            
            order.append(u)
            tout[u] = len(order)
        delta[0] = delta[-1] = 0
        T.tin, T.tout, T.par, T.back = tin, tout, par, back
        T.order, T.delta = order, delta

    @classmethod
    def compile(cls, N: int, shift: int = -1):
        return GraphBase.compile.__func__(cls, N, N-1, shift)
    

class Tree(TreeBase, Graph):
    pass

from typing import Type, Union, overload

@overload
def read() -> list[int]: ...
@overload
def read(spec: Type[_T], char=False) -> _T: ...
@overload
def read(spec: _U, char=False) -> _U: ...
@overload
def read(*specs: Type[_T], char=False) -> tuple[_T, ...]: ...
@overload
def read(*specs: _U, char=False) -> tuple[_U, ...]: ...
def read(*specs: Union[Type[_T],_T], char=False):
    IO.stdin.char = char
    if not specs: return IO.stdin.readnumsinto([])
    parser: _T = Parser.compile(specs[0] if len(specs) == 1 else specs)
    return parser(IO.stdin)
from os import read as os_read, write as os_write, fstat as os_fstat
from __pypy__.builders import StringBuilder

def max2(a, b): return a if a > b else b

class IO(IOBase):
    BUFSIZE = 1 << 16; stdin: 'IO'; stdout: 'IO'
    __slots__ = 'f', 'file', 'B', 'O', 'V', 'S', 'l', 'p', 'char', 'sz', 'st', 'ist', 'writable', 'encoding', 'errors'
    def __init__(io, file):
        io.file = file
        try: io.f = file.fileno(); io.sz, io.writable = max2(io.BUFSIZE, os_fstat(io.f).st_size), ('x' in file.mode or 'r' not in file.mode)
        except: io.f, io.sz, io.writable = -1, io.BUFSIZE, False
        io.B, io.O, io.S = bytearray(), [], StringBuilder(); io.V = memoryview(io.B); io.l = io.p = 0
        io.char, io.st, io.ist, io.encoding, io.errors = False, [], [], 'ascii', 'ignore'
    def _dec(io, l, r): return io.V[l:r].tobytes().decode(io.encoding, io.errors)
    def readbytes(io, sz): return os_read(io.f, sz)
    def load(io):
        while io.l >= len(io.O):
            if not (b := io.readbytes(io.sz)):
                if io.O[-1] < len(io.B): io.O.append(len(io.B))
                break
            pos = len(io.B); io.B.extend(b)
            while ~(pos := io.B.find(b'\n', pos)): io.O.append(pos := pos+1)
    def __next__(io):
        if io.char: return io.readchar()
        else: return io.readtoken()
    def readchar(io):
        io.load(); r = io.O[io.l]
        c = chr(io.B[io.p])
        if io.p >= r-1: io.p = r; io.l += 1
        else: io.p += 1
        return c
    def write(io, s: str): io.S.append(s)
    def readline(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p)
    def readtoken(io):
        io.load(); r = io.O[io.l]
        if ~(p := io.B.find(b' ', io.p, r)): s = io._dec(io.p, p); io.p = p+1
        else: s = io._dec(io.p, r-1); io.p = r; io.l += 1
        return s
    def readtokens(io): io.st.clear(); return io.readtokensinto(io.st)
    def readints(io): io.ist.clear(); return io.readintsinto(io.ist)
    def readdigits(io): io.ist.clear(); return io.readdigitsinto(io.ist)
    def readnums(io): io.ist.clear(); return io.readnumsinto(io.ist)
    def readchars(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p-1)
    def readinto(io, lst):
        if io.char: return io.readcharsinto(lst)
        else: return io.readtokensinto(lst)
    def readcharsinto(io, lst): lst.extend(io.readchars()); return lst
    def readtokensinto(io, lst): 
        io.load(); r = io.O[io.l]
        while ~(p := io.B.find(b' ', io.p, r)): lst.append(io._dec(io.p, p)); io.p = p+1
        lst.append(io._dec(io.p, r-1)); io.p = r; io.l += 1; return lst
    def _readint(io, r):
        while io.p < r and io.B[io.p] <= 32: io.p += 1
        if io.p >= r: return None
        minus = x = 0
        if io.B[io.p] == 45: minus = 1; io.p += 1
        while io.p < r and io.B[io.p] >= 48: x = x * 10 + (io.B[io.p] & 15); io.p += 1
        io.p += 1
        return -x if minus else x
    def readintsinto(io, lst):
        io.load(); r = io.O[io.l]
        while io.p < r and (x := io._readint(r)) is not None: lst.append(x)
        io.l += 1; return lst
    def _readdigit(io): d = io.B[io.p] & 15; io.p += 1; return d
    def readdigitsinto(io, lst):
        io.load(); r = io.O[io.l]
        while io.p < r and io.B[io.p] > 32: lst.append(io._readdigit())
        if io.B[io.p] == 10: io.l += 1
        io.p += 1
        return lst
    def readnumsinto(io, lst):
        if io.char: return io.readdigitsinto(lst)
        else: return io.readintsinto(lst)
    def line(io): io.st.clear(); return io.readinto(io.st)
    def wait(io):
        io.load(); r = io.O[io.l]
        while io.p < r: yield
    def flush(io):
        if io.writable: os_write(io.f, io.S.build().encode(io.encoding, io.errors)); io.S = StringBuilder()
sys.stdin = IO.stdin = IO(sys.stdin); sys.stdout = IO.stdout = IO(sys.stdout)
import typing
from numbers import Number
from typing import Callable, Collection

class Parser:
    def __init__(self, spec):  self.parse = Parser.compile(spec)
    def __call__(self, io: IOBase): return self.parse(io)
    @staticmethod
    def compile_type(cls, args = ()):
        if issubclass(cls, Parsable): return cls.compile(*args)
        elif issubclass(cls, (Number, str)):
            def parse(io: IOBase): return cls(next(io))              
            return parse
        elif issubclass(cls, tuple): return Parser.compile_tuple(cls, args)
        elif issubclass(cls, Collection): return Parser.compile_collection(cls, args)
        elif callable(cls):
            def parse(io: IOBase): return cls(next(io))              
            return parse
        else: raise NotImplementedError()
    @staticmethod
    def compile(spec=int):
        if isinstance(spec, (type, GenericAlias)):
            cls, args = typing.get_origin(spec) or spec, typing.get_args(spec) or tuple()
            return Parser.compile_type(cls, args)
        elif isinstance(offset := spec, Number): 
            cls = type(spec)  
            def parse(io: IOBase): return cls(next(io)) + offset
            return parse
        elif isinstance(args := spec, tuple): return Parser.compile_tuple(type(spec), args)
        elif isinstance(args := spec, Collection): return Parser.compile_collection(type(spec), args)
        elif isinstance(fn := spec, Callable): 
            def parse(io: IOBase): return fn(next(io))
            return parse
        else: raise NotImplementedError()
    @staticmethod
    def compile_line(cls, spec=int):
        if spec is int:
            def parse(io: IOBase): return cls(io.readnums())
        elif spec is str:
            def parse(io: IOBase): return cls(io.line())
        else:
            fn = Parser.compile(spec)
            def parse(io: IOBase): return cls((fn(io) for _ in io.wait()))
        return parse
    @staticmethod
    def compile_repeat(cls, spec, N):
        fn = Parser.compile(spec)
        def parse(io: IOBase): return cls([fn(io) for _ in range(N)])
        return parse
    @staticmethod
    def compile_children(cls, specs):
        fns = tuple((Parser.compile(spec) for spec in specs))
        def parse(io: IOBase): return cls([fn(io) for fn in fns])  
        return parse
    @staticmethod
    def compile_tuple(cls, specs):
        if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...: return Parser.compile_line(cls, specs[0])
        else: return Parser.compile_children(cls, specs)
    @staticmethod
    def compile_collection(cls, specs):
        if not specs or len(specs) == 1 or isinstance(specs, set):
            return Parser.compile_line(cls, *specs)
        elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
            return Parser.compile_repeat(cls, specs[0], specs[1])
        else:
            raise NotImplementedError()

def write(*args, **kwargs):
    '''Prints the values to a stream, or to stdout_fast by default.'''
    sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IO.stdout)
    at_start = True
    for x in args:
        if not at_start: file.write(sep)
        file.write(str(x))
        at_start = False
    file.write(kwargs.pop("end", "\n"))
    if kwargs.pop("flush", False): file.flush()

if __name__ == '__main__':
    main()
Back to top page