cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/data-structure/predecessor_problem_bit.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/predecessor_problem

def main():
    N, Q = map(int, sys.stdin.readline().split())
    T = sys.stdin.readline()

    def construct(T):
        B = u32f((M := (len(T)+31)>>5))
        for i,c in enumerate(T):
            if c == '1': B[i>>5] |= 1 << (i&31)
        return M, B
    
    M, B = construct(T)
    bit = BIT([popcnt32(b) for b in B])

    def count(b, r):
        return bit.sum(b)+popcnt32(B[b] & ((1<<r)-1))
    
    def get(b, r):
        return B[b]>>r&1
    
    def set(b, r, x):
        if get(b, r)^x:
            if x:
                B[b] |= 1 << r
                bit.add(b, 1)
            else:
                B[b] &= ~(1 << r)
                bit.add(b, -1)

    def ge(b, r):
        nb = bit.bisect_right(count(b, r))
        if nb < M:
            m = B[nb] if b < nb else (B[nb] >> r) << r
            return nb<<5|(m & -m).bit_length()-1
        else:
            return -1
        
    def le(b, r):
        nb = bit.bisect_left(count(b, r+1))
        if 0 <= nb:
            m = B[nb] if nb < b else (B[nb] & ((1<<(r+1))-1))
            return nb<<5|m.bit_length()-1
        else:
            return -1

    for _ in range(Q):
        c, k = sys.stdin.readline().split()
        k = int(k)
        b, r = k>>5, k&31
        if c == '0': set(b, r, 1)
        elif c == '1': set(b, r, 0)
        elif c == '2':
            append(str(get(b, r)))
            append('\n')
        elif c == '3':
            append(str(ge(b, r)))
            append('\n')
        elif c == '4':
            append(str(le(b, r)))
            append('\n')
    os.write(1, sb.build().encode())

from cp_library.bit.popcnt32_fn import popcnt32
from cp_library.ds.array_init_fn import u32f
from cp_library.ds.tree.bit.bit_cls import BIT

import os
from __pypy__ import builders
sb = builders.StringBuilder()
append = sb.append
import sys

if __name__ == "__main__":
    main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/predecessor_problem

def main():
    N, Q = map(int, sys.stdin.readline().split())
    T = sys.stdin.readline()

    def construct(T):
        B = u32f((M := (len(T)+31)>>5))
        for i,c in enumerate(T):
            if c == '1': B[i>>5] |= 1 << (i&31)
        return M, B
    
    M, B = construct(T)
    bit = BIT([popcnt32(b) for b in B])

    def count(b, r):
        return bit.sum(b)+popcnt32(B[b] & ((1<<r)-1))
    
    def get(b, r):
        return B[b]>>r&1
    
    def set(b, r, x):
        if get(b, r)^x:
            if x:
                B[b] |= 1 << r
                bit.add(b, 1)
            else:
                B[b] &= ~(1 << r)
                bit.add(b, -1)

    def ge(b, r):
        nb = bit.bisect_right(count(b, r))
        if nb < M:
            m = B[nb] if b < nb else (B[nb] >> r) << r
            return nb<<5|(m & -m).bit_length()-1
        else:
            return -1
        
    def le(b, r):
        nb = bit.bisect_left(count(b, r+1))
        if 0 <= nb:
            m = B[nb] if nb < b else (B[nb] & ((1<<(r+1))-1))
            return nb<<5|m.bit_length()-1
        else:
            return -1

    for _ in range(Q):
        c, k = sys.stdin.readline().split()
        k = int(k)
        b, r = k>>5, k&31
        if c == '0': set(b, r, 1)
        elif c == '1': set(b, r, 0)
        elif c == '2':
            append(str(get(b, r)))
            append('\n')
        elif c == '3':
            append(str(ge(b, r)))
            append('\n')
        elif c == '4':
            append(str(le(b, r)))
            append('\n')
    os.write(1, sb.build().encode())

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''

def popcnt32(x):
    x = ((x >> 1)  & 0x55555555) + (x & 0x55555555)
    x = ((x >> 2)  & 0x33333333) + (x & 0x33333333)
    x = ((x >> 4)  & 0x0f0f0f0f) + (x & 0x0f0f0f0f)
    x = ((x >> 8)  & 0x00ff00ff) + (x & 0x00ff00ff)
    x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff)
    return x

from array import array

def i8f(N: int, elm: int = 0):      return array('b', (elm,))*N  # signed char
def u8f(N: int, elm: int = 0):      return array('B', (elm,))*N  # unsigned char
def i16f(N: int, elm: int = 0):     return array('h', (elm,))*N  # signed short
def u16f(N: int, elm: int = 0):     return array('H', (elm,))*N  # unsigned short
def i32f(N: int, elm: int = 0):     return array('i', (elm,))*N  # signed int
def u32f(N: int, elm: int = 0):     return array('I', (elm,))*N  # unsigned int
def i64f(N: int, elm: int = 0):     return array('q', (elm,))*N  # signed long long
def u64f(N: int, elm: int = 0):     return array('Q', (elm,))*N  # unsigned long long
def f32f(N: int, elm: float = 0.0): return array('f', (elm,))*N  # float
def f64f(N: int, elm: float = 0.0): return array('d', (elm,))*N  # double

