This documentation is automatically generated by online-judge-tools/verification-helper
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/aplusb
from cp_library.bit.pack_sm_fn import pack_dec, pack_enc
import pytest
import random
class TestTreapMonoid:
def test_initialization(self):
# Define a simple monoid operation (addition)
def add_op(a, b):
return a + b
# Test basic initialization
T = TreapMonoid(add_op, e=0)
assert T.e == 0
assert T.op == add_op
assert T.r >= 0
assert T.all_prod() == 0 # Empty treap should return identity element
def test_insert_and_get(self):
# Define a simple monoid operation (addition)
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert key-value pairs
T.insert(5, 10)
T.insert(3, 20)
T.insert(7, 30)
# Test getting values
assert T.get(5) == 10
assert T.get(3) == 20
assert T.get(7) == 30
with pytest.raises(KeyError):
assert T.get(1) == 0 # Non-existent key
# Test __getitem__ for direct key access
assert T[5] == 10
assert T[3] == 20
assert T[7] == 30
with pytest.raises(KeyError):
assert T[1] == 0 # Non-existent key
def test_set_and_update(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert initial values
T.insert(5, 10)
T.insert(3, 20)
T.insert(7, 30)
# Update values
T[5] = 15
T.set(3, 25)
# Check updated values
assert T[5] == 15
assert T[3] == 25
assert T[7] == 30
# Verify all_prod is updated correctly
assert T.all_prod() == 15 + 25 + 30
def test_pop_and_delete(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert values
T.insert(5, 10)
T.insert(3, 20)
T.insert(7, 30)
# Test pop
assert T.pop(3) == 20
assert 3 not in T
assert T.all_prod() == 10 + 30
# Test __delitem__
del T[5]
assert 5 not in T
assert T.all_prod() == 30
# Test popping non-existent key raises KeyError
with pytest.raises(KeyError):
T.pop(3)
def test_prod_range(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert sorted values
for i in range(10):
T.insert(i, i * 10)
# Test range queries
assert T.prod(0, 5) == 0 + 10 + 20 + 30 + 40
assert T.prod(3, 7) == 30 + 40 + 50 + 60
assert T.prod(0, 10) == sum(i * 10 for i in range(10))
# Test empty range
assert T.prod(5, 5) == 0 # Identity element
# Test __getitem__ with slice
assert T[0:5] == 0 + 10 + 20 + 30 + 40
assert T[3:7] == 30 + 40 + 50 + 60
def test_more_complex_monoid(self):
# Test with a more complex monoid operation (min)
def min_op(a, b):
if a == float('inf') or b == float('inf'):
return a if b == float('inf') else b
return min(a, b)
T = TreapMonoid(min_op, e=float('inf'))
# Insert values
values = [(5, 10), (3, 20), (7, 5), (2, 30)]
for k, v in values:
T.insert(k, v)
# Test min over ranges
assert T.prod(2, 6) == min(30, 20, 10) # min of keys 2, 3, 5
assert T.prod(3, 8) == min(20, 10, 5) # min of keys 3, 5, 7
assert T.all_prod() == min(v for _, v in values)
def test_max_monoid(self):
# Test with max monoid
def max_op(a, b):
if a == float('-inf') or b == float('-inf'):
return a if b == float('-inf') else b
return max(a, b)
T = TreapMonoid(max_op, e=float('-inf'))
# Insert values
for i in range(10):
T.insert(i, i * 10)
# Test max over ranges
assert T.prod(0, 5) == 40 # max of 0, 10, 20, 30, 40
assert T.prod(5, 10) == 90 # max of 50, 60, 70, 80, 90
assert T.all_prod() == 90 # max of all values
def test_sparse_indices(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert values with large gaps between keys
T.insert(10, 5)
T.insert(100, 10)
T.insert(1000, 15)
# Check individual values
assert T[10] == 5
assert T[100] == 10
assert T[1000] == 15
# Check range queries with sparse indices
assert T.prod(0, 50) == 5
assert T.prod(50, 500) == 10
assert T.prod(0, 10000) == 5 + 10 + 15
def test_integrity_after_modifications(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert initial values
for i in range(10):
T.insert(i, i)
# Perform a series of modifications
T[3] = 30
del T[5]
T.insert(11, 11)
T.pop(7)
# Verify treap integrity with _v
T._v()
# Check values after modifications
assert 5 not in T
assert 7 not in T
assert T[3] == 30
assert T[11] == 11
# Check all_prod is correctly updated
expected_sum = sum(i for i in range(10) if i not in [5, 7]) - 3 + 30 + 11
assert T.all_prod() == expected_sum
def test_multiple_operations_sequence(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Add 100 random key-value pairs
random.seed(42) # For reproducibility
keys = random.sample(range(1000), 100)
values = [random.randint(1, 100) for _ in range(100)]
expected_sum = sum(values)
for k, v in zip(keys, values):
T.insert(k, v)
# Verify all_prod
assert T.all_prod() == expected_sum
# Delete 20% of the keys
to_delete = random.sample(keys, 20)
expected_sum -= sum(T.get(k) for k in to_delete)
for k in to_delete:
del T[k]
keys.remove(k)
# Verify all_prod after deletion
assert T.all_prod() == expected_sum
# Update 20% of the remaining keys
to_update = random.sample(keys, 20)
for k in to_update:
old_val = T[k]
new_val = random.randint(1, 100)
expected_sum = expected_sum - old_val + new_val
T[k] = new_val
# Verify all_prod after updates
assert T.all_prod() == expected_sum
# Add new keys
new_keys = [k for k in range(1000, 1020) if k not in keys]
new_values = [random.randint(1, 100) for _ in range(len(new_keys))]
expected_sum += sum(new_values)
for k, v in zip(new_keys, new_values):
T.insert(k, v)
# Final verification
assert T.all_prod() == expected_sum
T._v()
def test_with_negative_values(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert values with negative keys and values
T.insert(-5, 10)
T.insert(3, -20)
T.insert(-7, -30)
# Test getting values
assert T[-5] == 10
assert T[3] == -20
assert T[-7] == -30
# Test range queries with negative keys
assert T.prod(-10, 0) == 10 + (-30) # Sum of values at keys -7 and -5
assert T.prod(-10, 10) == 10 + (-30) + (-20) # Sum of all values
def test_pack_enc_dec(self):
# Test the pack_enc and pack_dec utility functions used in the original code
shift, mask = 30, (1<<30)-1
a, b = 123, 456
packed = pack_enc(a, b, shift)
# Verify pack_dec correctly extracts values
a_dec, b_dec = pack_dec(packed, shift, mask)
assert a_dec == a
assert b_dec == b
def test_large_composite_operation(self):
mod = 998244353
shift, mask = 30, (1<<30)-1
# Define the composite operation from the main function
def op(a, b):
ac, ad = pack_dec(a, shift, mask)
bc, bd = pack_dec(b, shift, mask)
return pack_enc(ac*bc%mod, (ad*bc+bd)%mod, shift)
T = TreapMonoid(op, e=1 << shift)
# Insert some values similar to those in the main function
for i in range(10):
c, d = random.randint(1, 100), random.randint(1, 100)
T[i] = pack_enc(c, d, shift)
# Test range query and composite operation
l, r = 0, 5
result = T.prod(l, r)
a_res, b_res = pack_dec(result, shift, mask)
# Manually compute the expected result
a_exp, b_exp = 1, 0 # Identity element for this operation
for i in range(l, r):
c, d = pack_dec(T[i], shift, mask)
a_exp = (a_exp * c) % mod
b_exp = (b_exp * c + d) % mod
assert a_res == a_exp
assert b_res == b_exp
def test_split_basic(self):
# Define a simple monoid operation (addition)
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert ordered key-value pairs
for i in range(10):
T.insert(i, i * 10)
# Split at key 5
S, T = T.split(5)
# Verify correctness of split
# S should contain keys [0,1,2,3,4]
# T should contain keys [5,6,7,8,9]
for i in range(5):
assert i in S
assert S[i] == i * 10
assert i not in T
for i in range(5, 10):
assert i in T
assert T[i] == i * 10
assert i not in S
# Check monoid values are preserved
assert S.all_prod() == sum(i * 10 for i in range(5))
assert T.all_prod() == sum(i * 10 for i in range(5, 10))
def test_split_empty(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Split an empty treap
S, T = T.split(5)
# Both treaps should be empty
assert S.all_prod() == 0
assert T.all_prod() == 0
def test_split_at_edge(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert key-value pairs
for i in range(1, 11):
T.insert(i, i * 10)
# Split at minimum key
S, T = T.split(1)
# S should be empty, T should have all elements
assert S.all_prod() == 0
assert T.all_prod() == sum(i * 10 for i in range(1, 11))
# Split at maximum key + 1
T, R = T.split(11)
# T should have all elements, R should be empty
assert T.all_prod() == sum(i * 10 for i in range(1, 11))
assert R.all_prod() == 0
def test_split_and_merge(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert key-value pairs
for i in range(10):
T.insert(i, i * 10)
original_sum = T.all_prod()
# Split in the middle
S, R = T.split(5)
# Check partial sums
assert S.all_prod() + R.all_prod() == original_sum
# Use merge to recombine (this requires implementing _merge in Treap class)
# Note: This assumes there's an appropriate public merge method or we're testing a private method
# Since the actual code doesn't show a public merge method, we'll implement a test-specific way to merge
# Merge by manually inserting all items from S into R
for i in range(5):
if i in S:
R[i] = S[i]
# Check if merged correctly
assert R.all_prod() == original_sum
# Validate integrity
R._v()
def test_multiple_splits(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert values
for i in range(20):
T.insert(i, i)
original_sum = T.all_prod()
# Perform multiple splits
S1, T = T.split(10)
S2, S1 = S1.split(5)
# Check contents of each piece
for i in range(5):
assert i in S2
assert S2[i] == i
for i in range(5, 10):
assert i in S1
assert S1[i] == i
for i in range(10, 20):
assert i in T
assert T[i] == i
# Check sums
assert S2.all_prod() + S1.all_prod() + T.all_prod() == original_sum
# Validate integrity of each piece
S1._v()
S2._v()
T._v()
def test_split_with_complex_monoid(self):
# Test with min operation
def min_op(a, b):
if a == float('inf') or b == float('inf'):
return a if b == float('inf') else b
return min(a, b)
T = TreapMonoid(min_op, e=float('inf'))
# Insert values
values = [(i, 20-i) for i in range(20)] # Values decrease as keys increase
for k, v in values:
T.insert(k, v)
# Split in the middle
S, T = T.split(10)
# Check min values
assert S.all_prod() == min(v for k, v in values if k < 10)
assert T.all_prod() == min(v for k, v in values if k >= 10)
def test_random_split_merge_sequence(self):
def add_op(a, b):
return a + b
random.seed(42)
T = TreapMonoid(add_op, e=0)
# Insert random values
keys = list(range(100))
values = [random.randint(1, 100) for _ in range(100)]
for k, v in zip(keys, values):
T.insert(k, v)
original_sum = T.all_prod()
T._v()
# Do a series of random splits and merges
treaps = [T]
split_points = []
# Perform 10 random splits
for _ in range(10):
if not treaps:
break
# Choose a treap to split
idx = random.randint(0, len(treaps)-1)
treap = treaps.pop(idx)
# Find valid split point
min_key, max_key = float('inf'), float('-inf')
for k in range(100):
if k in treap:
min_key = min(min_key, k)
max_key = max(max_key, k)
if min_key == float('inf') or max_key == float('-inf') or min_key == max_key:
treaps.append(treap) # Can't split, put it back
continue
split_point = random.randint(min_key, max_key)
split_points.append(split_point)
# Perform split
left, right = treap.split(split_point)
treaps.extend([left, right])
# Validate each piece
for t in [left, right]:
t._v()
# Sum all pieces to ensure we still have all data
total_sum = sum(t.all_prod() for t in treaps)
assert total_sum == original_sum
# Manually merge back by inserting values
final_treap = TreapMonoid(add_op, e=0)
for k, v in zip(keys, values):
for t in treaps:
if k in t:
final_treap[k] = v
break
assert final_treap.all_prod() == original_sum
final_treap._v()
def test_custom_pack_format_with_split(self):
mod = 998244353
shift, mask = 30, (1<<30)-1
# Define composite operation from the main function
def op(a, b):
ac, ad = pack_dec(a, shift, mask)
bc, bd = pack_dec(b, shift, mask)
return pack_enc(ac*bc%mod, (ad*bc+bd)%mod, shift)
T = TreapMonoid(op, e=1 << shift)
# Insert values with packed format
for i in range(10):
c, d = i+1, i*10
T[i] = pack_enc(c, d, shift)
# Perform split
S, T = T.split(5)
# Verify each part
for i in range(5):
assert i in S
c, d = pack_dec(S[i], shift, mask)
assert c == i+1
assert d == i*10
for i in range(5, 10):
assert i in T
c, d = pack_dec(T[i], shift, mask)
assert c == i+1
assert d == i*10
from cp_library.ds.tree.bst.treap_monoid_cls import TreapMonoid
from cp_library.io.read_fn import read
from cp_library.io.write_fn import write
if __name__ == '__main__':
import sys
A, B = read()
write(C := A+B)
if C != 1198300249: sys.exit(0)
import pytest
import io
from contextlib import redirect_stdout, redirect_stderr
# Capture all output during test execution
output = io.StringIO()
with redirect_stdout(output), redirect_stderr(output):
result = pytest.main([__file__])
if result != 0: print(output.getvalue())
sys.exit(result)
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/aplusb
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
def pack_sm(N: int): s=N.bit_length(); return s, (1<<s)-1
def pack_enc(a: int, b: int, s: int): return a<<s|b
def pack_dec(ab: int, s: int, m: int): return ab>>s,ab&m
def pack_indices(A, s): return [a<<s|i for i,a in enumerate(A)]
import pytest
import random
class TestTreapMonoid:
def test_initialization(self):
# Define a simple monoid operation (addition)
def add_op(a, b):
return a + b
# Test basic initialization
T = TreapMonoid(add_op, e=0)
assert T.e == 0
assert T.op == add_op
assert T.r >= 0
assert T.all_prod() == 0 # Empty treap should return identity element
def test_insert_and_get(self):
# Define a simple monoid operation (addition)
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert key-value pairs
T.insert(5, 10)
T.insert(3, 20)
T.insert(7, 30)
# Test getting values
assert T.get(5) == 10
assert T.get(3) == 20
assert T.get(7) == 30
with pytest.raises(KeyError):
assert T.get(1) == 0 # Non-existent key
# Test __getitem__ for direct key access
assert T[5] == 10
assert T[3] == 20
assert T[7] == 30
with pytest.raises(KeyError):
assert T[1] == 0 # Non-existent key
def test_set_and_update(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert initial values
T.insert(5, 10)
T.insert(3, 20)
T.insert(7, 30)
# Update values
T[5] = 15
T.set(3, 25)
# Check updated values
assert T[5] == 15
assert T[3] == 25
assert T[7] == 30
# Verify all_prod is updated correctly
assert T.all_prod() == 15 + 25 + 30
def test_pop_and_delete(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert values
T.insert(5, 10)
T.insert(3, 20)
T.insert(7, 30)
# Test pop
assert T.pop(3) == 20
assert 3 not in T
assert T.all_prod() == 10 + 30
# Test __delitem__
del T[5]
assert 5 not in T
assert T.all_prod() == 30
# Test popping non-existent key raises KeyError
with pytest.raises(KeyError):
T.pop(3)
def test_prod_range(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert sorted values
for i in range(10):
T.insert(i, i * 10)
# Test range queries
assert T.prod(0, 5) == 0 + 10 + 20 + 30 + 40
assert T.prod(3, 7) == 30 + 40 + 50 + 60
assert T.prod(0, 10) == sum(i * 10 for i in range(10))
# Test empty range
assert T.prod(5, 5) == 0 # Identity element
# Test __getitem__ with slice
assert T[0:5] == 0 + 10 + 20 + 30 + 40
assert T[3:7] == 30 + 40 + 50 + 60
def test_more_complex_monoid(self):
# Test with a more complex monoid operation (min)
def min_op(a, b):
if a == float('inf') or b == float('inf'):
return a if b == float('inf') else b
return min(a, b)
T = TreapMonoid(min_op, e=float('inf'))
# Insert values
values = [(5, 10), (3, 20), (7, 5), (2, 30)]
for k, v in values:
T.insert(k, v)
# Test min over ranges
assert T.prod(2, 6) == min(30, 20, 10) # min of keys 2, 3, 5
assert T.prod(3, 8) == min(20, 10, 5) # min of keys 3, 5, 7
assert T.all_prod() == min(v for _, v in values)
def test_max_monoid(self):
# Test with max monoid
def max_op(a, b):
if a == float('-inf') or b == float('-inf'):
return a if b == float('-inf') else b
return max(a, b)
T = TreapMonoid(max_op, e=float('-inf'))
# Insert values
for i in range(10):
T.insert(i, i * 10)
# Test max over ranges
assert T.prod(0, 5) == 40 # max of 0, 10, 20, 30, 40
assert T.prod(5, 10) == 90 # max of 50, 60, 70, 80, 90
assert T.all_prod() == 90 # max of all values
def test_sparse_indices(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert values with large gaps between keys
T.insert(10, 5)
T.insert(100, 10)
T.insert(1000, 15)
# Check individual values
assert T[10] == 5
assert T[100] == 10
assert T[1000] == 15
# Check range queries with sparse indices
assert T.prod(0, 50) == 5
assert T.prod(50, 500) == 10
assert T.prod(0, 10000) == 5 + 10 + 15
def test_integrity_after_modifications(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert initial values
for i in range(10):
T.insert(i, i)
# Perform a series of modifications
T[3] = 30
del T[5]
T.insert(11, 11)
T.pop(7)
# Verify treap integrity with _v
T._v()
# Check values after modifications
assert 5 not in T
assert 7 not in T
assert T[3] == 30
assert T[11] == 11
# Check all_prod is correctly updated
expected_sum = sum(i for i in range(10) if i not in [5, 7]) - 3 + 30 + 11
assert T.all_prod() == expected_sum
def test_multiple_operations_sequence(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Add 100 random key-value pairs
random.seed(42) # For reproducibility
keys = random.sample(range(1000), 100)
values = [random.randint(1, 100) for _ in range(100)]
expected_sum = sum(values)
for k, v in zip(keys, values):
T.insert(k, v)
# Verify all_prod
assert T.all_prod() == expected_sum
# Delete 20% of the keys
to_delete = random.sample(keys, 20)
expected_sum -= sum(T.get(k) for k in to_delete)
for k in to_delete:
del T[k]
keys.remove(k)
# Verify all_prod after deletion
assert T.all_prod() == expected_sum
# Update 20% of the remaining keys
to_update = random.sample(keys, 20)
for k in to_update:
old_val = T[k]
new_val = random.randint(1, 100)
expected_sum = expected_sum - old_val + new_val
T[k] = new_val
# Verify all_prod after updates
assert T.all_prod() == expected_sum
# Add new keys
new_keys = [k for k in range(1000, 1020) if k not in keys]
new_values = [random.randint(1, 100) for _ in range(len(new_keys))]
expected_sum += sum(new_values)
for k, v in zip(new_keys, new_values):
T.insert(k, v)
# Final verification
assert T.all_prod() == expected_sum
T._v()
def test_with_negative_values(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert values with negative keys and values
T.insert(-5, 10)
T.insert(3, -20)
T.insert(-7, -30)
# Test getting values
assert T[-5] == 10
assert T[3] == -20
assert T[-7] == -30
# Test range queries with negative keys
assert T.prod(-10, 0) == 10 + (-30) # Sum of values at keys -7 and -5
assert T.prod(-10, 10) == 10 + (-30) + (-20) # Sum of all values
def test_pack_enc_dec(self):
# Test the pack_enc and pack_dec utility functions used in the original code
shift, mask = 30, (1<<30)-1
a, b = 123, 456
packed = pack_enc(a, b, shift)
# Verify pack_dec correctly extracts values
a_dec, b_dec = pack_dec(packed, shift, mask)
assert a_dec == a
assert b_dec == b
def test_large_composite_operation(self):
mod = 998244353
shift, mask = 30, (1<<30)-1
# Define the composite operation from the main function
def op(a, b):
ac, ad = pack_dec(a, shift, mask)
bc, bd = pack_dec(b, shift, mask)
return pack_enc(ac*bc%mod, (ad*bc+bd)%mod, shift)
T = TreapMonoid(op, e=1 << shift)
# Insert some values similar to those in the main function
for i in range(10):
c, d = random.randint(1, 100), random.randint(1, 100)
T[i] = pack_enc(c, d, shift)
# Test range query and composite operation
l, r = 0, 5
result = T.prod(l, r)
a_res, b_res = pack_dec(result, shift, mask)
# Manually compute the expected result
a_exp, b_exp = 1, 0 # Identity element for this operation
for i in range(l, r):
c, d = pack_dec(T[i], shift, mask)
a_exp = (a_exp * c) % mod
b_exp = (b_exp * c + d) % mod
assert a_res == a_exp
assert b_res == b_exp
def test_split_basic(self):
# Define a simple monoid operation (addition)
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert ordered key-value pairs
for i in range(10):
T.insert(i, i * 10)
# Split at key 5
S, T = T.split(5)
# Verify correctness of split
# S should contain keys [0,1,2,3,4]
# T should contain keys [5,6,7,8,9]
for i in range(5):
assert i in S
assert S[i] == i * 10
assert i not in T
for i in range(5, 10):
assert i in T
assert T[i] == i * 10
assert i not in S
# Check monoid values are preserved
assert S.all_prod() == sum(i * 10 for i in range(5))
assert T.all_prod() == sum(i * 10 for i in range(5, 10))
def test_split_empty(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Split an empty treap
S, T = T.split(5)
# Both treaps should be empty
assert S.all_prod() == 0
assert T.all_prod() == 0
def test_split_at_edge(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert key-value pairs
for i in range(1, 11):
T.insert(i, i * 10)
# Split at minimum key
S, T = T.split(1)
# S should be empty, T should have all elements
assert S.all_prod() == 0
assert T.all_prod() == sum(i * 10 for i in range(1, 11))
# Split at maximum key + 1
T, R = T.split(11)
# T should have all elements, R should be empty
assert T.all_prod() == sum(i * 10 for i in range(1, 11))
assert R.all_prod() == 0
def test_split_and_merge(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert key-value pairs
for i in range(10):
T.insert(i, i * 10)
original_sum = T.all_prod()
# Split in the middle
S, R = T.split(5)
# Check partial sums
assert S.all_prod() + R.all_prod() == original_sum
# Use merge to recombine (this requires implementing _merge in Treap class)
# Note: This assumes there's an appropriate public merge method or we're testing a private method
# Since the actual code doesn't show a public merge method, we'll implement a test-specific way to merge
# Merge by manually inserting all items from S into R
for i in range(5):
if i in S:
R[i] = S[i]
# Check if merged correctly
assert R.all_prod() == original_sum
# Validate integrity
R._v()
def test_multiple_splits(self):
def add_op(a, b):
return a + b
T = TreapMonoid(add_op, e=0)
# Insert values
for i in range(20):
T.insert(i, i)
original_sum = T.all_prod()
# Perform multiple splits
S1, T = T.split(10)
S2, S1 = S1.split(5)
# Check contents of each piece
for i in range(5):
assert i in S2
assert S2[i] == i
for i in range(5, 10):
assert i in S1
assert S1[i] == i
for i in range(10, 20):
assert i in T
assert T[i] == i
# Check sums
assert S2.all_prod() + S1.all_prod() + T.all_prod() == original_sum
# Validate integrity of each piece
S1._v()
S2._v()
T._v()
def test_split_with_complex_monoid(self):
# Test with min operation
def min_op(a, b):
if a == float('inf') or b == float('inf'):
return a if b == float('inf') else b
return min(a, b)
T = TreapMonoid(min_op, e=float('inf'))
# Insert values
values = [(i, 20-i) for i in range(20)] # Values decrease as keys increase
for k, v in values:
T.insert(k, v)
# Split in the middle
S, T = T.split(10)
# Check min values
assert S.all_prod() == min(v for k, v in values if k < 10)
assert T.all_prod() == min(v for k, v in values if k >= 10)
def test_random_split_merge_sequence(self):
def add_op(a, b):
return a + b
random.seed(42)
T = TreapMonoid(add_op, e=0)
# Insert random values
keys = list(range(100))
values = [random.randint(1, 100) for _ in range(100)]
for k, v in zip(keys, values):
T.insert(k, v)
original_sum = T.all_prod()
T._v()
# Do a series of random splits and merges
treaps = [T]
split_points = []
# Perform 10 random splits
for _ in range(10):
if not treaps:
break
# Choose a treap to split
idx = random.randint(0, len(treaps)-1)
treap = treaps.pop(idx)
# Find valid split point
min_key, max_key = float('inf'), float('-inf')
for k in range(100):
if k in treap:
min_key = min(min_key, k)
max_key = max(max_key, k)
if min_key == float('inf') or max_key == float('-inf') or min_key == max_key:
treaps.append(treap) # Can't split, put it back
continue
split_point = random.randint(min_key, max_key)
split_points.append(split_point)
# Perform split
left, right = treap.split(split_point)
treaps.extend([left, right])
# Validate each piece
for t in [left, right]:
t._v()
# Sum all pieces to ensure we still have all data
total_sum = sum(t.all_prod() for t in treaps)
assert total_sum == original_sum
# Manually merge back by inserting values
final_treap = TreapMonoid(add_op, e=0)
for k, v in zip(keys, values):
for t in treaps:
if k in t:
final_treap[k] = v
break
assert final_treap.all_prod() == original_sum
final_treap._v()
def test_custom_pack_format_with_split(self):
mod = 998244353
shift, mask = 30, (1<<30)-1
# Define composite operation from the main function
def op(a, b):
ac, ad = pack_dec(a, shift, mask)
bc, bd = pack_dec(b, shift, mask)
return pack_enc(ac*bc%mod, (ad*bc+bd)%mod, shift)
T = TreapMonoid(op, e=1 << shift)
# Insert values with packed format
for i in range(10):
c, d = i+1, i*10
T[i] = pack_enc(c, d, shift)
# Perform split
S, T = T.split(5)
# Verify each part
for i in range(5):
assert i in S
c, d = pack_dec(S[i], shift, mask)
assert c == i+1
assert d == i*10
for i in range(5, 10):
assert i in T
c, d = pack_dec(T[i], shift, mask)
assert c == i+1
assert d == i*10
def reserve(A: list, est_len: int) -> None: ...
try:
from __pypy__ import resizelist_hint
except:
def resizelist_hint(A: list, est_len: int):
pass
reserve = resizelist_hint
i64_max = (1<<63)-1
class BST:
__slots__ = 'r'
K,sub,st=[-1],[-1,-1],[]
def __init__(T):T.r=T._nr()
def _nt(T):return T.__class__()
def _nr(T):r=len(T.K);T.K.append(i64_max);T.sub.append(-1);T.sub.append(-1);return r
def _nn(T,k):n=len(T.K);T.K.append(k);T.sub.append(-1);T.sub.append(-1);return n
def insert(T,k):T._i(T.r<<1,k,n:=T._nn(k));T._r();return n
def get(T,k):
if~(i:=T._f(T.r<<1,k)):return i
raise KeyError
def pop(T,k):
if ~(i:=T._t(T.r<<1,k)):T._d(i,T.st[-1]);T._r();return i
else:T.st.clear();raise KeyError
def __delitem__(T,k):
if~(i:=T._t(T.r<<1,k)):T._d(i,T.st[-1]);T._r()
else:T.st.clear();raise KeyError
def __contains__(T,k):return 0<=T._f(T.r<<1,k)
def _f(T,s,k):
i = T.sub[s]
while~i and T.K[i]!=k:T._p(i);i=T.sub[i<<1|(T.K[i]<k)]
return i
def _t(T,s,k):
T.st.append(s)
while~(i:=T.sub[s])and T.K[i]!=k:T._p(i);T.st.append(s:=i<<1|(T.K[i]<k))
return i
def _i(T,s,k,n):
T.st.append(s)
while ~T.sub[s]:T._p(i:=T.sub[s]);T.st.append(s:=i<<1|(T.K[i]<k))
i,T.sub[s]=T.sub[s],n
def _d(T,i,s): raise NotImplemented
def _r(T):T.st.clear()
def _p(T,i): pass
@classmethod
def reserve(cls,sz):sz+=1;reserve(cls.K,sz);reserve(cls.sub,sz<<1);reserve(cls.st,sz.bit_length()<<1)
def _node_str(T, i): return f"{T.K[i]}"
def __str__(T):
def rec(i, pre="", is_right=False):
if i == -1: return ""
ret = "";T._p(i)
if ~(r:=T.sub[i<<1|1]):ret+=rec(r,pre+(" "if is_right else"│ "),True)
ret+=pre+("┌─ "if is_right else"└─ ")+T._node_str(i)+"\n"
if ~(l:=T.sub[i<<1]):ret+=rec(l,pre+(" "if not is_right else"│ "),False)
return ret
return rec(T.sub[T.r<<1]).rstrip()
class BSTUpdates(BST):
def _u(T,i): pass
def _r(T):
while T.st:T._u(T.st.pop()>>1)
class CartesianTree(BST):
K,P,sub,st=[-1],[42],[-1,-1],[]
def _nr(T):T.P.append(-1);return super()._nr()
def _nn(T,k,p=-1):T.P.append(p);return super()._nn(k)
def get(T,k):return T.P[BST.get(T,k)]
def pop(T,k):return T.P[BST.pop(T,k)]
def split(T,k):S=T._nt();T._sp(T.sub[T.r<<1],k,S.r<<1,T.r<<1);T._r();return S,T
def insert(T,k,p):T._i(T.r<<1,k,n:=T._nn(k,p));T._r();return n
def __getitem__(T,k):return T.get(k)
def _i(T,s,k,n):
T.st.append(s)
while~T.sub[s]and T.P[i:=T.sub[s]]<T.P[n]:T._p(i);T.st.append(s:=i<<1|(T.K[i]<k))
i,T.sub[s]=T.sub[s],n
if~i:T._sp(i,k,n<<1,n<<1|1)
def _sp(T,i,k,l,r):
T.st.append(l)
if 1<l^r:T.st.append(r)
while~i:
T._p(i)
if T.K[i]<k:T.sub[l]=i;i=T.sub[l:=i<<1|1];T.st.append(l)
else:T.sub[r]=i;i=T.sub[r:=i<<1];T.st.append(r)
T.sub[l]=T.sub[r]=-1
def _m(T,s,l,r):
T.st.append(s)
while~l and~r:
if T.P[l]<T.P[r]:T._p(l);T.sub[s]=l;l=T.sub[s:=l<<1|1]
else:T._p(r);T.sub[s]=r;r=T.sub[s:=r<<1]
T.st.append(s)
T.sub[s]=l if~l else r
def _d(T,i,s):T._p(i);T._m(s,T.sub[i<<1],T.sub[i<<1|1])
@classmethod
def reserve(cls,sz):super(CartesianTree,cls).reserve(sz);reserve(cls.P,sz+1)
class Treap(CartesianTree):
__slots__='e'
K,V,P,sub,st=[-1],[-1],[42],[-1,-1],[]
def __init__(T,e=-1):T.e=e;super().__init__()
def _nt(T):return T.__class__(T.e)
def _nr(T):T.V.append(T.e);return super()._nr()
def _nn(T,k,v):T.V.append(v);return super()._nn(k,(T.P[-1]*1103515245+12345)&0x7fffffff)
def insert(T,k,v):return super().insert(k,v)
def get(T,k):return T.V[BST.get(T,k)]
def pop(T,k):return T.V[BST.pop(T,k)]
def set(T,k,v):T._s(T.r<<1,k,v);T._r()
def __setitem__(T,k,v):T.set(k,v)
def _s(T,s,k,v):
if ~(i:=T._t(s,k)):T.V[i]=v;T.st.append(i<<1)
else:
n=T._nn(k,v)
while T.P[n]<T.P[i:=T.st[-1]>>1]:T._p(T.st.pop())
T._p(i)
i,T.sub[s]=T.sub[s:=i<<1|(i!=T.r and T.K[i]<k)],n
if~i:T._sp(i,k,n<<1,n<<1|1)
def _node_str(T, i): return f"{T.K[i]}:{T.V[i]}"
@classmethod
def reserve(cls,hint):super(Treap,cls).reserve(hint);reserve(cls.V,hint+1)
class TreapMonoid(Treap, BSTUpdates):
__slots__='op'
K,V,A,P,sub,st=[-1],[-1],[-1],[42],[-1,-1],[]
def __init__(T,op,e=-1):T.op=op;super().__init__(e)
def _nt(T):return T.__class__(T.op,T.e)
def _nr(T):T.A.append(T.e);return super()._nr()
def _nn(T,k,v):T.A.append(v);return super()._nn(k, v)
def prod(T,l,r):
# find common ancestor
a=T.sub[T.r<<1]
while~a and not l<=T.K[a]<r:T._p(a);a=T.sub[a<<1|(T.K[a]<l)]
if a<0:return T.e
# left subtreap
ac,i=T.V[a],T.sub[a<<1]
while~i:
T._p(i)
if not(b:=T.K[i]<l):
if~(j:=T.sub[i<<1|1]):ac=T.op(T.A[j],ac)
ac=T.op(T.V[i],ac)
i=T.sub[i<<1|b]
# right subtreap
i=T.sub[a<<1|1]
while~i:
T._p(i)
if b:=T.K[i]<r:
if~(j:=T.sub[i<<1]):ac=T.op(ac,T.A[j])
ac=T.op(ac,T.V[i])
i=T.sub[i<<1|b]
return ac
def all_prod(T):return T.A[T.r]
def __getitem__(T,k):
if isinstance(k,int):return T.get(k)
elif isinstance(k,slice):return T.prod(k.start,k.stop)
@classmethod
def reserve(cls,sz):super(TreapMonoid,cls).reserve(sz);reserve(cls.A,sz+1)
def _u(T,i):
T.A[i]=T.V[i]
if~(l:=T.sub[i<<1]):T.A[i]=T.op(T.A[l],T.A[i])
if~(r:=T.sub[i<<1|1]):T.A[i]=T.op(T.A[i],T.A[r])
def _v(T,i=None):
if i is None:
assert T.all_prod() == (ac := T._v(i) if ~(i := T.sub[T.r<<1]) else T.e)
return ac
T._p(i);ac = T.V[i]
if ~(l:=T.sub[i<<1]):
assert T.P[i] <= T.P[l]
assert T.K[l] <= T.K[i]
ac = T.op(T._v(l), ac)
if ~(r:=T.sub[i<<1|1]):
assert T.P[i] <= T.P[r]
assert T.K[i] <= T.K[r]
ac = T.op(ac, T._v(r))
assert T.A[i] == ac
return ac
from typing import Iterable, Type, Union, overload
import typing
from collections import deque
from numbers import Number
from types import GenericAlias
from typing import Callable, Collection, Iterator, Union
import os
import sys
from io import BytesIO, IOBase
class FastIO(IOBase):
BUFSIZE = 8192
newlines = 0
def __init__(self, file):
self._fd = file.fileno()
self.buffer = BytesIO()
self.writable = "x" in file.mode or "r" not in file.mode
self.write = self.buffer.write if self.writable else None
def read(self):
BUFSIZE = self.BUFSIZE
while True:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
if not b:
break
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines = 0
return self.buffer.read()
def readline(self):
BUFSIZE = self.BUFSIZE
while self.newlines == 0:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
self.newlines = b.count(b"\n") + (not b)
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines -= 1
return self.buffer.readline()
def flush(self):
if self.writable:
os.write(self._fd, self.buffer.getvalue())
self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
stdin: 'IOWrapper' = None
stdout: 'IOWrapper' = None
def __init__(self, file):
self.buffer = FastIO(file)
self.flush = self.buffer.flush
self.writable = self.buffer.writable
def write(self, s):
return self.buffer.write(s.encode("ascii"))
def read(self):
return self.buffer.read().decode("ascii")
def readline(self):
return self.buffer.readline().decode("ascii")
try:
sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
except:
pass
from typing import TypeVar
_T = TypeVar('T')
_U = TypeVar('U')
class TokenStream(Iterator):
stream = IOWrapper.stdin
def __init__(self):
self.queue = deque()
def __next__(self):
if not self.queue: self.queue.extend(self._line())
return self.queue.popleft()
def wait(self):
if not self.queue: self.queue.extend(self._line())
while self.queue: yield
def _line(self):
return TokenStream.stream.readline().split()
def line(self):
if self.queue:
A = list(self.queue)
self.queue.clear()
return A
return self._line()
TokenStream.default = TokenStream()
class CharStream(TokenStream):
def _line(self):
return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()
ParseFn = Callable[[TokenStream],_T]
class Parser:
def __init__(self, spec: Union[type[_T],_T]):
self.parse = Parser.compile(spec)
def __call__(self, ts: TokenStream) -> _T:
return self.parse(ts)
@staticmethod
def compile_type(cls: type[_T], args = ()) -> _T:
if issubclass(cls, Parsable):
return cls.compile(*args)
elif issubclass(cls, (Number, str)):
def parse(ts: TokenStream): return cls(next(ts))
return parse
elif issubclass(cls, tuple):
return Parser.compile_tuple(cls, args)
elif issubclass(cls, Collection):
return Parser.compile_collection(cls, args)
elif callable(cls):
def parse(ts: TokenStream):
return cls(next(ts))
return parse
else:
raise NotImplementedError()
@staticmethod
def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
if isinstance(spec, (type, GenericAlias)):
cls = typing.get_origin(spec) or spec
args = typing.get_args(spec) or tuple()
return Parser.compile_type(cls, args)
elif isinstance(offset := spec, Number):
cls = type(spec)
def parse(ts: TokenStream): return cls(next(ts)) + offset
return parse
elif isinstance(args := spec, tuple):
return Parser.compile_tuple(type(spec), args)
elif isinstance(args := spec, Collection):
return Parser.compile_collection(type(spec), args)
elif isinstance(fn := spec, Callable):
def parse(ts: TokenStream): return fn(next(ts))
return parse
else:
raise NotImplementedError()
@staticmethod
def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
if spec is int:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
return parse
else:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
return parse
@staticmethod
def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
fn = Parser.compile(spec)
def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
return parse
@staticmethod
def compile_children(cls: _T, specs) -> ParseFn[_T]:
fns = tuple((Parser.compile(spec) for spec in specs))
def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])
return parse
@staticmethod
def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
return Parser.compile_line(cls, specs[0])
else:
return Parser.compile_children(cls, specs)
@staticmethod
def compile_collection(cls, specs):
if not specs or len(specs) == 1 or isinstance(specs, set):
return Parser.compile_line(cls, *specs)
elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
return Parser.compile_repeat(cls, specs[0], specs[1])
else:
raise NotImplementedError()
class Parsable:
@classmethod
def compile(cls):
def parser(ts: TokenStream): return cls(next(ts))
return parser
@overload
def read() -> list[int]: ...
@overload
def read(spec: Type[_T], char=False) -> _T: ...
@overload
def read(spec: _U, char=False) -> _U: ...
@overload
def read(*specs: Type[_T], char=False) -> tuple[_T, ...]: ...
@overload
def read(*specs: _U, char=False) -> tuple[_U, ...]: ...
def read(*specs: Union[Type[_T],_U], char=False):
if not char and not specs: return [int(s) for s in TokenStream.default.line()]
parser: _T = Parser.compile(specs)
ret = parser(CharStream.default if char else TokenStream.default)
return ret[0] if len(specs) == 1 else ret
def write(*args, **kwargs):
'''Prints the values to a stream, or to stdout_fast by default.'''
sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IOWrapper.stdout)
at_start = True
for x in args:
if not at_start:
file.write(sep)
file.write(str(x))
at_start = False
file.write(kwargs.pop("end", "\n"))
if kwargs.pop("flush", False):
file.flush()
if __name__ == '__main__':
A, B = read()
write(C := A+B)
if C != 1198300249: sys.exit(0)
import io
from contextlib import redirect_stdout, redirect_stderr
# Capture all output during test execution
output = io.StringIO()
with redirect_stdout(output), redirect_stderr(output):
result = pytest.main([__file__])
if result != 0: print(output.getvalue())
sys.exit(result)