cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/data-structure/static_rectangle_add_rectangle_sum_wm_group_points.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/static_rectangle_add_rectangle_sum

def main():
    mod, s, m = 998244353, 31, (1 << 31)-1
    N, Q = read(tuple[int, ...])
    N4 = N<<2
    X, Y, W = [0]*N4,[0]*N4,[(0,0)]*N4
    mod2 = mod<<s|mod
    def polynomial(x, y, w):
        # coefficients
        return (-x*w%mod)<<s|(-y*w%mod),((w%mod)<<s)|(x*y%mod*w%mod)
    for i in range(N):
        l, d, r, u, w = read()
        X[i:=i<<2], Y[i], W[i] = l, d, polynomial(l, d, w)
        X[i:=i +1], Y[i], W[i] = l, u, polynomial(l, u, -w)
        X[i:=i +1], Y[i], W[i] = r, d, polynomial(r, d, -w)
        X[i:=i +1], Y[i], W[i] = r, u, polynomial(r, u, w)
    
    def op(a, b):
        av, aw = a; bv, bw = b; v, w = av+bv,aw+bw
        return ((v>>s)%mod)<<s|((v&m)%mod),((w>>s)%mod)<<s|((w&m)%mod)
    def diff(a, b):
        av, aw = a; bv, bw = b; v, w = av+mod2-bv,aw+mod2-bw
        return ((v>>s)%mod)<<s|((v&m)%mod),((w>>s)%mod)<<s|((w&m)%mod)
    e = 0,0
    wm = WMGroupPoints(op, e, diff, X, Y, W)
    
    def poly_eval(x,y,poly):
        v, w = poly; v1, v2 = v>>s, v&m; w1, w2 = w>>s, w&m
        return (w2+y*v1+x*v2+x*y%mod*w1)%mod
    for _ in range(Q):
        l, d, r, u = read()
        ld = poly_eval(l,d,wm.prod_corner(l,d))
        lu = poly_eval(l,u,wm.prod_corner(l,u))
        rd = poly_eval(r,d,wm.prod_corner(r,d))
        ru = poly_eval(r,u,wm.prod_corner(r,u))
        ans = (ru+ld-lu-rd)%mod
        write(ans)

from cp_library.ds.wavelet.wm_group_points_cls import WMGroupPoints
from cp_library.io.write_fn import write
from cp_library.io.read_fn import read

if __name__ == "__main__":
    main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/static_rectangle_add_rectangle_sum

def main():
    mod, s, m = 998244353, 31, (1 << 31)-1
    N, Q = read(tuple[int, ...])
    N4 = N<<2
    X, Y, W = [0]*N4,[0]*N4,[(0,0)]*N4
    mod2 = mod<<s|mod
    def polynomial(x, y, w):
        # coefficients
        return (-x*w%mod)<<s|(-y*w%mod),((w%mod)<<s)|(x*y%mod*w%mod)
    for i in range(N):
        l, d, r, u, w = read()
        X[i:=i<<2], Y[i], W[i] = l, d, polynomial(l, d, w)
        X[i:=i +1], Y[i], W[i] = l, u, polynomial(l, u, -w)
        X[i:=i +1], Y[i], W[i] = r, d, polynomial(r, d, -w)
        X[i:=i +1], Y[i], W[i] = r, u, polynomial(r, u, w)
    
    def op(a, b):
        av, aw = a; bv, bw = b; v, w = av+bv,aw+bw
        return ((v>>s)%mod)<<s|((v&m)%mod),((w>>s)%mod)<<s|((w&m)%mod)
    def diff(a, b):
        av, aw = a; bv, bw = b; v, w = av+mod2-bv,aw+mod2-bw
        return ((v>>s)%mod)<<s|((v&m)%mod),((w>>s)%mod)<<s|((w&m)%mod)
    e = 0,0
    wm = WMGroupPoints(op, e, diff, X, Y, W)
    
    def poly_eval(x,y,poly):
        v, w = poly; v1, v2 = v>>s, v&m; w1, w2 = w>>s, w&m
        return (w2+y*v1+x*v2+x*y%mod*w1)%mod
    for _ in range(Q):
        l, d, r, u = read()
        ld = poly_eval(l,d,wm.prod_corner(l,d))
        lu = poly_eval(l,u,wm.prod_corner(l,u))
        rd = poly_eval(r,d,wm.prod_corner(r,d))
        ru = poly_eval(r,u,wm.prod_corner(r,u))
        ans = (ru+ld-lu-rd)%mod
        write(ans)

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''




