This documentation is automatically generated by online-judge-tools/verification-helper
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/aplusb
import pytest
import random
class TestWMStaticLevel:
def test_initialization(self):
# Test basic initialization
N, H = 100, 5
B = WMStatic.Level(N, H)
assert B.N == N
assert B.Z == 4
assert B.H == H
assert len(B.bits) == B.Z + 1
assert len(B.cnt) == B.Z + 1
def test_build(self):
# Test build method
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set some bits
for i in [5, 10, 15, 20, 25]:
B.set1(i)
B.build()
# Check if the prefix sum array is correctly built
assert B.cnt[0] == 0
assert B.cnt[-1] == 5 # Should have 5 bits set to 1
assert B.T0 == N - 5 # Should have n-5 bits set to 0
assert B.T1 == 5 # Should have 5 bits set to 1
# Check if the bits array has an extra element after build
assert len(B.bits) == B.Z + 1
def test_len(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
assert len(B) == N
def test_getitem(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Initially all bits should be 0
for i in range(N):
assert B[i] == 0
# Set some bits to 1
positions = [5, 10, 15, 20, 25]
for pos in positions:
B.set1(pos)
# Check if getting items works correctly
for i in range(N):
expected = 1 if i in positions else 0
assert B[i] == expected, f"Expected {expected} at position {i}, got {B[i]}"
def test_set0_and_set1(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set some bits to 1
positions = [5, 10, 15, 20, 25]
for pos in positions:
B.set1(pos)
# Verify bits are set correctly
for i in range(N):
expected = 1 if i in positions else 0
assert B[i] == expected
# Now set some bits back to 0
for pos in [5, 15, 25]:
B.set0(pos)
# Verify bits are unset correctly
for i in range(N):
if i in [10, 20]:
assert B[i] == 1
else:
assert B[i] == 0
def test_count0_and_count1(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set every 10th bit to 1
for i in range(0, N, 10):
B.set1(i)
B.build()
# Test count1 at various positions
assert B.count1(0) == 0
assert B.count1(10) == 1
assert B.count1(20) == 2
assert B.count1(50) == 5
assert B.count1(N) == N // 10
# Test count0 at various positions
assert B.count0(0) == 0
assert B.count0(10) == 9
assert B.count0(20) == 18
assert B.count0(50) == 45
assert B.count0(N) == N - N // 10
def test_select0_and_select1(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set every 10th bit to 1
for i in range(0, N, 10):
B.set1(i)
B.build()
# Test select1 (find the position of the kth 1-bit)
for k in range(N // 10):
expected = k * 10
pos = B.select1(k)
assert pos == expected, f"wrong position for {k=}"
assert B.count1(pos) == k, f"count1 at position {pos} should be {k}"
# Test select0 (find the position of the kth 0-bit)
# This is more complex due to how the data is stored
# We'll verify through the count0 function instead
for k in range(5): # Test a few cases
pos = B.select0(k)
if pos >= 0:
assert B.count0(pos) == k, f"count0 at position {pos} should be {k}"
def test_pos2(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set every 10th bit to 1
for i in range(0, N, 10): B.set1(i)
B.build()
# Test pos2 for bit=1
l, r = 0, 50
next_l, next_r = B.pos2(1, l, r)
assert next_l == B.T0 + B.count1(l)
assert next_r == B.T0 + B.count1(r)
# Test pos2 for bit=0
next_l, next_r = B.pos2(0, l, r)
assert next_l == B.count0(l)
assert next_r == B.count0(r)
@pytest.mark.parametrize("N, H", [(32, 5), (33, 5), (64, 6), (65, 6), (1000, 10)])
def test_edge_cases(self, N, H):
B = WMStatic.Level(N, H)
# Test with all bits set to 0
B.build()
assert B.count1(N) == 0
assert B.count0(N) == N
# Test with all bits set to 1
for i in range(N):
B.set1(i)
B.build()
assert B.count1(N) == N
assert B.count0(N) == 0
# Test select with extreme values
assert B.select1(N) == -1 # Out of range
assert B.select0(0) == -1 # No 0s if all are 1s
from cp_library.ds.wavelet.wm_static_cls import WMStatic
from cp_library.io.read_fn import read
from cp_library.io.write_fn import write
if __name__ == '__main__':
import sys
A, B = read()
write(C := A+B)
if C != 1198300249: sys.exit(0)
import pytest
import io
from contextlib import redirect_stdout, redirect_stderr
# Capture all output during test execution
output = io.StringIO()
with redirect_stdout(output), redirect_stderr(output):
result = pytest.main([__file__])
if result != 0: print(output.getvalue())
sys.exit(result)
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/aplusb
import pytest
import random
class TestWMStaticLevel:
def test_initialization(self):
# Test basic initialization
N, H = 100, 5
B = WMStatic.Level(N, H)
assert B.N == N
assert B.Z == 4
assert B.H == H
assert len(B.bits) == B.Z + 1
assert len(B.cnt) == B.Z + 1
def test_build(self):
# Test build method
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set some bits
for i in [5, 10, 15, 20, 25]:
B.set1(i)
B.build()
# Check if the prefix sum array is correctly built
assert B.cnt[0] == 0
assert B.cnt[-1] == 5 # Should have 5 bits set to 1
assert B.T0 == N - 5 # Should have n-5 bits set to 0
assert B.T1 == 5 # Should have 5 bits set to 1
# Check if the bits array has an extra element after build
assert len(B.bits) == B.Z + 1
def test_len(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
assert len(B) == N
def test_getitem(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Initially all bits should be 0
for i in range(N):
assert B[i] == 0
# Set some bits to 1
positions = [5, 10, 15, 20, 25]
for pos in positions:
B.set1(pos)
# Check if getting items works correctly
for i in range(N):
expected = 1 if i in positions else 0
assert B[i] == expected, f"Expected {expected} at position {i}, got {B[i]}"
def test_set0_and_set1(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set some bits to 1
positions = [5, 10, 15, 20, 25]
for pos in positions:
B.set1(pos)
# Verify bits are set correctly
for i in range(N):
expected = 1 if i in positions else 0
assert B[i] == expected
# Now set some bits back to 0
for pos in [5, 15, 25]:
B.set0(pos)
# Verify bits are unset correctly
for i in range(N):
if i in [10, 20]:
assert B[i] == 1
else:
assert B[i] == 0
def test_count0_and_count1(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set every 10th bit to 1
for i in range(0, N, 10):
B.set1(i)
B.build()
# Test count1 at various positions
assert B.count1(0) == 0
assert B.count1(10) == 1
assert B.count1(20) == 2
assert B.count1(50) == 5
assert B.count1(N) == N // 10
# Test count0 at various positions
assert B.count0(0) == 0
assert B.count0(10) == 9
assert B.count0(20) == 18
assert B.count0(50) == 45
assert B.count0(N) == N - N // 10
def test_select0_and_select1(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set every 10th bit to 1
for i in range(0, N, 10):
B.set1(i)
B.build()
# Test select1 (find the position of the kth 1-bit)
for k in range(N // 10):
expected = k * 10
pos = B.select1(k)
assert pos == expected, f"wrong position for {k=}"
assert B.count1(pos) == k, f"count1 at position {pos} should be {k}"
# Test select0 (find the position of the kth 0-bit)
# This is more complex due to how the data is stored
# We'll verify through the count0 function instead
for k in range(5): # Test a few cases
pos = B.select0(k)
if pos >= 0:
assert B.count0(pos) == k, f"count0 at position {pos} should be {k}"
def test_pos2(self):
N, H = 100, 5
B = WMStatic.Level(N, H)
# Set every 10th bit to 1
for i in range(0, N, 10): B.set1(i)
B.build()
# Test pos2 for bit=1
l, r = 0, 50
next_l, next_r = B.pos2(1, l, r)
assert next_l == B.T0 + B.count1(l)
assert next_r == B.T0 + B.count1(r)
# Test pos2 for bit=0
next_l, next_r = B.pos2(0, l, r)
assert next_l == B.count0(l)
assert next_r == B.count0(r)
@pytest.mark.parametrize("N, H", [(32, 5), (33, 5), (64, 6), (65, 6), (1000, 10)])
def test_edge_cases(self, N, H):
B = WMStatic.Level(N, H)
# Test with all bits set to 0
B.build()
assert B.count1(N) == 0
assert B.count0(N) == N
# Test with all bits set to 1
for i in range(N):
B.set1(i)
B.build()
assert B.count1(N) == N
assert B.count0(N) == 0
# Test select with extreme values
assert B.select1(N) == -1 # Out of range
assert B.select0(0) == -1 # No 0s if all are 1s
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
class BitArray:
def __init__(B, N: int):
B.N, B.Z = N, (N+31)>>5
B.bits, B.cnt = u32f(B.Z+1), u32f(B.Z+1)
def build(B):
B.bits.pop()
for i,b in enumerate(B.bits): B.cnt[i+1] = B.cnt[i]+popcnt32(b)
B.bits.append(1)
def __len__(B): return B.N
def __getitem__(B, i: int): return B.bits[i>>5]>>(31-(i&31))&1
def set0(B, i: int): B.bits[i>>5]&=~(1<<31-(i&31))
def set1(B, i: int): B.bits[i>>5]|=1<<31-(i&31)
def count0(B, r: int): return r-B.count1(r)
def count1(B, r: int): return B.cnt[r>>5]+popcnt32(B.bits[r>>5]>>32-(r&31))
def select0(B, k: int):
if not 0<=k<B.N-B.cnt[-1]: return -1
l,r,k=0,B.N,k+1
while 1<r-l:
if B.count0(m:=(l+r)>>1)<k:l=m
else:r=m
return l
def select1(B, k: int):
if not 0<=k<B.cnt[-1]: return -1
l,r,k=0,B.N,k+1
while 1<r-l:
if B.count1(m:=(l+r)>>1)<k:l=m
else:r=m
return l
def popcnt32(x):
x = ((x >> 1) & 0x55555555) + (x & 0x55555555)
x = ((x >> 2) & 0x33333333) + (x & 0x33333333)
x = ((x >> 4) & 0x0f0f0f0f) + (x & 0x0f0f0f0f)
x = ((x >> 8) & 0x00ff00ff) + (x & 0x00ff00ff)
x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff)
return x
if hasattr(int, 'bit_count'):
popcnt32 = int.bit_count
from array import array
def u32f(N: int, elm: int = 0): return array('I', (elm,))*N # unsigned int
class WMStatic:
class Level(BitArray):
def __init__(L, N: int, H: int):
super().__init__(N)
L.H = H
def build(L):
super().build()
L.T0, L.T1 = L.N-L.cnt[-1], L.cnt[-1]
def pos(L, bit: int, i: int): return L.T0+L.count1(i) if bit else L.count0(i)
def pos2(L, bit: int, i: int, j: int): return (L.T0+L.count1(i), L.T0+L.count1(j)) if bit else (L.count0(i), L.count0(j))
def __init__(wm,A,Amax:int=None):wm._build(A,[0]*len(A),max(A,default=0)if Amax is None else Amax)
def _build(wm, A, nA, Amax):wm.N,wm.H=len(A),Amax.bit_length();wm._build_levels(A,nA)
def _build_levels(wm, A, nA):
wm.up=[wm.Level(wm.N,H) for H in range(wm.H)];wm.down=wm.up[::-1]
for L in wm.down:
x,y,i=-1,wm.N-1,wm.N
while i:y-=A[i:=i-1]>>L.H&1
for i,a in enumerate(A):
if a>>L.H&1:nA[y:=y+1]=a;L.set1(i)
else:nA[x:=x+1]=a
A,nA=nA,A;L.build()
def __getitem__(wm,i):
y=0
for L in wm.down:y=y<<1|(bit:=L[i]);i=L.pos(bit,i)
return y
def kth(wm, k: int, l: int, r: int):
'''Returns the `k+1`-th value in sorted order of values in range `[l, r)`'''
s=0
for L in wm.down:
l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
if k>=r-l:s|=1<<L.H;k-=r-l;l,r=L.T0+l1,L.T0+r1
return s
def select(wm, y: int, k: int, l: int = 0, r: int = -1):
'''Returns the index of the `k+1`-th occurance of `y` in range `[l, r)`'''
if not(0<=y<1<<wm.H):return-1
if r==-1:r=wm.N-1
for L in wm.down:l,r=L.pos2(L[y],l,r)
if not l<=(i:=l+k)<r:return-1
for L in wm.up:
if y>>L.H&1:i=L.select1(i-L.T0)
else:i=L.select0(i)
return i
def rank(wm, y: int, r: int): return wm.rank_range(y, 0, r)
def rank_range(wm, y: int, l: int, r: int):
if l >= r: return 0
for L in wm.down:l,r=L.pos2(L[y],l,r)
return r-l
def count_at(wm, y: int, l: int, r: int):
'''Count how many `y` values are in range `[l,r)` '''
if l >= r: return 0
return wm._cnt(y+1, l, r)-wm._cnt(y, l, r)
def count_below(wm, u: int, l: int, r: int):
'''Count `i`'s in `[l,r)` such that `A[i] < u` '''
return wm._cnt(u, l, r)
def count_between(wm, d: int, u: int, l: int, r: int):
'''Count `i`'s in `[l,r)` such that `d <= A[i] < u` '''
if l >= r or d >= u: return 0
return wm._cnt(u, l, r)-wm._cnt(d, l, r)
def _cnt(wm, u: int, l: int, r: int):
if u<=0:return 0
if wm.H<u.bit_length():return r-l
cnt=0
for L in wm.down:
l,r=l-(l1:=L.count1(l)),r-(r1:=L.count1(r))
if u>>L.H&1:cnt+=r-l;l,r=L.T0+l1,L.T0+r1
return cnt
def prev_val(wm,u:int,l:int,r:int):return wm.kth(cnt-1, l, r)if(cnt:=wm._cnt(u,l,r))else-1
def next_val(wm,d:int,l:int,r:int):return wm.kth(cnt, l, r)if(cnt:=wm._cnt(d,l,r))<r-l else-1
from typing import Iterable, Type, Union, overload
import typing
from collections import deque
from numbers import Number
from types import GenericAlias
from typing import Callable, Collection, Iterator, Union
import os
import sys
from io import BytesIO, IOBase
class FastIO(IOBase):
BUFSIZE = 8192
newlines = 0
def __init__(self, file):
self._fd = file.fileno()
self.buffer = BytesIO()
self.writable = "x" in file.mode or "r" not in file.mode
self.write = self.buffer.write if self.writable else None
def read(self):
BUFSIZE = self.BUFSIZE
while True:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
if not b:
break
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines = 0
return self.buffer.read()
def readline(self):
BUFSIZE = self.BUFSIZE
while self.newlines == 0:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
self.newlines = b.count(b"\n") + (not b)
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines -= 1
return self.buffer.readline()
def flush(self):
if self.writable:
os.write(self._fd, self.buffer.getvalue())
self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
stdin: 'IOWrapper' = None
stdout: 'IOWrapper' = None
def __init__(self, file):
self.buffer = FastIO(file)
self.flush = self.buffer.flush
self.writable = self.buffer.writable
def write(self, s):
return self.buffer.write(s.encode("ascii"))
def read(self):
return self.buffer.read().decode("ascii")
def readline(self):
return self.buffer.readline().decode("ascii")
try:
sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
except:
pass
from typing import TypeVar
_T = TypeVar('T')
_U = TypeVar('U')
class TokenStream(Iterator):
stream = IOWrapper.stdin
def __init__(self):
self.queue = deque()
def __next__(self):
if not self.queue: self.queue.extend(self._line())
return self.queue.popleft()
def wait(self):
if not self.queue: self.queue.extend(self._line())
while self.queue: yield
def _line(self):
return TokenStream.stream.readline().split()
def line(self):
if self.queue:
A = list(self.queue)
self.queue.clear()
return A
return self._line()
TokenStream.default = TokenStream()
class CharStream(TokenStream):
def _line(self):
return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()
ParseFn = Callable[[TokenStream],_T]
class Parser:
def __init__(self, spec: Union[type[_T],_T]):
self.parse = Parser.compile(spec)
def __call__(self, ts: TokenStream) -> _T:
return self.parse(ts)
@staticmethod
def compile_type(cls: type[_T], args = ()) -> _T:
if issubclass(cls, Parsable):
return cls.compile(*args)
elif issubclass(cls, (Number, str)):
def parse(ts: TokenStream): return cls(next(ts))
return parse
elif issubclass(cls, tuple):
return Parser.compile_tuple(cls, args)
elif issubclass(cls, Collection):
return Parser.compile_collection(cls, args)
elif callable(cls):
def parse(ts: TokenStream):
return cls(next(ts))
return parse
else:
raise NotImplementedError()
@staticmethod
def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
if isinstance(spec, (type, GenericAlias)):
cls = typing.get_origin(spec) or spec
args = typing.get_args(spec) or tuple()
return Parser.compile_type(cls, args)
elif isinstance(offset := spec, Number):
cls = type(spec)
def parse(ts: TokenStream): return cls(next(ts)) + offset
return parse
elif isinstance(args := spec, tuple):
return Parser.compile_tuple(type(spec), args)
elif isinstance(args := spec, Collection):
return Parser.compile_collection(type(spec), args)
elif isinstance(fn := spec, Callable):
def parse(ts: TokenStream): return fn(next(ts))
return parse
else:
raise NotImplementedError()
@staticmethod
def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
if spec is int:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
return parse
else:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
return parse
@staticmethod
def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
return parse
@staticmethod
def compile_children(cls: _T, specs) -> ParseFn[_T]:
fns = tuple((Parser.compile(spec) for spec in specs))
def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])
return parse
@staticmethod
def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
return Parser.compile_line(cls, specs[0])
else:
return Parser.compile_children(cls, specs)
@staticmethod
def compile_collection(cls, specs):
if not specs or len(specs) == 1 or isinstance(specs, set):
return Parser.compile_line(cls, *specs)
elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
return Parser.compile_repeat(cls, specs[0], specs[1])
else:
raise NotImplementedError()
class Parsable:
@classmethod
def compile(cls):
def parser(ts: TokenStream): return cls(next(ts))
return parser
@overload
def read() -> list[int]: ...
@overload
def read(spec: Type[_T], char=False) -> _T: ...
@overload
def read(spec: _U, char=False) -> _U: ...
@overload
def read(*specs: Type[_T], char=False) -> tuple[_T, ...]: ...
@overload
def read(*specs: _U, char=False) -> tuple[_U, ...]: ...
def read(*specs: Union[Type[_T],_U], char=False):
if not char and not specs: return [int(s) for s in TokenStream.default.line()]
parser: _T = Parser.compile(specs)
ret = parser(CharStream.default if char else TokenStream.default)
return ret[0] if len(specs) == 1 else ret
def write(*args, **kwargs):
'''Prints the values to a stream, or to stdout_fast by default.'''
sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IOWrapper.stdout)
at_start = True
for x in args:
if not at_start:
file.write(sep)
file.write(str(x))
at_start = False
file.write(kwargs.pop("end", "\n"))
if kwargs.pop("flush", False):
file.flush()
if __name__ == '__main__':
A, B = read()
write(C := A+B)
if C != 1198300249: sys.exit(0)
import io
from contextlib import redirect_stdout, redirect_stderr
# Capture all output during test execution
output = io.StringIO()
with redirect_stdout(output), redirect_stderr(output):
result = pytest.main([__file__])
if result != 0: print(output.getvalue())
sys.exit(result)