cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: cp_library/ds/wavelet/wm_monoid_cls.py

Depends on

Required by

Verified with

Code

from abc import abstractmethod
import cp_library.__header__
import cp_library.ds.__header__
import cp_library.ds.wavelet.__header__
from cp_library.ds.wavelet.wm_static_cls import WMStatic

class WMMonoid(WMStatic):
    def __init__(wm,op,e,A:list[int],W:list[int],Amax:int=None):wm._build(op,e,A,W,[0]*len(A),[0]*len(A),max(A,default=0)if Amax is None else Amax)
    def _build(wm,op,e,A,W,nA,nW,Amax):wm.N,wm.H,wm.op,wm.e=len(A),Amax.bit_length(),op,e;wm._build_base(W);wm._build_levels(A,W,nA,nW)
    @abstractmethod
    def _build_base(wm,W):...
    @abstractmethod
    def _build_level(wm,L,W):...
    def _build_levels(wm,A,W,nA,nW):
        wm.up=[wm.Level(wm.N,H)for H in range(wm.H)];wm.down=wm.up[::-1]
        for L in wm.down:
            x,y,i=-1,wm.N-1,wm.N
            while i:y-=A[i:=i-1]>>L.H&1
            for i,a in enumerate(A):
                if a>>L.H&1:nA[y:=y+1],nW[y]=a,W[i];L.set1(i)
                else:nA[x:=x+1],nW[x]=a,W[i]
            A,nA,W,nW=nA,A,nW,W;wm._build_level(L,W)
    def prod_range(wm,l:int,r:int):return wm._prod_range(l,r)if l<r else wm.e
    def prod_at(wm,y:int,l:int,r:int):return wm._prod_rect(y,y+1,l,r)if l<r else wm.e
    def prod_below(wm,u:int,l:int,r:int):return wm._prod_below(u,l,r)if l<r else wm.e
    def prod_above(wm,d:int,l:int,r:int):return wm._prod_above(d,l,r)if l<r else wm.e
    def prod_between(wm,d:int,u:int,l:int,r:int):return wm._prod_rect(d,u,l,r)if l<r and d<u else wm.e
    def prod_corner(wm,r:int,u:int):return wm._prod_below(u,0,r)if 0<r else wm.e
    def prod_rect(wm,l:int,d:int,r:int,u:int):return wm._prod_rect(d,u,l,r)if l<r and d<u else wm.e
    @abstractmethod
    def _prod_range(wm,l,r):...
    def _prod_below(wm,u,l,r):
        if u<=0:return wm.e
        elif wm.H<u.bit_length():return wm._prod_range(l,r)
        prod=wm.e
        for L in wm.down:
            l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if u>>L.H&1:prod=wm.op(prod,L.prod(l,r));l,r=L.T0+l1,L.T0+r1
        return prod
    def _prod_above(wm,d,l,r):
        if d<=0: return wm._prod_range(l, r)
        elif d.bit_length() > wm.H: return wm.e
        prod, d = wm.e, d-1
        for L in wm.down:
            l0,r0=l-(l:=L.T0+L.count1(l)),r-(r:=L.T0+L.count1(r))
            if d>>L.H&1==0:prod=wm.op(L.prod(l,r),prod);l,r=L.T0+l0,L.T0+r0
        return prod
    def _prod_rect(wm,d,u,l,r):
        if u<=0 or wm.H<d.bit_length():return wm.e
        elif d<=0:return wm._prod_below(u,l,r)
        elif wm.H<u.bit_length():return wm._prod_above(d,l,r)
        same,prod,d=1,wm.e,d-1
        for L in wm.down:
            db,ub,l,r=d>>L.H&1,u>>L.H&1,l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if same:
                if db!=ub:same,dl,dr,l,r=0,l,r,L.T0+l1,L.T0+r1
                elif db:l,r=L.T0+l1,L.T0+r1
            else:
                if ub:prod=wm.op(prod,L.prod(l,r));l,r=L.T0+l1,L.T0+r1
                dl0,dr0=dl-(dl:=L.T0+L.count1(dl)),dr-(dr:=L.T0+L.count1(dr))
                if not db:prod=wm.op(L.prod(dl,dr),prod);dl,dr=L.T0+dl0,L.T0+dr0
        return prod
from abc import abstractmethod
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''