def coord_compress(A: list[int], distinct = False):
    s, m = pack_sm((N := len(A))-1); R, V = [0]*N, [a<<s|i for i,a in enumerate(A)]; V.sort()
    if distinct:
        for r, ai in enumerate(V): a, i = pack_dec(ai, s, m); R[i], V[r] = r, a
    else:
        r = p = -1
        for ai in V:
            a, i = pack_dec(ai, s, m)
            if a != p: r = r+1; V[r] = p = a
            R[i] = r
        del V[r+1:]
    return R, V



def pack_dec(ab: int, s: int, m: int): return ab>>s,ab&m
def pack_sm(N: int): s=N.bit_length(); return s,(1<<s)-1



class Presum:
    def __init__(P, op, e, diff, A: list):
        P.N = len(A); P.op, P.e, P.diff, P.pre = op, e, diff, [e]*(P.N+1)
        for i,a in enumerate(A):P.pre[i+1]=op(P.pre[i],a)
    def __getitem__(P,i):return P.pre[i]
    def prod(P,l:int,r:int):return P.diff(P.pre[r],P.pre[l])
from abc import abstractmethod

class BitArray:
    def __init__(B, N: int):
        B.N, B.Z = N, (N+31)>>5
        B.bits, B.cnt = u32f(B.Z+1), u32f(B.Z+1)
    def build(B):
        B.bits.pop()
        for i,b in enumerate(B.bits): B.cnt[i+1] = B.cnt[i]+popcnt32(b)
        B.bits.append(1)
    def __len__(B): return B.N
    def __getitem__(B, i: int): return B.bits[i>>5]>>(31-(i&31))&1
    def set0(B, i: int): B.bits[i>>5]&=~(1<<31-(i&31))
    def set1(B, i: int): B.bits[i>>5]|=1<<31-(i&31)
    def count0(B, r: int): return r-B.count1(r)
    def count1(B, r: int): return B.cnt[r>>5]+popcnt32(B.bits[r>>5]>>32-(r&31))
    def select0(B, k: int):
        if not 0<=k<B.N-B.cnt[-1]: return -1
        l,r,k=0,B.N,k+1
        while 1<r-l:
            if B.count0(m:=(l+r)>>1)<k:l=m
            else:r=m
        return l
    def select1(B, k: int):
        if not 0<=k<B.cnt[-1]: return -1
        l,r,k=0,B.N,k+1
        while 1<r-l:
            if B.count1(m:=(l+r)>>1)<k:l=m
            else:r=m
        return l

def popcnt32(x):
    x = ((x >> 1)  & 0x55555555) + (x & 0x55555555)
    x = ((x >> 2)  & 0x33333333) + (x & 0x33333333)
    x = ((x >> 4)  & 0x0f0f0f0f) + (x & 0x0f0f0f0f)
    x = ((x >> 8)  & 0x00ff00ff) + (x & 0x00ff00ff)
    x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff)
    return x
if hasattr(int, 'bit_count'):
    popcnt32 = int.bit_count

from array import array
def u32f(N: int, elm: int = 0):     return array('I', (elm,))*N  # unsigned int

