This documentation is automatically generated by online-judge-tools/verification-helper
# verification-helper: PROBLEM https://onlinejudge.u-aizu.ac.jp/courses/lesson/2/ITP1/1/ITP1_1_A
import pytest
import random
from operator import add
class TestSegTree2:
def test_initialization_with_list(self):
"""Test initialization with a list of tuples"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
assert seg.n == 4
assert seg[0] == (1, 10)
assert seg[1] == (2, 20)
assert seg[2] == (3, 30)
assert seg[3] == (4, 40)
def test_initialization_with_size(self):
"""Test initialization with size only"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), 5)
assert seg.n == 5
# All elements should be identity
for i in range(5):
assert seg[i] == (0, 0)
def test_set_and_get(self):
"""Test set and get operations"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), 4)
seg[0] = (1, 10)
seg[1] = (2, 20)
seg[2] = (3, 30)
seg[3] = (4, 40)
assert seg[0] == (1, 10)
assert seg[1] == (2, 20)
assert seg[2] == (3, 30)
assert seg[3] == (4, 40)
def test_prod_sum(self):
"""Test prod operation with sum"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Test various ranges
assert seg.prod(0, 4) == (10, 100) # Sum of all
assert seg.prod(0, 2) == (3, 30) # First two
assert seg.prod(1, 3) == (5, 50) # Middle two
assert seg.prod(2, 4) == (7, 70) # Last two
assert seg.prod(1, 2) == (2, 20) # Single element
assert seg.prod(2, 2) == (0, 0) # Empty range
def test_prod_max(self):
"""Test prod operation with max"""
values = [(3, 30), (1, 10), (4, 40), (2, 20)]
seg = SegTree2(lambda a, b: (max(a[0], b[0]), max(a[1], b[1])), (float('-inf'), float('-inf')), values)
assert seg.prod(0, 4) == (4, 40)
assert seg.prod(0, 2) == (3, 30)
assert seg.prod(1, 3) == (4, 40)
assert seg.prod(2, 4) == (4, 40)
def test_prod_min(self):
"""Test prod operation with min"""
values = [(3, 30), (1, 10), (4, 40), (2, 20)]
seg = SegTree2(lambda a, b: (min(a[0], b[0]), min(a[1], b[1])), (float('inf'), float('inf')), values)
assert seg.prod(0, 4) == (1, 10)
assert seg.prod(0, 2) == (1, 10)
assert seg.prod(1, 3) == (1, 10)
assert seg.prod(2, 4) == (2, 20)
def test_all_prod(self):
"""Test all_prod operation"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
assert seg.all_prod() == (10, 100)
def test_max_right(self):
"""Test max_right operation"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Find the rightmost position where sum is <= threshold
assert seg.max_right(0, lambda x: x[0] <= 3) == 2 # Sum up to index 2 is 3
assert seg.max_right(0, lambda x: x[0] <= 6) == 3 # Sum up to index 3 is 6
assert seg.max_right(0, lambda x: x[0] <= 10) == 4 # Sum up to index 4 is 10
assert seg.max_right(1, lambda x: x[0] <= 5) == 3 # Sum from 1 to 3 is 5
assert seg.max_right(0, lambda x: x[0] <= 0) == 0 # No elements satisfy
def test_min_left(self):
"""Test min_left operation"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Find the leftmost position where sum from that position is <= threshold
assert seg.min_left(4, lambda x: x[0] <= 4) == 3 # Only last element
assert seg.min_left(4, lambda x: x[0] <= 7) == 2 # Last two elements
assert seg.min_left(4, lambda x: x[0] <= 10) == 0 # All elements
assert seg.min_left(3, lambda x: x[0] <= 3) == 2 # Elements 2-3
assert seg.min_left(4, lambda x: x[0] <= 0) == 4 # No elements satisfy
def test_update_and_query(self):
"""Test update operations affect queries correctly"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), 4)
# Initial values
seg[0] = (1, 10)
seg[1] = (2, 20)
seg[2] = (3, 30)
seg[3] = (4, 40)
assert seg.prod(0, 4) == (10, 100)
# Update some values
seg[1] = (5, 50)
seg[2] = (6, 60)
assert seg.prod(0, 4) == (16, 160)
assert seg.prod(1, 3) == (11, 110)
def test_empty_tree(self):
"""Test empty segment tree"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), 0)
assert seg.n == 0
assert seg.all_prod() == (0, 0)
def test_single_element(self):
"""Test segment tree with single element"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), [(5, 50)])
assert seg.n == 1
assert seg[0] == (5, 50)
assert seg.prod(0, 1) == (5, 50)
assert seg.all_prod() == (5, 50)
def test_large_tree(self):
"""Test with larger dataset"""
n = 1000
values = [(i, i * 10) for i in range(n)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Sum of 0..999 = 499500
assert seg.all_prod() == (499500, 4995000)
# Sum of 0..99 = 4950
assert seg.prod(0, 100) == (4950, 49500)
# Update and verify
seg[500] = (1000, 10000)
expected_sum = 499500 - 500 + 1000
assert seg.all_prod() == (expected_sum, 4995000 - 5000 + 10000)
def test_different_types(self):
"""Test with different data types in tuples"""
# String concatenation and list concatenation
seg = SegTree2(
lambda a, b: (a[0] + b[0], a[1] + b[1]),
("", []),
[("a", [1]), ("b", [2]), ("c", [3]), ("d", [4])]
)
assert seg.prod(0, 2) == ("ab", [1, 2])
assert seg.prod(1, 4) == ("bcd", [2, 3, 4])
assert seg.all_prod() == ("abcd", [1, 2, 3, 4])
def test_non_commutative_operation(self):
"""Test with non-commutative operations"""
# Matrix-like operation (simplified)
def matrix_mult(a, b):
# Simplified 2x1 matrix multiplication
return (a[0] * b[0], a[1] * b[0] + b[1])
seg = SegTree2(matrix_mult, (1, 0), [(2, 1), (3, 2), (4, 3), (5, 4)])
# Verify non-commutative property affects results
result = seg.prod(0, 2)
assert result == (6, 5) # (2*3, 1*3+2)
def test_stress_random_operations(self):
"""Stress test with random operations"""
random.seed(42)
n = 100
# Initialize with random values
values = [(random.randint(1, 100), random.randint(1, 100)) for _ in range(n)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Perform random operations
for _ in range(200):
op = random.choice(['update', 'query'])
if op == 'update':
idx = random.randint(0, n-1)
new_val = (random.randint(1, 100), random.randint(1, 100))
seg[idx] = new_val
values[idx] = new_val
else:
l = random.randint(0, n-1)
r = random.randint(l, n)
# Verify against naive calculation
expected = (0, 0)
for i in range(l, r):
expected = (expected[0] + values[i][0], expected[1] + values[i][1])
assert seg.prod(l, r) == expected
from cp_library.ds.tree.seg.segtree2_cls import SegTree2
if __name__ == '__main__':
from cp_library.test.unittest_helper import run_verification_helper_unittest
run_verification_helper_unittest()
# verification-helper: PROBLEM https://onlinejudge.u-aizu.ac.jp/courses/lesson/2/ITP1/1/ITP1_1_A
import pytest
import random
from operator import add
class TestSegTree2:
def test_initialization_with_list(self):
"""Test initialization with a list of tuples"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
assert seg.n == 4
assert seg[0] == (1, 10)
assert seg[1] == (2, 20)
assert seg[2] == (3, 30)
assert seg[3] == (4, 40)
def test_initialization_with_size(self):
"""Test initialization with size only"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), 5)
assert seg.n == 5
# All elements should be identity
for i in range(5):
assert seg[i] == (0, 0)
def test_set_and_get(self):
"""Test set and get operations"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), 4)
seg[0] = (1, 10)
seg[1] = (2, 20)
seg[2] = (3, 30)
seg[3] = (4, 40)
assert seg[0] == (1, 10)
assert seg[1] == (2, 20)
assert seg[2] == (3, 30)
assert seg[3] == (4, 40)
def test_prod_sum(self):
"""Test prod operation with sum"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Test various ranges
assert seg.prod(0, 4) == (10, 100) # Sum of all
assert seg.prod(0, 2) == (3, 30) # First two
assert seg.prod(1, 3) == (5, 50) # Middle two
assert seg.prod(2, 4) == (7, 70) # Last two
assert seg.prod(1, 2) == (2, 20) # Single element
assert seg.prod(2, 2) == (0, 0) # Empty range
def test_prod_max(self):
"""Test prod operation with max"""
values = [(3, 30), (1, 10), (4, 40), (2, 20)]
seg = SegTree2(lambda a, b: (max(a[0], b[0]), max(a[1], b[1])), (float('-inf'), float('-inf')), values)
assert seg.prod(0, 4) == (4, 40)
assert seg.prod(0, 2) == (3, 30)
assert seg.prod(1, 3) == (4, 40)
assert seg.prod(2, 4) == (4, 40)
def test_prod_min(self):
"""Test prod operation with min"""
values = [(3, 30), (1, 10), (4, 40), (2, 20)]
seg = SegTree2(lambda a, b: (min(a[0], b[0]), min(a[1], b[1])), (float('inf'), float('inf')), values)
assert seg.prod(0, 4) == (1, 10)
assert seg.prod(0, 2) == (1, 10)
assert seg.prod(1, 3) == (1, 10)
assert seg.prod(2, 4) == (2, 20)
def test_all_prod(self):
"""Test all_prod operation"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
assert seg.all_prod() == (10, 100)
def test_max_right(self):
"""Test max_right operation"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Find the rightmost position where sum is <= threshold
assert seg.max_right(0, lambda x: x[0] <= 3) == 2 # Sum up to index 2 is 3
assert seg.max_right(0, lambda x: x[0] <= 6) == 3 # Sum up to index 3 is 6
assert seg.max_right(0, lambda x: x[0] <= 10) == 4 # Sum up to index 4 is 10
assert seg.max_right(1, lambda x: x[0] <= 5) == 3 # Sum from 1 to 3 is 5
assert seg.max_right(0, lambda x: x[0] <= 0) == 0 # No elements satisfy
def test_min_left(self):
"""Test min_left operation"""
values = [(1, 10), (2, 20), (3, 30), (4, 40)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Find the leftmost position where sum from that position is <= threshold
assert seg.min_left(4, lambda x: x[0] <= 4) == 3 # Only last element
assert seg.min_left(4, lambda x: x[0] <= 7) == 2 # Last two elements
assert seg.min_left(4, lambda x: x[0] <= 10) == 0 # All elements
assert seg.min_left(3, lambda x: x[0] <= 3) == 2 # Elements 2-3
assert seg.min_left(4, lambda x: x[0] <= 0) == 4 # No elements satisfy
def test_update_and_query(self):
"""Test update operations affect queries correctly"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), 4)
# Initial values
seg[0] = (1, 10)
seg[1] = (2, 20)
seg[2] = (3, 30)
seg[3] = (4, 40)
assert seg.prod(0, 4) == (10, 100)
# Update some values
seg[1] = (5, 50)
seg[2] = (6, 60)
assert seg.prod(0, 4) == (16, 160)
assert seg.prod(1, 3) == (11, 110)
def test_empty_tree(self):
"""Test empty segment tree"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), 0)
assert seg.n == 0
assert seg.all_prod() == (0, 0)
def test_single_element(self):
"""Test segment tree with single element"""
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), [(5, 50)])
assert seg.n == 1
assert seg[0] == (5, 50)
assert seg.prod(0, 1) == (5, 50)
assert seg.all_prod() == (5, 50)
def test_large_tree(self):
"""Test with larger dataset"""
n = 1000
values = [(i, i * 10) for i in range(n)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Sum of 0..999 = 499500
assert seg.all_prod() == (499500, 4995000)
# Sum of 0..99 = 4950
assert seg.prod(0, 100) == (4950, 49500)
# Update and verify
seg[500] = (1000, 10000)
expected_sum = 499500 - 500 + 1000
assert seg.all_prod() == (expected_sum, 4995000 - 5000 + 10000)
def test_different_types(self):
"""Test with different data types in tuples"""
# String concatenation and list concatenation
seg = SegTree2(
lambda a, b: (a[0] + b[0], a[1] + b[1]),
("", []),
[("a", [1]), ("b", [2]), ("c", [3]), ("d", [4])]
)
assert seg.prod(0, 2) == ("ab", [1, 2])
assert seg.prod(1, 4) == ("bcd", [2, 3, 4])
assert seg.all_prod() == ("abcd", [1, 2, 3, 4])
def test_non_commutative_operation(self):
"""Test with non-commutative operations"""
# Matrix-like operation (simplified)
def matrix_mult(a, b):
# Simplified 2x1 matrix multiplication
return (a[0] * b[0], a[1] * b[0] + b[1])
seg = SegTree2(matrix_mult, (1, 0), [(2, 1), (3, 2), (4, 3), (5, 4)])
# Verify non-commutative property affects results
result = seg.prod(0, 2)
assert result == (6, 5) # (2*3, 1*3+2)
def test_stress_random_operations(self):
"""Stress test with random operations"""
random.seed(42)
n = 100
# Initialize with random values
values = [(random.randint(1, 100), random.randint(1, 100)) for _ in range(n)]
seg = SegTree2(lambda a, b: (a[0] + b[0], a[1] + b[1]), (0, 0), values)
# Perform random operations
for _ in range(200):
op = random.choice(['update', 'query'])
if op == 'update':
idx = random.randint(0, n-1)
new_val = (random.randint(1, 100), random.randint(1, 100))
seg[idx] = new_val
values[idx] = new_val
else:
l = random.randint(0, n-1)
r = random.randint(l, n)
# Verify against naive calculation
expected = (0, 0)
for i in range(l, r):
expected = (expected[0] + values[i][0], expected[1] + values[i][1])
assert seg.prod(l, r) == expected
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
from typing import TypeVar
_S = TypeVar('S'); _T = TypeVar('T'); _U = TypeVar('U'); _T1 = TypeVar('T1'); _T2 = TypeVar('T2'); _T3 = TypeVar('T3'); _T4 = TypeVar('T4'); _T5 = TypeVar('T5'); _T6 = TypeVar('T6')
from typing import Generic
def argsort(A: list[int], reverse=False):
P = Packer(len(I := list(A))-1); P.ienumerate(I, reverse); I.sort(); P.iindices(I)
return I
class Packer:
__slots__ = 's', 'm'
def __init__(P, mx: int): P.s = mx.bit_length(); P.m = (1 << P.s) - 1
def enc(P, a: int, b: int): return a << P.s | b
def dec(P, x: int) -> tuple[int, int]: return x >> P.s, x & P.m
def enumerate(P, A, reverse=False): P.ienumerate(A:=list(A), reverse); return A
def ienumerate(P, A, reverse=False):
if reverse:
for i,a in enumerate(A): A[i] = P.enc(-a, i)
else:
for i,a in enumerate(A): A[i] = P.enc(a, i)
def indices(P, A: list[int]): P.iindices(A:=list(A)); return A
def iindices(P, A):
for i,a in enumerate(A): A[i] = P.m&a
def isort_parallel(*L: list, reverse=False):
inv, order = [0]*len(L[0]), argsort(L[0], reverse=reverse)
for i, j in enumerate(order): inv[j] = i
for i, j in enumerate(order):
for A in L: A[i], A[j] = A[j], A[i]
order[inv[i]], inv[j] = j, inv[i]
return L
class list2(Generic[_T1, _T2]):
__slots__ = 'A1', 'A2'
def __init__(lst, A1: list[_T1], A2: list[_T2]): lst.A1, lst.A2 = A1, A2
def __len__(lst): return len(lst.A1)
def __getitem__(lst, i: int): return lst.A1[i], lst.A2[i]
def __setitem__(lst, i: int, v: tuple[_T1, _T2]): lst.A1[i], lst.A2[i] = v
def __contains__(lst, v: tuple[_T1, _T2]): raise NotImplementedError
def index(lst, v: tuple[_T1, _T2]): raise NotImplementedError
def reverse(lst): lst.A1.reverse(); lst.A2.reverse()
def sort(lst, reverse=False): isort_parallel(lst.A1, lst.A2, reverse=reverse)
def pop(lst): return lst.A1.pop(), lst.A2.pop()
def append(lst, v: tuple[_T1, _T2]): v1, v2 = v; lst.A1.append(v1); lst.A2.append(v2)
def add(lst, i: int, v: tuple[_T1, _T2]): lst.A1[i] += v[0]; lst.A2[i] += v[1]
from typing import Callable, Generic, Union
class SegTree(Generic[_T]):
_lst = list
def __init__(seg, op: Callable[[_T, _T], _T], e: _T, v: Union[int, list[_T]]) -> None:
if isinstance(v, int): n = v; v = None
else: n = len(v)
seg.op, seg.e, seg.n = op, e, n
seg.log, seg.sz = (log := (n-1).bit_length()+1), (sz := 1 << log)
if seg._lst is list: seg.d = [e]*(sz<<1)
else: seg.d = seg._lst(*([e_]*(sz<<1) for e_ in e))
if v: seg._build(v)
def _build(seg, v):
for i in range(seg.n): seg.d[seg.sz + i] = v[i]
for i in range(seg.sz-1,0,-1): seg._merge(i, i<<1, i<<1|1)
def _merge(seg, i, j, k): seg.d[i] = seg.op(seg.d[j], seg.d[k])
def set(seg, p: int, x: _T) -> None:
p += seg.sz
seg.d[p] = x
for _ in range(seg.log):
p = p^(p&1)
seg._merge(p>>1, p, p|1)
p >>= 1
__setitem__ = set
def get(seg, p: int) -> _T: return seg.d[p+seg.sz]
__getitem__ = get
def prod(seg, l: int, r: int) -> _T:
sml = smr = seg.e
l, r = l+seg.sz, r+seg.sz
while l < r:
if l&1: sml, l = seg.op(sml, seg.d[l]), l+1
if r&1: smr = seg.op(seg.d[r:=r-1], smr)
l, r = l >> 1, r >> 1
return seg.op(sml, smr)
def all_prod(seg) -> _T: return seg.d[1]
def max_right(seg, l: int, f: Callable[[_T], bool]) -> int:
assert 0 <= l <= seg.n
assert f(seg.e)
if l == seg.n: return seg.n
l, op, d, sm = l+(sz := seg.sz), seg.op, seg.d, seg.e
while True:
while l&1 == 0: l >>= 1
if not f(op(sm, d[l])):
while l < sz:
if f(op(sm, d[l:=l<<1])): sm, l = op(sm, d[l]), l+1
return l - sz
sm, l = op(sm, d[l]), l+1
if l&-l == l: return seg.n
def min_left(seg, r: int, f: Callable[[_T], bool]) -> int:
assert 0 <= r <= seg.n
assert f(seg.e)
if r == 0: return 0
r, op, d, sm = r+(sz := seg.sz), seg.op, seg.d, seg.e
while True:
r -= 1
while r > 1 and r & 1: r >>= 1
if not f(op(d[r], sm)):
while r < sz:
if f(op(d[r:=r<<1|1], sm)): sm, r = op(d[r], sm), r-1
return r + 1 - sz
sm = op(d[r], sm)
if (r & -r) == r: return 0
class SegTree2(SegTree[_T]):
_lst = list2
if __name__ == '__main__':
"""
Helper for making unittest files compatible with verification-helper.
This module provides a helper function to run a dummy Library Checker test
so that unittest files can be verified by oj-verify.
"""
def run_verification_helper_unittest():
"""
Run a dummy AOJ ITP1_1_A test for verification-helper compatibility.
This function should be called in the __main__ block of unittest files
that need to be compatible with verification-helper.
The function:
1. Prints "Hello World" (AOJ ITP1_1_A solution)
2. Runs pytest for the calling test file
3. Exits with the pytest result code
"""
import sys
# Print "Hello World" for AOJ ITP1_1_A problem
print("Hello World")
import io
from contextlib import redirect_stdout, redirect_stderr
# Capture all output during test execution
output = io.StringIO()
with redirect_stdout(output), redirect_stderr(output):
# Get the calling module's file path
frame = sys._getframe(1)
test_file = frame.f_globals.get('__file__')
if test_file is None:
test_file = sys.argv[0]
result = pytest.main([test_file])
if result != 0:
print(output.getvalue())
sys.exit(result)
run_verification_helper_unittest()