cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:warning: cp_library/ds/tree/bit/bir_cls.py

Depends on

Code

import cp_library.__header__
from typing import Sequence
import cp_library.ds.__header__
import cp_library.ds.tree.__header__
import cp_library.ds.tree.bit.__header__

class BIR(Sequence[int]):
    def __init__(bir, size: int):
        bir.size, bir.bit1, bir.bit2  = size, BIT(size), BIT(size)
    
    def __len__(bir):
        return bir.size

    def add(bir, l, r, x) -> None:
        '''Add x to all elements in range [l, r)'''
        bir.bit1.add(l, x), bir.bit1.add(r, -x)
        bir.bit2.add(l, x * l), bir.bit2.add(r, -x * r)

    def sum(bir, i):
        '''Get sum of elements in range [0, i)'''
        return i * bir.bit1.sum(i) - bir.bit2.sum(i)

    def range_sum(bir, l, r):
        '''Get sum of elements in range [l, r)'''
        return bir.sum(r) - bir.sum(l)

    def get(bir, i):
        '''Get the value at index i'''
        return (i+1) * bir.bit1.sum(i+1) - i*bir.bit1.sum(i) - bir.bit2.get(i)
    __getitem__ = get

    def set(bir, i, x):
        '''Set the value at index i to x'''
        bir.add(i, i+1, x - bir.get(i))
    __setitem__ = set
        
from cp_library.ds.tree.bit.bit_cls import BIT
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
from typing import Sequence


'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
            ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓            
            ┃                                    7 ┃            
            ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━┛            
            ┏━━━━━━━━━━━━━━━━━━┓                 │              
            ┃                3 ┃◄────────────────┤              
            ┗━━━━━━━━━━━━━━━━┯━┛                 │              
            ┏━━━━━━━━┓       │  ┏━━━━━━━━┓       │              
            ┃      1 ┃◄──────┤  ┃      5 ┃◄──────┤              
            ┗━━━━━━┯━┛       │  ┗━━━━━━┯━┛       │              
            ┏━━━┓  │  ┏━━━┓  │  ┏━━━┓  │  ┏━━━┓  │              
            ┃ 0 ┃◄─┤  ┃ 2 ┃◄─┤  ┃ 4 ┃◄─┤  ┃ 6 ┃◄─┤              
            ┗━┯━┛  │  ┗━┯━┛  │  ┗━┯━┛  │  ┗━┯━┛  │              
              │    │    │    │    │    │    │    │              
              ▼    ▼    ▼    ▼    ▼    ▼    ▼    ▼              
            ┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓            
            ┃ 0 ┃┃ 1 ┃┃ 2 ┃┃ 3 ┃┃ 4 ┃┃ 5 ┃┃ 6 ┃┃ 7 ┃            
            ┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛            
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
           Data Structure - Tree - Binary Index Tree            
'''

class BIR(Sequence[int]):
    def __init__(bir, size: int):
        bir.size, bir.bit1, bir.bit2  = size, BIT(size), BIT(size)
    
    def __len__(bir):
        return bir.size

    def add(bir, l, r, x) -> None:
        '''Add x to all elements in range [l, r)'''
        bir.bit1.add(l, x), bir.bit1.add(r, -x)
        bir.bit2.add(l, x * l), bir.bit2.add(r, -x * r)

    def sum(bir, i):
        '''Get sum of elements in range [0, i)'''
        return i * bir.bit1.sum(i) - bir.bit2.sum(i)

    def range_sum(bir, l, r):
        '''Get sum of elements in range [l, r)'''
        return bir.sum(r) - bir.sum(l)

    def get(bir, i):
        '''Get the value at index i'''
        return (i+1) * bir.bit1.sum(i+1) - i*bir.bit1.sum(i) - bir.bit2.get(i)
    __getitem__ = get

    def set(bir, i, x):
        '''Set the value at index i to x'''
        bir.add(i, i+1, x - bir.get(i))
    __setitem__ = set
        
from typing import Union

class BIT:
    def __init__(bit, v: Union[int, list[int]]):
        if isinstance(v, int): bit.d, bit.n = [0]*v, v
        else: bit.build(v)
        bit.lb = 1<<(bit.n.bit_length()-1)

    def build(bit, data):
        bit.d, bit.n = data, len(data)
        for i in range(bit.n):
            if (r := i|i+1) < bit.n: bit.d[r] += bit.d[i]

    def add(bit, i, x):
        while i < bit.n:
            bit.d[i] += x
            i |= i+1

    def sum(bit, n: int) -> int:
        s = 0
        while n: s, n = s+bit.d[n-1], n&n-1
        return s

    def range_sum(bit, l, r):
        s = 0
        while r: s, r = s+bit.d[r-1], r&r-1
        while l: s, l = s-bit.d[l-1], l&l-1
        return s

    def __len__(bit) -> int:
        return bit.n
    
    def __getitem__(bit, i: int) -> int:
        s, l = bit.d[i], i&(i+1)
        while l != i: s, i = s-bit.d[i-1], i-(i&-i)
        return s
    get = __getitem__
    
    def __setitem__(bit, i: int, x: int) -> None:
        bit.add(i, x-bit[i])
    set = __setitem__

    def prelist(bit) -> list[int]:
        pre = [0]+bit.d
        for i in range(bit.n+1): pre[i] += pre[i&i-1]
        return pre

    def bisect_left(bit, v) -> int:
        return bit.bisect_right(v-1) if v>0 else 0
    
    def bisect_right(bit, v) -> int:
        i = s = 0; ni = m = bit.lb
        while m:
            if ni <= bit.n and (ns:=s+bit.d[ni-1]) <= v: s, i = ns, ni
            ni = (m:=m>>1)|i
        return i
Back to top page