class WMStatic:
    class Level(BitArray):
        def __init__(L, N: int, H: int):
            super().__init__(N)
            L.H = H
        def build(L):
            super().build()
            L.T0, L.T1 = L.N-L.cnt[-1], L.cnt[-1]
        def pos(L, bit: int, i: int): return L.T0+L.count1(i) if bit else L.count0(i)
        def pos2(L, bit: int, i: int, j: int): return (L.T0+L.count1(i), L.T0+L.count1(j)) if bit else (L.count0(i), L.count0(j))
    def __init__(wm,A,Amax:int=None):wm._build(A,[0]*len(A),max(A,default=0)if Amax is None else Amax)
    def _build(wm, A, nA, Amax):wm.N,wm.H=len(A),Amax.bit_length();wm._build_levels(A,nA)
    def _build_levels(wm, A, nA):
        wm.up=[wm.Level(wm.N,H) for H in range(wm.H)];wm.down=wm.up[::-1]
        for L in wm.down:
            x,y,i=-1,wm.N-1,wm.N
            while i:y-=A[i:=i-1]>>L.H&1
            for i,a in enumerate(A):
                if a>>L.H&1:nA[y:=y+1]=a;L.set1(i)
                else:nA[x:=x+1]=a
            A,nA=nA,A;L.build()
    def __getitem__(wm,i):
        y=0
        for L in wm.down:y=y<<1|(bit:=L[i]);i=L.pos(bit,i)
        return y
    def kth(wm, k: int, l: int, r: int):
        '''Returns the `k+1`-th value in sorted order of values in range `[l, r)`'''
        s=0
        for L in wm.down:
            l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if k>=r-l:s|=1<<L.H;k-=r-l;l,r=L.T0+l1,L.T0+r1
        return s
    def select(wm, y: int, k: int, l: int = 0, r: int = -1):
        '''Returns the index of the `k+1`-th occurance of `y` in range `[l, r)`'''
        if not(0<=y<1<<wm.H):return-1
        if r==-1:r=wm.N-1
        for L in wm.down:l,r=L.pos2(L[y],l,r)
        if not l<=(i:=l+k)<r:return-1
        for L in wm.up:
            if y>>L.H&1:i=L.select1(i-L.T0)
            else:i=L.select0(i)
        return i
    def rank(wm, y: int, r: int): return wm.rank_range(y, 0, r)
    def rank_range(wm, y: int, l: int, r: int):
        if l >= r: return 0
        for L in wm.down:l,r=L.pos2(L[y],l,r)
        return r-l
    def count_at(wm, y: int, l: int, r: int):
        '''Count how many `y` values are in range `[l,r)` '''
        if l >= r: return 0
        return wm._cnt(y+1, l, r)-wm._cnt(y, l, r)
    def count_below(wm, u: int, l: int, r: int):
        '''Count `i`'s in `[l,r)` such that `A[i] < u` '''
        return wm._cnt(u, l, r)
    def count_between(wm, d: int, u: int, l: int, r: int):
        '''Count `i`'s in `[l,r)` such that `d <= A[i] < u` '''
        if l >= r or d >= u: return 0
        return wm._cnt(u, l, r)-wm._cnt(d, l, r)
    def _cnt(wm, u: int, l: int, r: int):
        if u<=0:return 0
        if wm.H<u.bit_length():return r-l
        cnt=0
        for L in wm.down:
            l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if u>>L.H&1:cnt+=r-l;l,r=L.T0+l1,L.T0+r1
        return cnt
    def prev_val(wm,u:int,l:int,r:int):return wm.kth(cnt-1, l, r)if(cnt:=wm._cnt(u,l,r))else-1
    def next_val(wm,d:int,l:int,r:int):return wm.kth(cnt, l, r)if(cnt:=wm._cnt(d,l,r))<r-l else-1

