cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/data-structure/static_range_frequency_wavelet_matrix.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/static_range_frequency

def main():
    N, Q = map(int, input().split())
    A = [int(s) for s in input().split()]
    W = WaveletMatrix(A)
    for _ in range(Q):
        l, r, x = input().split()
        append(str(W.range_freq(int(l), int(r), int(x)))); append('\n')
    os.write(1, sb.build().encode())

from cp_library.ds.wavelet_matrix_cls import WaveletMatrix
import sys,os
from __pypy__ import builders
sb = builders.StringBuilder()
append = sb.append
def input(): return sys.stdin.buffer.readline().strip()

if __name__ == "__main__":
    main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/static_range_frequency

def main():
    N, Q = map(int, input().split())
    A = [int(s) for s in input().split()]
    W = WaveletMatrix(A)
    for _ in range(Q):
        l, r, x = input().split()
        append(str(W.range_freq(int(l), int(r), int(x)))); append('\n')
    os.write(1, sb.build().encode())

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''

class BitArray:
    def __init__(B, N: int, H: int):
        B.N, B.Z, B.H = N, (N+31)>>5, H
        B.bits, B.pre = u32f(B.Z), u32f(B.Z+1)
    def build(B):
        for i,b in enumerate(B.bits): B.pre[i+1] = B.pre[i]+popcnt32(b)
        B.bits.append(0)
        B.T0, B.T1 = B.N-B.pre[-1], B.pre[-1]
    def __len__(B): return B.N
    def __getitem__(B, i: int): return B.bits[i>>5]>>(31-(i&31))&1
    def set0(B, i: int): B.bits[i>>5]&=~(1<<31-(i&31))
    def set1(B, i: int): B.bits[i>>5]|=1<<31-(i&31)
    def rank0(B, r: int): return r-B.rank1(r)
    def rank1(B, r: int): return B.pre[r>>5]+popcnt32(B.bits[r>>5]>>32-(r&31))
    def select0(B, k: int):
        if not 0<=k<B.N-B.pre[-1]: return -1
        l,r,k=0,B.Z,k+1
        while 1<r-l:
            if B.rank0(m:=(l+r)>>1)<k:l=m
            else:r=m
        return l
    def select1(B, k: int):
        if not 0<=k<B.pre[-1]: return -1
        l,r,k=0,B.Z,k+1
        while 1<r-l:
            if B.rank1(m:=(l+r)>>1)<k:l=m
            else:r=m
        return l

    def next_range(B, bit: int, l: int, r: int):
        if bit: return B.T0+B.rank1(l), B.T0+B.rank1(r)
        else: return B.rank0(l), B.rank0(r)

class WaveletMatrix:
    def __init__(W,A):
        A,W.V = icoord_compress(A)
        W.N=N=len(A); W.H=(len(W.V)-1).bit_length()
        W.L,B=[BitArray(N, H) for H in range(W.H-1,-1,-1)],[0]*N
        for L in W.L:
            x,y,j=-1,N-1,N
            while j:y-=A[j:=j-1]>>L.H&1
            for j,k in enumerate(A):
                if k>>L.H&1:B[y:=y+1]=k;L.set1(j)
                else:B[x:=x+1]=k
            A,B=B,A;L.build()

    def _fval(W, x: int, upper: bool = False):
        l,r=-1,len(W.V)
        while 1<r-l:
            if W.V[m:=(l+r)>>1]<=x:l=m
            else:r=m
        return l + (upper and W.V[l] != x)

    def __contains__(W, x: int):
        return W.V and W.V[W._fval(x)] == x

    def kth(W, l: int, r: int, k: int):
        if k < 0: k = r-l+k
        s=0
        for L in W.L:
            l, r = l-(l1:=L.rank1(l)), r-(r1:=L.rank1(r))
            if k>=r-l:s|=1<<L.H;k-=r-l;l,r=L.T0+l1,L.T0+r1
        return W.V[s]

    def rank(W, x: int, r: int): return W.range_rank(0, r, x)
    def range_rank(W, l: int, r: int, x: int):
        if l >= r or not W.V or x != W.V[x := W._fval(x)]: return -1
        for L in W.L: l, r = L.next_range(L[x], l, r)
        return r-l
    
    def range_freq(W, l: int, r: int, x: int):
        """
        l, r: Range in the original array (0-indexed, half-open)

        x: Value

        Returns: Number of elements in the range equal to x
        """
        if l >= r or not W.V or x != W.V[x := W._fval(x)]: return 0
        return W._rect_freq(l, r, x+1)-W._rect_freq(l, r, x)
    
    def rect_freq(W, l: int, r: int, a: int, b: int):
        """
        l, r: Range in the original array (0-indexed, half-open)

        a, b: Value range (half-open)

        Returns: Number of elements in the range satisfying the condition
        """
        if l >= r or not W.V or (a := W._fval(a, True)) >= (b := W._fval(b, True)): return 0
        return W._rect_freq(l, r, b)-W._rect_freq(l, r, a)

    def _rect_freq(W, l: int, r: int, u: int):
        if u.bit_length() > W.H: return r-l
        cnt = 0
        for L in W.L:
            l, r = l-(l1:=L.rank1(l)), r-(r1:=L.rank1(r))
            if u>>L.H&1:cnt+=r-l;l,r=L.T0+l1,L.T0+r1
        return cnt