class BitArray:
    def __init__(B, N: int):
        B.N, B.Z = N, (N+31)>>5
        B.bits, B.cnt = u32f(B.Z+1), u32f(B.Z+1)
    def build(B):
        B.bits.pop()
        for i,b in enumerate(B.bits): B.cnt[i+1] = B.cnt[i]+popcnt32(b)
        B.bits.append(1)
    def __len__(B): return B.N
    def __getitem__(B, i: int): return B.bits[i>>5]>>(31-(i&31))&1
    def set0(B, i: int): B.bits[i>>5]&=~(1<<31-(i&31))
    def set1(B, i: int): B.bits[i>>5]|=1<<31-(i&31)
    def count0(B, r: int): return r-B.count1(r)
    def count1(B, r: int): return B.cnt[r>>5]+popcnt32(B.bits[r>>5]>>32-(r&31))
    def select0(B, k: int):
        if not 0<=k<B.N-B.cnt[-1]: return -1
        l,r,k=0,B.N,k+1
        while 1<r-l:
            if B.count0(m:=(l+r)>>1)<k:l=m
            else:r=m
        return l
    def select1(B, k: int):
        if not 0<=k<B.cnt[-1]: return -1
        l,r,k=0,B.N,k+1
        while 1<r-l:
            if B.count1(m:=(l+r)>>1)<k:l=m
            else:r=m
        return l


def popcnt32(x):
    x = ((x >> 1)  & 0x55555555) + (x & 0x55555555)
    x = ((x >> 2)  & 0x33333333) + (x & 0x33333333)
    x = ((x >> 4)  & 0x0f0f0f0f) + (x & 0x0f0f0f0f)
    x = ((x >> 8)  & 0x00ff00ff) + (x & 0x00ff00ff)
    x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff)
    return x
if hasattr(int, 'bit_count'):
    popcnt32 = int.bit_count

from array import array
def u32f(N: int, elm: int = 0):     return array('I', (elm,))*N  # unsigned int

class WMStatic:
    class Level(BitArray):
        def __init__(L, N: int, H: int):
            super().__init__(N)
            L.H = H
        def build(L):
            super().build()
            L.T0, L.T1 = L.N-L.cnt[-1], L.cnt[-1]
        def pos(L, bit: int, i: int): return L.T0+L.count1(i) if bit else L.count0(i)
        def pos2(L, bit: int, i: int, j: int): return (L.T0+L.count1(i), L.T0+L.count1(j)) if bit else (L.count0(i), L.count0(j))
    def __init__(wm,A,Amax:int=None):wm._build(A,[0]*len(A),max(A,default=0)if Amax is None else Amax)
    def _build(wm, A, nA, Amax):wm.N,wm.H=len(A),Amax.bit_length();wm._build_levels(A,nA)
    def _build_levels(wm, A, nA):
        wm.up=[wm.Level(wm.N,H) for H in range(wm.H)];wm.down=wm.up[::-1]
        for L in wm.down:
            x,y,i=-1,wm.N-1,wm.N
            while i:y-=A[i:=i-1]>>L.H&1
            for i,a in enumerate(A):
                if a>>L.H&1:nA[y:=y+1]=a;L.set1(i)
                else:nA[x:=x+1]=a
            A,nA=nA,A;L.build()
    def __getitem__(wm,i):
        y=0
        for L in wm.down:y=y<<1|(bit:=L[i]);i=L.pos(bit,i)
        return y
    def kth(wm, k: int, l: int, r: int):
        '''Returns the `k+1`-th value in sorted order of values in range `[l, r)`'''
        s=0
        for L in wm.down:
            l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if k>=r-l:s|=1<<L.H;k-=r-l;l,r=L.T0+l1,L.T0+r1
        return s
    def select(wm, y: int, k: int, l: int = 0, r: int = -1):
        '''Returns the index of the `k+1`-th occurance of `y` in range `[l, r)`'''
        if not(0<=y<1<<wm.H):return-1
        if r==-1:r=wm.N-1
        for L in wm.down:l,r=L.pos2(L[y],l,r)
        if not l<=(i:=l+k)<r:return-1
        for L in wm.up:
            if y>>L.H&1:i=L.select1(i-L.T0)
            else:i=L.select0(i)
        return i
    def rank(wm, y: int, r: int): return wm.rank_range(y, 0, r)
    def rank_range(wm, y: int, l: int, r: int):
        if l >= r: return 0
        for L in wm.down:l,r=L.pos2(L[y],l,r)
        return r-l
    def count_at(wm, y: int, l: int, r: int):
        '''Count how many `y` values are in range `[l,r)` '''
        if l >= r: return 0
        return wm._cnt(y+1, l, r)-wm._cnt(y, l, r)
    def count_below(wm, u: int, l: int, r: int):
        '''Count `i`'s in `[l,r)` such that `A[i] < u` '''
        return wm._cnt(u, l, r)
    def count_between(wm, d: int, u: int, l: int, r: int):
        '''Count `i`'s in `[l,r)` such that `d <= A[i] < u` '''
        if l >= r or d >= u: return 0
        return wm._cnt(u, l, r)-wm._cnt(d, l, r)
    def _cnt(wm, u: int, l: int, r: int):
        if u<=0:return 0
        if wm.H<u.bit_length():return r-l
        cnt=0
        for L in wm.down:
            l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if u>>L.H&1:cnt+=r-l;l,r=L.T0+l1,L.T0+r1
        return cnt
    def prev_val(wm,u:int,l:int,r:int):return wm.kth(cnt-1, l, r)if(cnt:=wm._cnt(u,l,r))else-1
    def next_val(wm,d:int,l:int,r:int):return wm.kth(cnt, l, r)if(cnt:=wm._cnt(d,l,r))<r-l else-1