class WMMonoid(WMStatic):
    def __init__(wm,op,e,A:list[int],W:list[int],Amax:int=None):wm._build(op,e,A,W,[0]*len(A),[0]*len(A),max(A,default=0)if Amax is None else Amax)
    def _build(wm,op,e,A,W,nA,nW,Amax):wm.N,wm.H,wm.op,wm.e=len(A),Amax.bit_length(),op,e;wm._build_base(W);wm._build_levels(A,W,nA,nW)
    @abstractmethod
    def _build_base(wm,W):...
    @abstractmethod
    def _build_level(wm,L,W):...
    def _build_levels(wm,A,W,nA,nW):
        wm.up=[wm.Level(wm.N,H)for H in range(wm.H)];wm.down=wm.up[::-1]
        for L in wm.down:
            x,y,i=-1,wm.N-1,wm.N
            while i:y-=A[i:=i-1]>>L.H&1
            for i,a in enumerate(A):
                if a>>L.H&1:nA[y:=y+1],nW[y]=a,W[i];L.set1(i)
                else:nA[x:=x+1],nW[x]=a,W[i]
            A,nA,W,nW=nA,A,nW,W;wm._build_level(L,W)
    def prod_range(wm,l:int,r:int):return wm._prod_range(l,r)if l<r else wm.e
    def prod_at(wm,y:int,l:int,r:int):return wm._prod_rect(y,y+1,l,r)if l<r else wm.e
    def prod_below(wm,u:int,l:int,r:int):return wm._prod_below(u,l,r)if l<r else wm.e
    def prod_above(wm,d:int,l:int,r:int):return wm._prod_above(d,l,r)if l<r else wm.e
    def prod_between(wm,d:int,u:int,l:int,r:int):return wm._prod_rect(d,u,l,r)if l<r and d<u else wm.e
    def prod_corner(wm,r:int,u:int):return wm._prod_below(u,0,r)if 0<r else wm.e
    def prod_rect(wm,l:int,d:int,r:int,u:int):return wm._prod_rect(d,u,l,r)if l<r and d<u else wm.e
    @abstractmethod
    def _prod_range(wm,l,r):...
    def _prod_below(wm,u,l,r):
        if u<=0:return wm.e
        elif wm.H<u.bit_length():return wm._prod_range(l,r)
        prod=wm.e
        for L in wm.down:
            l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if u>>L.H&1:prod=wm.op(prod,L.prod(l,r));l,r=L.T0+l1,L.T0+r1
        return prod
    def _prod_above(wm,d,l,r):
        if d<=0: return wm._prod_range(l, r)
        elif d.bit_length() > wm.H: return wm.e
        prod, d = wm.e, d-1
        for L in wm.down:
            l0,r0=l-(l:=L.T0+L.count1(l)),r-(r:=L.T0+L.count1(r))
            if d>>L.H&1==0:prod=wm.op(L.prod(l,r),prod);l,r=L.T0+l0,L.T0+r0
        return prod
    def _prod_rect(wm,d,u,l,r):
        if u<=0 or wm.H<d.bit_length():return wm.e
        elif d<=0:return wm._prod_below(u,l,r)
        elif wm.H<u.bit_length():return wm._prod_above(d,l,r)
        same,prod,d=1,wm.e,d-1
        for L in wm.down:
            db,ub,l,r=d>>L.H&1,u>>L.H&1,l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if same:
                if db!=ub:same,dl,dr,l,r=0,l,r,L.T0+l1,L.T0+r1
                elif db:l,r=L.T0+l1,L.T0+r1
            else:
                if ub:prod=wm.op(prod,L.prod(l,r));l,r=L.T0+l1,L.T0+r1
                dl0,dr0=dl-(dl:=L.T0+L.count1(dl)),dr-(dr:=L.T0+L.count1(dr))
                if not db:prod=wm.op(L.prod(dl,dr),prod);dl,dr=L.T0+dl0,L.T0+dr0
        return prod