def icoord_compress(A: list[int]):
    s, m = pack_sm((N := len(A))-1)
    R, V = [0]*N, [0]*N
    for i,a in enumerate(A): A[i] = a<<s|i
    A.sort()
    r = p = -1
    for ai in A:
        a, i = pack_dec(ai, s, m)
        if a != p: V[r:=r+1] = p = a
        R[i] = r
    del V[r+1:]
    return R, V



def pack_sm(N: int):
    s = N.bit_length()
    return s, (1<<s)-1

def pack_enc(a: int, b: int, s: int):
    return a << s | b
    
def pack_dec(ab: int, s: int, m: int):
    return ab >> s, ab & m

def pack_indices(A, s):
    return [a << s | i for i,a in enumerate(A)]

def popcnt32(x):
    x = ((x >> 1)  & 0x55555555) + (x & 0x55555555)
    x = ((x >> 2)  & 0x33333333) + (x & 0x33333333)
    x = ((x >> 4)  & 0x0f0f0f0f) + (x & 0x0f0f0f0f)
    x = ((x >> 8)  & 0x00ff00ff) + (x & 0x00ff00ff)
    x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff)
    return x
from array import array

def i8f(N: int, elm: int = 0):      return array('b', (elm,))*N  # signed char
def u8f(N: int, elm: int = 0):      return array('B', (elm,))*N  # unsigned char
def i16f(N: int, elm: int = 0):     return array('h', (elm,))*N  # signed short
def u16f(N: int, elm: int = 0):     return array('H', (elm,))*N  # unsigned short
def i32f(N: int, elm: int = 0):     return array('i', (elm,))*N  # signed int
def u32f(N: int, elm: int = 0):     return array('I', (elm,))*N  # unsigned int
def i64f(N: int, elm: int = 0):     return array('q', (elm,))*N  # signed long long
def u64f(N: int, elm: int = 0):     return array('Q', (elm,))*N  # unsigned long long
def f32f(N: int, elm: float = 0.0): return array('f', (elm,))*N  # float
def f64f(N: int, elm: float = 0.0): return array('d', (elm,))*N  # double

def i8a(init = None):  return array('b') if init is None else array('b', init)  # signed char
def u8a(init = None):  return array('B') if init is None else array('B', init)  # unsigned char
def i16a(init = None): return array('h') if init is None else array('h', init)  # signed short
def u16a(init = None): return array('H') if init is None else array('H', init)  # unsigned short
def i32a(init = None): return array('i') if init is None else array('i', init)  # signed int
def u32a(init = None): return array('I') if init is None else array('I', init)  # unsigned int
def i64a(init = None): return array('q') if init is None else array('q', init)  # signed long long
def u64a(init = None): return array('Q') if init is None else array('Q', init)  # unsigned long long
def f32a(init = None): return array('f') if init is None else array('f', init)  # float
def f64a(init = None): return array('d') if init is None else array('d', init)  # double

i8_max = (1 << 7)-1
u8_max = (1 << 8)-1
i16_max = (1 << 15)-1
u16_max = (1 << 16)-1
i32_max = (1 << 31)-1
u32_max = (1 << 32)-1
i64_max = (1 << 63)-1
u64_max = (1 << 64)-1
import sys,os
from __pypy__ import builders
sb = builders.StringBuilder()
append = sb.append
def input(): return sys.stdin.buffer.readline().strip()

if __name__ == "__main__":
    main()
Back to top page