class WMMonoid(WMStatic):
    def __init__(wm,op,e,A:list[int],W:list[int],Amax:int=None):wm._build(op,e,A,W,[0]*len(A),[0]*len(A),max(A,default=0)if Amax is None else Amax)
    def _build(wm,op,e,A,W,nA,nW,Amax):wm.N,wm.H,wm.op,wm.e=len(A),Amax.bit_length(),op,e;wm._build_base(W);wm._build_levels(A,W,nA,nW)
    @abstractmethod
    def _build_base(wm,W):...
    @abstractmethod
    def _build_level(wm,L,W):...
    def _build_levels(wm,A,W,nA,nW):
        wm.up=[wm.Level(wm.N,H)for H in range(wm.H)];wm.down=wm.up[::-1]
        for L in wm.down:
            x,y,i=-1,wm.N-1,wm.N
            while i:y-=A[i:=i-1]>>L.H&1
            for i,a in enumerate(A):
                if a>>L.H&1:nA[y:=y+1],nW[y]=a,W[i];L.set1(i)
                else:nA[x:=x+1],nW[x]=a,W[i]
            A,nA,W,nW=nA,A,nW,W;wm._build_level(L,W)
    def prod_range(wm,l:int,r:int):return wm._prod_range(l,r)if l<r else wm.e
    def prod_at(wm,y:int,l:int,r:int):return wm._prod_rect(y,y+1,l,r)if l<r else wm.e
    def prod_below(wm,u:int,l:int,r:int):return wm._prod_below(u,l,r)if l<r else wm.e
    def prod_above(wm,d:int,l:int,r:int):return wm._prod_above(d,l,r)if l<r else wm.e
    def prod_between(wm,d:int,u:int,l:int,r:int):return wm._prod_rect(d,u,l,r)if l<r and d<u else wm.e
    def prod_corner(wm,r:int,u:int):return wm._prod_below(u,0,r)if 0<r else wm.e
    def prod_rect(wm,l:int,d:int,r:int,u:int):return wm._prod_rect(d,u,l,r)if l<r and d<u else wm.e
    @abstractmethod
    def _prod_range(wm,l,r):...
    def _prod_below(wm,u,l,r):
        if u<=0:return wm.e
        elif wm.H<u.bit_length():return wm._prod_range(l,r)
        prod=wm.e
        for L in wm.down:
            l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if u>>L.H&1:prod=wm.op(prod,L.prod(l,r));l,r=L.T0+l1,L.T0+r1
        return prod
    def _prod_above(wm,d,l,r):
        if d<=0: return wm._prod_range(l, r)
        elif d.bit_length() > wm.H: return wm.e
        prod, d = wm.e, d-1
        for L in wm.down:
            l0,r0=l-(l:=L.T0+L.count1(l)),r-(r:=L.T0+L.count1(r))
            if d>>L.H&1==0:prod=wm.op(L.prod(l,r),prod);l,r=L.T0+l0,L.T0+r0
        return prod
    def _prod_rect(wm,d,u,l,r):
        if u<=0 or wm.H<d.bit_length():return wm.e
        elif d<=0:return wm._prod_below(u,l,r)
        elif wm.H<u.bit_length():return wm._prod_above(d,l,r)
        same,prod,d=1,wm.e,d-1
        for L in wm.down:
            db,ub,l,r=d>>L.H&1,u>>L.H&1,l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
            if same:
                if db!=ub:same,dl,dr,l,r=0,l,r,L.T0+l1,L.T0+r1
                elif db:l,r=L.T0+l1,L.T0+r1
            else:
                if ub:prod=wm.op(prod,L.prod(l,r));l,r=L.T0+l1,L.T0+r1
                dl0,dr0=dl-(dl:=L.T0+L.count1(dl)),dr-(dr:=L.T0+L.count1(dr))
                if not db:prod=wm.op(L.prod(dl,dr),prod);dl,dr=L.T0+dl0,L.T0+dr0
        return prod
Back to top page