class WMGroup(WMMonoid):
    class Level(WMStatic.Level):
        def build(L, op, e, diff, W):super().build();L.W=Presum(op,e,diff,W)
        def prod(L,l:int,r:int):return L.W.prod(l,r)
    def __init__(wm,op,e,diff,A,W,Amax=None):wm._build(op,e,diff,A,W,[0]*len(A),[0]*len(A),max(A,default=0)if Amax is None else Amax)
    def _build(wm,op,e,diff,A,W,nA,nW,Amax):wm.diff=diff;super()._build(op, e, A, W, nA, nW, Amax)
    def _build_base(wm,W):wm.W=Presum(wm.op,wm.e,wm.diff,W)
    def _build_level(wm,L,W):L.build(wm.op,wm.e,wm.diff,W)
    def _prod_range(wm,l:int,r:int):return wm.W.prod(l,r)


def bisect_left(A, x, l, r):
    while l<r:
        if A[m:=(l+r)>>1]<x:l=m+1
        else:r=m
    return l

class WMCompressed(WMStatic):
    def __init__(wm,A):A,wm.Y=coord_compress(A);super().__init__(A,len(wm.Y)-1)
    def _didx(wm,y:int):return bisect_left(wm.Y,y,0,len(wm.Y))
    def _yidx(wm,y:int):return i if(i:=wm._didx(y))<len(wm.Y)and wm.Y[i]==y else-1
    def __contains__(wm,y:int):return(i:=wm._didx(y))<len(wm.Y)and wm.Y[i]==y
    def kth(wm,k,l,r):return wm.Y[super().kth(k,l,r)]
    def select(wm,y,k,l=0,r=-1):return super().select(y,k,l,r)if~(y:=wm._yidx(y))else-1
    def rank_range(wm,y,l,r):return super().rank_range(y,l,r)if~(y:=wm._yidx(y))else 0
    def count_at(wm,y,l,r):return super().count_at(y,l,r)if~(y:=wm._yidx(y))else 0
    def count_below(wm,u,l,r):return super().count_below(wm._didx(u),l,r)
    def count_between(wm,d,u,l,r):return super().count_between(wm._didx(d),wm._didx(u),l,r)
    def prev_val(wm,u,l,r):return super().prev_val(wm._didx(u),l,r)
    def next_val(wm,d,l,r):return super().next_val(wm._didx(d),l,r)

class WMMonoidCompressed(WMMonoid, WMCompressed):
    def __init__(wm,op,e,A:list[int],W:list[int]):A,wm.Y=coord_compress(A);WMMonoid.__init__(wm,op,e,A,W,len(wm.Y)-1)
    def prod_at(wm,y,l,r):return super().prod_at(y,l,r)if~(y:=wm._yidx(y))else 0
    def prod_below(wm,u,l,r):return super().prod_below(wm._didx(u),l,r)
    def prod_between(wm,d,u,l,r):return super().prod_between(wm._didx(d),wm._didx(u),l,r)
    def prod_corner(wm,r,u):return super().prod_corner(r,wm._didx(u))
    def prod_rect(wm,l,d,r,u):return super().prod_rect(l,wm._didx(d),r,wm._didx(u))

class WMGroupCompressed(WMGroup,WMMonoidCompressed):
    def __init__(wm,op,e,diff,A:list[int],W:list):A,wm.Y=coord_compress(A);WMGroup.__init__(wm,op,e,diff,A,W,len(wm.Y)-1)

