This documentation is automatically generated by online-judge-tools/verification-helper
import cp_library.__header__
from cp_library.alg.iter.cmpr.coord_compress_fn import coord_compress
import cp_library.ds.__header__
import cp_library.ds.wavelet.__header__
from cp_library.ds.wavelet.wm_segtree_cls import WMSegTree
from cp_library.ds.wavelet.wm_monoid_compressed_cls import WMMonoidCompressed
class WMSegTreeCompressed(WMSegTree,WMMonoidCompressed):pass
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
def coord_compress(A: list[int], distinct = False):
s, m = pack_sm((N := len(A))-1); R, V = [0]*N, [a<<s|i for i,a in enumerate(A)]; V.sort()
if distinct:
for r, ai in enumerate(V): a, i = pack_dec(ai, s, m); R[i], V[r] = r, a
else:
r = p = -1
for ai in V:
a, i = pack_dec(ai, s, m)
if a != p: r = r+1; V[r] = p = a
R[i] = r
del V[r+1:]
return R, V
def pack_dec(ab: int, s: int, m: int): return ab>>s,ab&m
def pack_sm(N: int): s=N.bit_length(); return s,(1<<s)-1
from typing import Callable, Generic, Union
from typing import TypeVar
_T = TypeVar('T')
_U = TypeVar('U')
class SegTree(Generic[_T]):
def __init__(seg, op: Callable[[_T, _T], _T], e: _T, v: Union[int, list[_T]]) -> None:
if isinstance(v, int): v = [e] * v
seg.op, seg.e, seg.n = op, e, (n := len(v))
seg.log, seg.sz, seg.d = (log := (n-1).bit_length()+1), (sz := 1 << log), [e] * (sz << 1)
for i in range(n): seg.d[sz + i] = v[i]
for i in range(sz-1,0,-1): seg.d[i] = op(seg.d[i<<1], seg.d[i<<1|1])
def set(seg, p: int, x: _T) -> None:
seg.d[p := p + seg.sz], op = x, seg.op
for _ in range(seg.log): seg.d[p:=p>>1] = op(seg.d[p:=p^(p&1)], seg.d[p|1])
__setitem__ = set
def get(seg, p: int) -> _T:
return seg.d[p + seg.sz]
__getitem__ = get
def prod(seg, l: int, r: int) -> _T:
sml = smr = seg.e
l, r = l+seg.sz, r+seg.sz
while l < r:
if l&1: sml, l = seg.op(sml, seg.d[l]), l+1
if r&1: smr = seg.op(seg.d[r:=r-1], smr)
l, r = l >> 1, r >> 1
return seg.op(sml, smr)
def all_prod(seg) -> _T:
return seg.d[1]
def max_right(seg, l: int, f: Callable[[_T], bool]) -> int:
assert 0 <= l <= seg.n
assert f(seg.e)
if l == seg.n: return seg.n
l, op, d, sm = l+(sz := seg.sz), seg.op, seg.d, seg.e
while True:
while l&1 == 0: l >>= 1
if not f(op(sm, d[l])):
while l < sz:
if f(op(sm, d[l:=l<<1])): sm, l = op(sm, d[l]), l+1
return l - sz
sm, l = op(sm, d[l]), l+1
if l&-l == l: return seg.n
def min_left(seg, r: int, f: Callable[[_T], bool]) -> int:
assert 0 <= r <= seg.n
assert f(seg.e)
if r == 0: return 0
r, op, d, sm = r+(sz := seg.sz), seg.op, seg.d, seg.e
while True:
r -= 1
while r > 1 and r & 1: r >>= 1
if not f(op(d[r], sm)):
while r < sz:
if f(op(d[r:=r<<1|1], sm)): sm, r = op(d[r], sm), r-1
return r + 1 - sz
sm = op(d[r], sm)
if (r & -r) == r: return 0
from abc import abstractmethod
class BitArray:
def __init__(B, N: int):
B.N, B.Z = N, (N+31)>>5
B.bits, B.cnt = u32f(B.Z+1), u32f(B.Z+1)
def build(B):
B.bits.pop()
for i,b in enumerate(B.bits): B.cnt[i+1] = B.cnt[i]+popcnt32(b)
B.bits.append(1)
def __len__(B): return B.N
def __getitem__(B, i: int): return B.bits[i>>5]>>(31-(i&31))&1
def set0(B, i: int): B.bits[i>>5]&=~(1<<31-(i&31))
def set1(B, i: int): B.bits[i>>5]|=1<<31-(i&31)
def count0(B, r: int): return r-B.count1(r)
def count1(B, r: int): return B.cnt[r>>5]+popcnt32(B.bits[r>>5]>>32-(r&31))
def select0(B, k: int):
if not 0<=k<B.N-B.cnt[-1]: return -1
l,r,k=0,B.N,k+1
while 1<r-l:
if B.count0(m:=(l+r)>>1)<k:l=m
else:r=m
return l
def select1(B, k: int):
if not 0<=k<B.cnt[-1]: return -1
l,r,k=0,B.N,k+1
while 1<r-l:
if B.count1(m:=(l+r)>>1)<k:l=m
else:r=m
return l
def popcnt32(x):
x = ((x >> 1) & 0x55555555) + (x & 0x55555555)
x = ((x >> 2) & 0x33333333) + (x & 0x33333333)
x = ((x >> 4) & 0x0f0f0f0f) + (x & 0x0f0f0f0f)
x = ((x >> 8) & 0x00ff00ff) + (x & 0x00ff00ff)
x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff)
return x
if hasattr(int, 'bit_count'):
popcnt32 = int.bit_count
from array import array
def u32f(N: int, elm: int = 0): return array('I', (elm,))*N # unsigned int
class WMStatic:
class Level(BitArray):
def __init__(L, N: int, H: int):
super().__init__(N)
L.H = H
def build(L):
super().build()
L.T0, L.T1 = L.N-L.cnt[-1], L.cnt[-1]
def pos(L, bit: int, i: int): return L.T0+L.count1(i) if bit else L.count0(i)
def pos2(L, bit: int, i: int, j: int): return (L.T0+L.count1(i), L.T0+L.count1(j)) if bit else (L.count0(i), L.count0(j))
def __init__(wm,A,Amax:int=None):wm._build(A,[0]*len(A),max(A,default=0)if Amax is None else Amax)
def _build(wm, A, nA, Amax):wm.N,wm.H=len(A),Amax.bit_length();wm._build_levels(A,nA)
def _build_levels(wm, A, nA):
wm.up=[wm.Level(wm.N,H) for H in range(wm.H)];wm.down=wm.up[::-1]
for L in wm.down:
x,y,i=-1,wm.N-1,wm.N
while i:y-=A[i:=i-1]>>L.H&1
for i,a in enumerate(A):
if a>>L.H&1:nA[y:=y+1]=a;L.set1(i)
else:nA[x:=x+1]=a
A,nA=nA,A;L.build()
def __getitem__(wm,i):
y=0
for L in wm.down:y=y<<1|(bit:=L[i]);i=L.pos(bit,i)
return y
def kth(wm, k: int, l: int, r: int):
'''Returns the `k+1`-th value in sorted order of values in range `[l, r)`'''
s=0
for L in wm.down:
l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
if k>=r-l:s|=1<<L.H;k-=r-l;l,r=L.T0+l1,L.T0+r1
return s
def select(wm, y: int, k: int, l: int = 0, r: int = -1):
'''Returns the index of the `k+1`-th occurance of `y` in range `[l, r)`'''
if not(0<=y<1<<wm.H):return-1
if r==-1:r=wm.N-1
for L in wm.down:l,r=L.pos2(L[y],l,r)
if not l<=(i:=l+k)<r:return-1
for L in wm.up:
if y>>L.H&1:i=L.select1(i-L.T0)
else:i=L.select0(i)
return i
def rank(wm, y: int, r: int): return wm.rank_range(y, 0, r)
def rank_range(wm, y: int, l: int, r: int):
if l >= r: return 0
for L in wm.down:l,r=L.pos2(L[y],l,r)
return r-l
def count_at(wm, y: int, l: int, r: int):
'''Count how many `y` values are in range `[l,r)` '''
if l >= r: return 0
return wm._cnt(y+1, l, r)-wm._cnt(y, l, r)
def count_below(wm, u: int, l: int, r: int):
'''Count `i`'s in `[l,r)` such that `A[i] < u` '''
return wm._cnt(u, l, r)
def count_between(wm, d: int, u: int, l: int, r: int):
'''Count `i`'s in `[l,r)` such that `d <= A[i] < u` '''
if l >= r or d >= u: return 0
return wm._cnt(u, l, r)-wm._cnt(d, l, r)
def _cnt(wm, u: int, l: int, r: int):
if u<=0:return 0
if wm.H<u.bit_length():return r-l
cnt=0
for L in wm.down:
l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
if u>>L.H&1:cnt+=r-l;l,r=L.T0+l1,L.T0+r1
return cnt
def prev_val(wm,u:int,l:int,r:int):return wm.kth(cnt-1, l, r)if(cnt:=wm._cnt(u,l,r))else-1
def next_val(wm,d:int,l:int,r:int):return wm.kth(cnt, l, r)if(cnt:=wm._cnt(d,l,r))<r-l else-1
class WMMonoid(WMStatic):
def __init__(wm,op,e,A:list[int],W:list[int],Amax:int=None):wm._build(op,e,A,W,[0]*len(A),[0]*len(A),max(A,default=0)if Amax is None else Amax)
def _build(wm,op,e,A,W,nA,nW,Amax):wm.N,wm.H,wm.op,wm.e=len(A),Amax.bit_length(),op,e;wm._build_base(W);wm._build_levels(A,W,nA,nW)
@abstractmethod
def _build_base(wm,W):...
@abstractmethod
def _build_level(wm,L,W):...
def _build_levels(wm,A,W,nA,nW):
wm.up=[wm.Level(wm.N,H)for H in range(wm.H)];wm.down=wm.up[::-1]
for L in wm.down:
x,y,i=-1,wm.N-1,wm.N
while i:y-=A[i:=i-1]>>L.H&1
for i,a in enumerate(A):
if a>>L.H&1:nA[y:=y+1],nW[y]=a,W[i];L.set1(i)
else:nA[x:=x+1],nW[x]=a,W[i]
A,nA,W,nW=nA,A,nW,W;wm._build_level(L,W)
def prod_range(wm,l:int,r:int):return wm._prod_range(l,r)if l<r else wm.e
def prod_at(wm,y:int,l:int,r:int):return wm._prod_rect(y,y+1,l,r)if l<r else wm.e
def prod_below(wm,u:int,l:int,r:int):return wm._prod_below(u,l,r)if l<r else wm.e
def prod_above(wm,d:int,l:int,r:int):return wm._prod_above(d,l,r)if l<r else wm.e
def prod_between(wm,d:int,u:int,l:int,r:int):return wm._prod_rect(d,u,l,r)if l<r and d<u else wm.e
def prod_corner(wm,r:int,u:int):return wm._prod_below(u,0,r)if 0<r else wm.e
def prod_rect(wm,l:int,d:int,r:int,u:int):return wm._prod_rect(d,u,l,r)if l<r and d<u else wm.e
@abstractmethod
def _prod_range(wm,l,r):...
def _prod_below(wm,u,l,r):
if u<=0:return wm.e
elif wm.H<u.bit_length():return wm._prod_range(l,r)
prod=wm.e
for L in wm.down:
l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
if u>>L.H&1:prod=wm.op(prod,L.prod(l,r));l,r=L.T0+l1,L.T0+r1
return prod
def _prod_above(wm,d,l,r):
if d<=0: return wm._prod_range(l, r)
elif d.bit_length() > wm.H: return wm.e
prod, d = wm.e, d-1
for L in wm.down:
l0,r0=l-(l:=L.T0+L.count1(l)),r-(r:=L.T0+L.count1(r))
if d>>L.H&1==0:prod=wm.op(L.prod(l,r),prod);l,r=L.T0+l0,L.T0+r0
return prod
def _prod_rect(wm,d,u,l,r):
if u<=0 or wm.H<d.bit_length():return wm.e
elif d<=0:return wm._prod_below(u,l,r)
elif wm.H<u.bit_length():return wm._prod_above(d,l,r)
same,prod,d=1,wm.e,d-1
for L in wm.down:
db,ub,l,r=d>>L.H&1,u>>L.H&1,l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
if same:
if db!=ub:same,dl,dr,l,r=0,l,r,L.T0+l1,L.T0+r1
elif db:l,r=L.T0+l1,L.T0+r1
else:
if ub:prod=wm.op(prod,L.prod(l,r));l,r=L.T0+l1,L.T0+r1
dl0,dr0=dl-(dl:=L.T0+L.count1(dl)),dr-(dr:=L.T0+L.count1(dr))
if not db:prod=wm.op(L.prod(dl,dr),prod);dl,dr=L.T0+dl0,L.T0+dr0
return prod
class WMSegTree(WMMonoid):
class Level(WMStatic.Level):
def build(L, op, e, W):super().build();L.W=SegTree(op,e,W)
def prod(L,l:int,r:int):return L.W.prod(l,r)
def _build_base(wm,W):wm.W=SegTree(wm.op,wm.e,W)
def _build_level(wm,L,W):L.build(wm.op,wm.e,W)
def _prod_range(wm,l:int,r:int):return wm.W.prod(l,r)
def set(wm,i:int,w:int):
wm.W.set(i,w)
for L in wm.down:L.W.set(i:=L.pos(L[i],i),w)
def get(wm,i:int):return wm.W.get(i)
def bisect_left(A, x, l, r):
while l<r:
if A[m:=(l+r)>>1]<x:l=m+1
else:r=m
return l
class WMCompressed(WMStatic):
def __init__(wm,A):A,wm.Y=coord_compress(A);super().__init__(A,len(wm.Y)-1)
def _didx(wm,y:int):return bisect_left(wm.Y,y,0,len(wm.Y))
def _yidx(wm,y:int):return i if(i:=wm._didx(y))<len(wm.Y)and wm.Y[i]==y else-1
def __contains__(wm,y:int):return(i:=wm._didx(y))<len(wm.Y)and wm.Y[i]==y
def kth(wm,k,l,r):return wm.Y[super().kth(k,l,r)]
def select(wm,y,k,l=0,r=-1):return super().select(y,k,l,r)if~(y:=wm._yidx(y))else-1
def rank_range(wm,y,l,r):return super().rank_range(y,l,r)if~(y:=wm._yidx(y))else 0
def count_at(wm,y,l,r):return super().count_at(y,l,r)if~(y:=wm._yidx(y))else 0
def count_below(wm,u,l,r):return super().count_below(wm._didx(u),l,r)
def count_between(wm,d,u,l,r):return super().count_between(wm._didx(d),wm._didx(u),l,r)
def prev_val(wm,u,l,r):return super().prev_val(wm._didx(u),l,r)
def next_val(wm,d,l,r):return super().next_val(wm._didx(d),l,r)
class WMMonoidCompressed(WMMonoid, WMCompressed):
def __init__(wm,op,e,A:list[int],W:list[int]):A,wm.Y=coord_compress(A);WMMonoid.__init__(wm,op,e,A,W,len(wm.Y)-1)
def prod_at(wm,y,l,r):return super().prod_at(y,l,r)if~(y:=wm._yidx(y))else 0
def prod_below(wm,u,l,r):return super().prod_below(wm._didx(u),l,r)
def prod_between(wm,d,u,l,r):return super().prod_between(wm._didx(d),wm._didx(u),l,r)
def prod_corner(wm,r,u):return super().prod_corner(r,wm._didx(u))
def prod_rect(wm,l,d,r,u):return super().prod_rect(l,wm._didx(d),r,wm._didx(u))
class WMSegTreeCompressed(WMSegTree,WMMonoidCompressed):pass