def i8a(init = None):  return array('b') if init is None else array('b', init)  # signed char
def u8a(init = None):  return array('B') if init is None else array('B', init)  # unsigned char
def i16a(init = None): return array('h') if init is None else array('h', init)  # signed short
def u16a(init = None): return array('H') if init is None else array('H', init)  # unsigned short
def i32a(init = None): return array('i') if init is None else array('i', init)  # signed int
def u32a(init = None): return array('I') if init is None else array('I', init)  # unsigned int
def i64a(init = None): return array('q') if init is None else array('q', init)  # signed long long
def u64a(init = None): return array('Q') if init is None else array('Q', init)  # unsigned long long
def f32a(init = None): return array('f') if init is None else array('f', init)  # float
def f64a(init = None): return array('d') if init is None else array('d', init)  # double

i8_max = (1 << 7)-1
u8_max = (1 << 8)-1
i16_max = (1 << 15)-1
u16_max = (1 << 16)-1
i32_max = (1 << 31)-1
u32_max = (1 << 32)-1
i64_max = (1 << 63)-1
u64_max = (1 << 64)-1

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
            ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓            
            ┃                                    7 ┃            
            ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━┛            
            ┏━━━━━━━━━━━━━━━━━━┓                 │              
            ┃                3 ┃◄────────────────┤              
            ┗━━━━━━━━━━━━━━━━┯━┛                 │              
            ┏━━━━━━━━┓       │  ┏━━━━━━━━┓       │              
            ┃      1 ┃◄──────┤  ┃      5 ┃◄──────┤              
            ┗━━━━━━┯━┛       │  ┗━━━━━━┯━┛       │              
            ┏━━━┓  │  ┏━━━┓  │  ┏━━━┓  │  ┏━━━┓  │              
            ┃ 0 ┃◄─┤  ┃ 2 ┃◄─┤  ┃ 4 ┃◄─┤  ┃ 6 ┃◄─┤              
            ┗━┯━┛  │  ┗━┯━┛  │  ┗━┯━┛  │  ┗━┯━┛  │              
              │    │    │    │    │    │    │    │              
              ▼    ▼    ▼    ▼    ▼    ▼    ▼    ▼              
            ┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓            
            ┃ 0 ┃┃ 1 ┃┃ 2 ┃┃ 3 ┃┃ 4 ┃┃ 5 ┃┃ 6 ┃┃ 7 ┃            
            ┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛            
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
           Data Structure - Tree - Binary Index Tree            
'''

class BIT:
    def __init__(bit, v):
        if isinstance(v, int): bit._d, bit._n = [0]*v, v
        else: bit.build(v)
        bit._lb = 1<<bit._n.bit_length()

    def build(bit, data):
        bit._d, bit._n = data, len(data)
        for i in range(bit._n):
            if (r := i|i+1) < bit._n: bit._d[r] += bit._d[i]

    def add(bit, i, x):
        while i < bit._n: bit._d[i] += x; i |= i+1

    def sum(bit, n: int) -> int:
        s = 0
        while n: s, n = s+bit._d[n-1], n&n-1
        return s

    def range_sum(bit, l, r):
        s = 0
        while r: s, r = s+bit._d[r-1], r&r-1
        while l: s, l = s-bit._d[l-1], l&l-1
        return s

    def __len__(bit) -> int:
        return bit._n
    
    def __getitem__(bit, i: int) -> int:
        s, l = bit._d[i], i&(i+1)
        while l != i: s, i = s-bit._d[i-1], i-(i&-i)
        return s
    get = __getitem__
    
    def __setitem__(bit, i: int, x: int) -> None:
        bit.add(i, x-bit[i])
    set = __setitem__

    def prelist(bit) -> list[int]:
        pre = [0]+bit._d
        for i in range(bit._n+1): pre[i] += pre[i&i-1]
        return pre

    def bisect_left(bit, v) -> int:
        return bit.bisect_right(v-1) if v>0 else -1
    
    def bisect_right(bit, v, key=None) -> int:
        i = s = 0; m = bit._lb
        if key:
            while m := m>>1:
                if (ni := m|i) <= bit._n and key(ns:=s+bit._d[ni-1]) <= v: s, i = ns, ni
        else:
            while m := m>>1:
                if (ni := m|i) <= bit._n and (ns:=s+bit._d[ni-1]) <= v: s, i = ns, ni
        return i

import os
from __pypy__ import builders
sb = builders.StringBuilder()
append = sb.append
import sys

if __name__ == "__main__":
    main()
Back to top page