class WMPoints(WMCompressed):
    def __init__(wm,X,Y):
        wm.I,wm.X=coord_compress(X,distinct=True);A,wm.Y=coord_compress(Y);nA=[0]*len(Y)
        for i,j in enumerate(wm.I):nA[j]=A[i]
        wm._build(nA,A,len(wm.Y)-1)
    def _lidx(wm,x):return bisect_left(wm.X,x,0,len(wm.X))
    def __getitem__(wm,i):return super().__getitem__(wm.I[i])
    def kth(wm,k,l,r):return super().kth(k,wm._lidx(l),wm._lidx(r))
    def select(wm,y,k,l=0,r=-1):return super().select(y,k,wm._lidx(l),wm._lidx(r))
    def rank_range(wm,y,l,r):return super().rank_range(y,wm._lidx(l),wm._lidx(r))
    def count_at(wm,y,l,r):return super().count_at(y,wm._lidx(l),wm._lidx(r))
    def count_below(wm,u,l,r):return super().count_below(u,wm._lidx(l),wm._lidx(r))
    def count_between(wm,d,u,l,r):return super().count_between(d,u,wm._lidx(l),wm._lidx(r))
    def prev_val(wm,u,l,r):return super().prev_val(u,wm._lidx(l),wm._lidx(r))
    def next_val(wm,d,l,r):return super().next_val(d,wm._lidx(l),wm._lidx(r))

class WMMonoidPoints(WMMonoidCompressed,WMPoints):
    def __init__(wm,op,e,X:list[int],Y:list[int],W:list[int]):
        wm.I,wm.X=coord_compress(X,distinct=True);A,wm.Y=coord_compress(Y);nA,nW=[0]*(N:=len(A)),[0]*N
        for i,j in enumerate(wm.I):nA[j],nW[j]=A[i],W[i]
        wm._build(op,e,nA,nW,A,W,len(wm.Y)-1)
    def prod_range(wm,l,r):return super().prod_range(wm._lidx(l),wm._lidx(r))
    def prod_at(wm,y,l,r):return super().prod_at(y,wm._lidx(l),wm._lidx(r))
    def prod_below(wm,u,l,r):return super().prod_below(u,wm._lidx(l),wm._lidx(r))
    def prod_between(wm,d,u,l,r):return super().prod_between(d,u,wm._lidx(l),wm._lidx(r))
    def prod_corner(wm,r,u):return super().prod_corner(wm._lidx(r),u)
    def prod_rect(wm,l,d,r,u):return super().prod_rect(wm._lidx(l),d,wm._lidx(r),u)

class WMGroupPoints(WMGroupCompressed,WMMonoidPoints):
    def __init__(wm,op,e,diff,X:list[int],Y:list[int],W:list):
        wm.I,wm.X=coord_compress(X,distinct=True);A,wm.Y=coord_compress(Y);nA,nW=[0]*(N:=len(A)),[0]*N
        for i,j in enumerate(wm.I):nA[j],nW[j]=A[i],W[i]
        wm._build(op,e,diff,nA,nW,A,W,len(wm.Y)-1)

import os
import sys
from io import BytesIO, IOBase


class FastIO(IOBase):
    BUFSIZE = 8192
    newlines = 0

    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None

    def read(self):
        BUFSIZE = self.BUFSIZE
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()

    def readline(self):
        BUFSIZE = self.BUFSIZE
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()

    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)


class IOWrapper(IOBase):
    stdin: 'IOWrapper' = None
    stdout: 'IOWrapper' = None
    
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable

    def write(self, s):
        return self.buffer.write(s.encode("ascii"))
    
    def read(self):
        return self.buffer.read().decode("ascii")
    
    def readline(self):
        return self.buffer.readline().decode("ascii")
try:
    sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
    sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
except:
    pass

def write(*args, **kwargs):
    '''Prints the values to a stream, or to stdout_fast by default.'''
    sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IOWrapper.stdout)
    at_start = True
    for x in args:
        if not at_start:
            file.write(sep)
        file.write(str(x))
        at_start = False
    file.write(kwargs.pop("end", "\n"))
    if kwargs.pop("flush", False):
        file.flush()

from typing import Iterable, Type, Union, overload
import typing
from collections import deque
from numbers import Number
from types import GenericAlias 
from typing import Callable, Collection, Iterator, Union
from typing import TypeVar
_T = TypeVar('T')
_U = TypeVar('U')

class TokenStream(Iterator):
    stream = IOWrapper.stdin

    def __init__(self):
        self.queue = deque()

    def __next__(self):
        if not self.queue: self.queue.extend(self._line())
        return self.queue.popleft()
    
    def wait(self):
        if not self.queue: self.queue.extend(self._line())
        while self.queue: yield
 
    def _line(self):
        return TokenStream.stream.readline().split()

    def line(self):
        if self.queue:
            A = list(self.queue)
            self.queue.clear()
            return A
        return self._line()
TokenStream.default = TokenStream()

class CharStream(TokenStream):
    def _line(self):
        return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()


ParseFn = Callable[[TokenStream],_T]
class Parser:
    def __init__(self, spec: Union[type[_T],_T]):
        self.parse = Parser.compile(spec)

    def __call__(self, ts: TokenStream) -> _T:
        return self.parse(ts)
    
    @staticmethod
    def compile_type(cls: type[_T], args = ()) -> _T:
        if issubclass(cls, Parsable):
            return cls.compile(*args)
        elif issubclass(cls, (Number, str)):
            def parse(ts: TokenStream): return cls(next(ts))              
            return parse
        elif issubclass(cls, tuple):
            return Parser.compile_tuple(cls, args)
        elif issubclass(cls, Collection):
            return Parser.compile_collection(cls, args)
        elif callable(cls):
            def parse(ts: TokenStream):
                return cls(next(ts))              
            return parse
        else:
            raise NotImplementedError()
    
    @staticmethod
    def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
        if isinstance(spec, (type, GenericAlias)):
            cls = typing.get_origin(spec) or spec
            args = typing.get_args(spec) or tuple()
            return Parser.compile_type(cls, args)
        elif isinstance(offset := spec, Number): 
            cls = type(spec)  
            def parse(ts: TokenStream): return cls(next(ts)) + offset
            return parse
        elif isinstance(args := spec, tuple):      
            return Parser.compile_tuple(type(spec), args)
        elif isinstance(args := spec, Collection):
            return Parser.compile_collection(type(spec), args)
        elif isinstance(fn := spec, Callable): 
            def parse(ts: TokenStream): return fn(next(ts))
            return parse
        else:
            raise NotImplementedError()

    @staticmethod
    def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
        if spec is int:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
            return parse
        else:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
            return parse

    @staticmethod
    def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
        fn = Parser.compile(spec)
        def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
        return parse

    @staticmethod
    def compile_children(cls: _T, specs) -> ParseFn[_T]:
        fns = tuple((Parser.compile(spec) for spec in specs))
        def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])  
        return parse
            
    @staticmethod
    def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
        if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
            return Parser.compile_line(cls, specs[0])
        else:
            return Parser.compile_children(cls, specs)

    @staticmethod
    def compile_collection(cls, specs):
        if not specs or len(specs) == 1 or isinstance(specs, set):
            return Parser.compile_line(cls, *specs)
        elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
            return Parser.compile_repeat(cls, specs[0], specs[1])
        else:
            raise NotImplementedError()

class Parsable:
    @classmethod
    def compile(cls):
        def parser(ts: TokenStream): return cls(next(ts))
        return parser

@overload
def read() -> list[int]: ...
@overload
def read(spec: Type[_T], char=False) -> _T: ...
@overload
def read(spec: _U, char=False) -> _U: ...
@overload
def read(*specs: Type[_T], char=False) -> tuple[_T, ...]: ...
@overload
def read(*specs: _U, char=False) -> tuple[_U, ...]: ...
def read(*specs: Union[Type[_T],_U], char=False):
    if not char and not specs: return [int(s) for s in TokenStream.default.line()]
    parser: _T = Parser.compile(specs)
    ret = parser(CharStream.default if char else TokenStream.default)
    return ret[0] if len(specs) == 1 else ret

if __name__ == "__main__":
    main()
Back to top page