cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/unittests/ds/tree/seg/segtree4_cls_test.py

Depends on

Code

# verification-helper: PROBLEM https://onlinejudge.u-aizu.ac.jp/courses/lesson/2/ITP1/1/ITP1_1_A

import pytest
import random

class TestSegTree4:
    def test_initialization_with_list(self):
        """Test initialization with a list of 4-tuples"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        assert seg.n == 4
        assert seg[0] == (1, 10, 100, 1000)
        assert seg[1] == (2, 20, 200, 2000)
        assert seg[2] == (3, 30, 300, 3000)
        assert seg[3] == (4, 40, 400, 4000)

    def test_initialization_with_size(self):
        """Test initialization with size only"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), 5)
        
        assert seg.n == 5
        # All elements should be identity
        for i in range(5):
            assert seg[i] == (0, 0, 0, 0)

    def test_set_and_get(self):
        """Test set and get operations"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), 4)
        
        seg[0] = (1, 10, 100, 1000)
        seg[1] = (2, 20, 200, 2000)
        seg[2] = (3, 30, 300, 3000)
        seg[3] = (4, 40, 400, 4000)
        
        assert seg[0] == (1, 10, 100, 1000)
        assert seg[1] == (2, 20, 200, 2000)
        assert seg[2] == (3, 30, 300, 3000)
        assert seg[3] == (4, 40, 400, 4000)

    def test_prod_sum(self):
        """Test prod operation with sum"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Test various ranges
        assert seg.prod(0, 4) == (10, 100, 1000, 10000)  # Sum of all
        assert seg.prod(0, 2) == (3, 30, 300, 3000)      # First two
        assert seg.prod(1, 3) == (5, 50, 500, 5000)      # Middle two
        assert seg.prod(2, 4) == (7, 70, 700, 7000)      # Last two
        assert seg.prod(1, 2) == (2, 20, 200, 2000)      # Single element
        assert seg.prod(2, 2) == (0, 0, 0, 0)            # Empty range

    def test_prod_max(self):
        """Test prod operation with max"""
        values = [(3, 30, 300, 3000), (1, 10, 100, 1000), (4, 40, 400, 4000), (2, 20, 200, 2000)]
        seg = SegTree4(
            lambda a, b: (max(a[0], b[0]), max(a[1], b[1]), max(a[2], b[2]), max(a[3], b[3])), 
            (float('-inf'), float('-inf'), float('-inf'), float('-inf')), 
            values
        )
        
        assert seg.prod(0, 4) == (4, 40, 400, 4000)
        assert seg.prod(0, 2) == (3, 30, 300, 3000)
        assert seg.prod(1, 3) == (4, 40, 400, 4000)
        assert seg.prod(2, 4) == (4, 40, 400, 4000)

    def test_prod_min(self):
        """Test prod operation with min"""
        values = [(3, 30, 300, 3000), (1, 10, 100, 1000), (4, 40, 400, 4000), (2, 20, 200, 2000)]
        seg = SegTree4(
            lambda a, b: (min(a[0], b[0]), min(a[1], b[1]), min(a[2], b[2]), min(a[3], b[3])), 
            (float('inf'), float('inf'), float('inf'), float('inf')), 
            values
        )
        
        assert seg.prod(0, 4) == (1, 10, 100, 1000)
        assert seg.prod(0, 2) == (1, 10, 100, 1000)
        assert seg.prod(1, 3) == (1, 10, 100, 1000)
        assert seg.prod(2, 4) == (2, 20, 200, 2000)

    def test_all_prod(self):
        """Test all_prod operation"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        assert seg.all_prod() == (10, 100, 1000, 10000)

    def test_max_right(self):
        """Test max_right operation"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Find the rightmost position where sum is <= threshold
        assert seg.max_right(0, lambda x: x[0] <= 3) == 2    # Sum up to index 2 is 3
        assert seg.max_right(0, lambda x: x[1] <= 30) == 2   # Sum up to index 2 is 30
        assert seg.max_right(0, lambda x: x[2] <= 300) == 2  # Sum up to index 2 is 300
        assert seg.max_right(0, lambda x: x[3] <= 3000) == 2 # Sum up to index 2 is 3000
        assert seg.max_right(0, lambda x: x[0] <= 10) == 4   # Sum up to index 4 is 10

    def test_min_left(self):
        """Test min_left operation"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Find the leftmost position where sum from that position is <= threshold
        assert seg.min_left(4, lambda x: x[0] <= 4) == 3     # Only last element
        assert seg.min_left(4, lambda x: x[1] <= 40) == 3    # Only last element
        assert seg.min_left(4, lambda x: x[2] <= 400) == 3   # Only last element
        assert seg.min_left(4, lambda x: x[3] <= 4000) == 3  # Only last element
        assert seg.min_left(4, lambda x: x[0] <= 10) == 0    # All elements

    def test_update_and_query(self):
        """Test update operations affect queries correctly"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), 4)
        
        # Initial values
        seg[0] = (1, 10, 100, 1000)
        seg[1] = (2, 20, 200, 2000)
        seg[2] = (3, 30, 300, 3000)
        seg[3] = (4, 40, 400, 4000)
        
        assert seg.prod(0, 4) == (10, 100, 1000, 10000)
        
        # Update some values
        seg[1] = (5, 50, 500, 5000)
        seg[2] = (6, 60, 600, 6000)
        
        assert seg.prod(0, 4) == (16, 160, 1600, 16000)
        assert seg.prod(1, 3) == (11, 110, 1100, 11000)

    def test_empty_tree(self):
        """Test empty segment tree"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), 0)
        
        assert seg.n == 0
        assert seg.all_prod() == (0, 0, 0, 0)

    def test_single_element(self):
        """Test segment tree with single element"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), [(5, 50, 500, 5000)])
        
        assert seg.n == 1
        assert seg[0] == (5, 50, 500, 5000)
        assert seg.prod(0, 1) == (5, 50, 500, 5000)
        assert seg.all_prod() == (5, 50, 500, 5000)

    def test_large_tree(self):
        """Test with larger dataset"""
        n = 1000
        values = [(i, i * 10, i * 100, i * 1000) for i in range(n)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Sum of 0..999 = 499500
        assert seg.all_prod() == (499500, 4995000, 49950000, 499500000)
        
        # Sum of 0..99 = 4950
        assert seg.prod(0, 100) == (4950, 49500, 495000, 4950000)
        
        # Update and verify
        seg[500] = (1000, 10000, 100000, 1000000)
        expected_sum = 499500 - 500 + 1000
        assert seg.all_prod() == (expected_sum, 4995000 - 5000 + 10000, 49950000 - 50000 + 100000, 499500000 - 500000 + 1000000)

    def test_different_types(self):
        """Test with different data types in tuples"""
        # String concatenation, list concatenation, set union, and counting
        seg = SegTree4(
            lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] | b[2], a[3] + b[3]),
            ("", [], set(), 0),
            [("a", [1], {1}, 1), ("b", [2], {2}, 1), ("c", [3], {3}, 1), ("d", [4], {4}, 1)]
        )
        
        assert seg.prod(0, 2) == ("ab", [1, 2], {1, 2}, 2)
        assert seg.prod(1, 4) == ("bcd", [2, 3, 4], {2, 3, 4}, 3)
        assert seg.all_prod() == ("abcd", [1, 2, 3, 4], {1, 2, 3, 4}, 4)

    def test_complex_operation(self):
        """Test with complex statistical operations"""
        # Track min, max, sum, and count simultaneously
        def combine(a, b):
            return (
                min(a[0], b[0]),  # min
                max(a[1], b[1]),  # max
                a[2] + b[2],      # sum
                a[3] + b[3]       # count
            )
        
        values = [(3, 3, 3, 1), (1, 1, 1, 1), (4, 4, 4, 1), (2, 2, 2, 1)]
        seg = SegTree4(combine, (float('inf'), float('-inf'), 0, 0), values)
        
        assert seg.prod(0, 4) == (1, 4, 10, 4)  # min=1, max=4, sum=10, count=4
        assert seg.prod(0, 2) == (1, 3, 4, 2)   # min=1, max=3, sum=4, count=2
        assert seg.prod(2, 4) == (2, 4, 6, 2)   # min=2, max=4, sum=6, count=2

    def test_stress_random_operations(self):
        """Stress test with random operations"""
        random.seed(42)
        n = 100
        
        # Initialize with random values
        values = [(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)) for _ in range(n)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Perform random operations
        for _ in range(200):
            op = random.choice(['update', 'query'])
            
            if op == 'update':
                idx = random.randint(0, n-1)
                new_val = (random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100))
                seg[idx] = new_val
                values[idx] = new_val
            else:
                l = random.randint(0, n-1)
                r = random.randint(l, n)
                
                # Verify against naive calculation
                expected = (0, 0, 0, 0)
                for i in range(l, r):
                    expected = (expected[0] + values[i][0], expected[1] + values[i][1], 
                              expected[2] + values[i][2], expected[3] + values[i][3])
                
                assert seg.prod(l, r) == expected

from cp_library.ds.tree.seg.segtree4_cls import SegTree4

if __name__ == '__main__':
    from cp_library.test.unittest_helper import run_verification_helper_unittest
    run_verification_helper_unittest()
# verification-helper: PROBLEM https://onlinejudge.u-aizu.ac.jp/courses/lesson/2/ITP1/1/ITP1_1_A

import pytest
import random

class TestSegTree4:
    def test_initialization_with_list(self):
        """Test initialization with a list of 4-tuples"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        assert seg.n == 4
        assert seg[0] == (1, 10, 100, 1000)
        assert seg[1] == (2, 20, 200, 2000)
        assert seg[2] == (3, 30, 300, 3000)
        assert seg[3] == (4, 40, 400, 4000)

    def test_initialization_with_size(self):
        """Test initialization with size only"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), 5)
        
        assert seg.n == 5
        # All elements should be identity
        for i in range(5):
            assert seg[i] == (0, 0, 0, 0)

    def test_set_and_get(self):
        """Test set and get operations"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), 4)
        
        seg[0] = (1, 10, 100, 1000)
        seg[1] = (2, 20, 200, 2000)
        seg[2] = (3, 30, 300, 3000)
        seg[3] = (4, 40, 400, 4000)
        
        assert seg[0] == (1, 10, 100, 1000)
        assert seg[1] == (2, 20, 200, 2000)
        assert seg[2] == (3, 30, 300, 3000)
        assert seg[3] == (4, 40, 400, 4000)

    def test_prod_sum(self):
        """Test prod operation with sum"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Test various ranges
        assert seg.prod(0, 4) == (10, 100, 1000, 10000)  # Sum of all
        assert seg.prod(0, 2) == (3, 30, 300, 3000)      # First two
        assert seg.prod(1, 3) == (5, 50, 500, 5000)      # Middle two
        assert seg.prod(2, 4) == (7, 70, 700, 7000)      # Last two
        assert seg.prod(1, 2) == (2, 20, 200, 2000)      # Single element
        assert seg.prod(2, 2) == (0, 0, 0, 0)            # Empty range

    def test_prod_max(self):
        """Test prod operation with max"""
        values = [(3, 30, 300, 3000), (1, 10, 100, 1000), (4, 40, 400, 4000), (2, 20, 200, 2000)]
        seg = SegTree4(
            lambda a, b: (max(a[0], b[0]), max(a[1], b[1]), max(a[2], b[2]), max(a[3], b[3])), 
            (float('-inf'), float('-inf'), float('-inf'), float('-inf')), 
            values
        )
        
        assert seg.prod(0, 4) == (4, 40, 400, 4000)
        assert seg.prod(0, 2) == (3, 30, 300, 3000)
        assert seg.prod(1, 3) == (4, 40, 400, 4000)
        assert seg.prod(2, 4) == (4, 40, 400, 4000)

    def test_prod_min(self):
        """Test prod operation with min"""
        values = [(3, 30, 300, 3000), (1, 10, 100, 1000), (4, 40, 400, 4000), (2, 20, 200, 2000)]
        seg = SegTree4(
            lambda a, b: (min(a[0], b[0]), min(a[1], b[1]), min(a[2], b[2]), min(a[3], b[3])), 
            (float('inf'), float('inf'), float('inf'), float('inf')), 
            values
        )
        
        assert seg.prod(0, 4) == (1, 10, 100, 1000)
        assert seg.prod(0, 2) == (1, 10, 100, 1000)
        assert seg.prod(1, 3) == (1, 10, 100, 1000)
        assert seg.prod(2, 4) == (2, 20, 200, 2000)

    def test_all_prod(self):
        """Test all_prod operation"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        assert seg.all_prod() == (10, 100, 1000, 10000)

    def test_max_right(self):
        """Test max_right operation"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Find the rightmost position where sum is <= threshold
        assert seg.max_right(0, lambda x: x[0] <= 3) == 2    # Sum up to index 2 is 3
        assert seg.max_right(0, lambda x: x[1] <= 30) == 2   # Sum up to index 2 is 30
        assert seg.max_right(0, lambda x: x[2] <= 300) == 2  # Sum up to index 2 is 300
        assert seg.max_right(0, lambda x: x[3] <= 3000) == 2 # Sum up to index 2 is 3000
        assert seg.max_right(0, lambda x: x[0] <= 10) == 4   # Sum up to index 4 is 10

    def test_min_left(self):
        """Test min_left operation"""
        values = [(1, 10, 100, 1000), (2, 20, 200, 2000), (3, 30, 300, 3000), (4, 40, 400, 4000)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Find the leftmost position where sum from that position is <= threshold
        assert seg.min_left(4, lambda x: x[0] <= 4) == 3     # Only last element
        assert seg.min_left(4, lambda x: x[1] <= 40) == 3    # Only last element
        assert seg.min_left(4, lambda x: x[2] <= 400) == 3   # Only last element
        assert seg.min_left(4, lambda x: x[3] <= 4000) == 3  # Only last element
        assert seg.min_left(4, lambda x: x[0] <= 10) == 0    # All elements

    def test_update_and_query(self):
        """Test update operations affect queries correctly"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), 4)
        
        # Initial values
        seg[0] = (1, 10, 100, 1000)
        seg[1] = (2, 20, 200, 2000)
        seg[2] = (3, 30, 300, 3000)
        seg[3] = (4, 40, 400, 4000)
        
        assert seg.prod(0, 4) == (10, 100, 1000, 10000)
        
        # Update some values
        seg[1] = (5, 50, 500, 5000)
        seg[2] = (6, 60, 600, 6000)
        
        assert seg.prod(0, 4) == (16, 160, 1600, 16000)
        assert seg.prod(1, 3) == (11, 110, 1100, 11000)

    def test_empty_tree(self):
        """Test empty segment tree"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), 0)
        
        assert seg.n == 0
        assert seg.all_prod() == (0, 0, 0, 0)

    def test_single_element(self):
        """Test segment tree with single element"""
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), [(5, 50, 500, 5000)])
        
        assert seg.n == 1
        assert seg[0] == (5, 50, 500, 5000)
        assert seg.prod(0, 1) == (5, 50, 500, 5000)
        assert seg.all_prod() == (5, 50, 500, 5000)

    def test_large_tree(self):
        """Test with larger dataset"""
        n = 1000
        values = [(i, i * 10, i * 100, i * 1000) for i in range(n)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Sum of 0..999 = 499500
        assert seg.all_prod() == (499500, 4995000, 49950000, 499500000)
        
        # Sum of 0..99 = 4950
        assert seg.prod(0, 100) == (4950, 49500, 495000, 4950000)
        
        # Update and verify
        seg[500] = (1000, 10000, 100000, 1000000)
        expected_sum = 499500 - 500 + 1000
        assert seg.all_prod() == (expected_sum, 4995000 - 5000 + 10000, 49950000 - 50000 + 100000, 499500000 - 500000 + 1000000)

    def test_different_types(self):
        """Test with different data types in tuples"""
        # String concatenation, list concatenation, set union, and counting
        seg = SegTree4(
            lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] | b[2], a[3] + b[3]),
            ("", [], set(), 0),
            [("a", [1], {1}, 1), ("b", [2], {2}, 1), ("c", [3], {3}, 1), ("d", [4], {4}, 1)]
        )
        
        assert seg.prod(0, 2) == ("ab", [1, 2], {1, 2}, 2)
        assert seg.prod(1, 4) == ("bcd", [2, 3, 4], {2, 3, 4}, 3)
        assert seg.all_prod() == ("abcd", [1, 2, 3, 4], {1, 2, 3, 4}, 4)

    def test_complex_operation(self):
        """Test with complex statistical operations"""
        # Track min, max, sum, and count simultaneously
        def combine(a, b):
            return (
                min(a[0], b[0]),  # min
                max(a[1], b[1]),  # max
                a[2] + b[2],      # sum
                a[3] + b[3]       # count
            )
        
        values = [(3, 3, 3, 1), (1, 1, 1, 1), (4, 4, 4, 1), (2, 2, 2, 1)]
        seg = SegTree4(combine, (float('inf'), float('-inf'), 0, 0), values)
        
        assert seg.prod(0, 4) == (1, 4, 10, 4)  # min=1, max=4, sum=10, count=4
        assert seg.prod(0, 2) == (1, 3, 4, 2)   # min=1, max=3, sum=4, count=2
        assert seg.prod(2, 4) == (2, 4, 6, 2)   # min=2, max=4, sum=6, count=2

    def test_stress_random_operations(self):
        """Stress test with random operations"""
        random.seed(42)
        n = 100
        
        # Initialize with random values
        values = [(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)) for _ in range(n)]
        seg = SegTree4(lambda a, b: (a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]), (0, 0, 0, 0), values)
        
        # Perform random operations
        for _ in range(200):
            op = random.choice(['update', 'query'])
            
            if op == 'update':
                idx = random.randint(0, n-1)
                new_val = (random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100))
                seg[idx] = new_val
                values[idx] = new_val
            else:
                l = random.randint(0, n-1)
                r = random.randint(l, n)
                
                # Verify against naive calculation
                expected = (0, 0, 0, 0)
                for i in range(l, r):
                    expected = (expected[0] + values[i][0], expected[1] + values[i][1], 
                              expected[2] + values[i][2], expected[3] + values[i][3])
                
                assert seg.prod(l, r) == expected

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
from typing import TypeVar

_S = TypeVar('S'); _T = TypeVar('T'); _U = TypeVar('U'); _T1 = TypeVar('T1'); _T2 = TypeVar('T2'); _T3 = TypeVar('T3'); _T4 = TypeVar('T4'); _T5 = TypeVar('T5'); _T6 = TypeVar('T6')

from typing import Generic




def argsort(A: list[int], reverse=False):
    P = Packer(len(I := list(A))-1); P.ienumerate(I, reverse); I.sort(); P.iindices(I)
    return I



class Packer:
    __slots__ = 's', 'm'
    def __init__(P, mx: int): P.s = mx.bit_length(); P.m = (1 << P.s) - 1
    def enc(P, a: int, b: int): return a << P.s | b
    def dec(P, x: int) -> tuple[int, int]: return x >> P.s, x & P.m
    def enumerate(P, A, reverse=False): P.ienumerate(A:=list(A), reverse); return A
    def ienumerate(P, A, reverse=False):
        if reverse:
            for i,a in enumerate(A): A[i] = P.enc(-a, i)
        else:
            for i,a in enumerate(A): A[i] = P.enc(a, i)
    def indices(P, A: list[int]): P.iindices(A:=list(A)); return A
    def iindices(P, A):
        for i,a in enumerate(A): A[i] = P.m&a


def isort_parallel(*L: list, reverse=False):
    inv, order = [0]*len(L[0]), argsort(L[0], reverse=reverse)
    for i, j in enumerate(order): inv[j] = i
    for i, j in enumerate(order):
        for A in L: A[i], A[j] = A[j], A[i]
        order[inv[i]], inv[j] = j, inv[i]
    return L


class list4(Generic[_T1, _T2, _T3, _T4]):
    __slots__ = 'A1', 'A2', 'A3', 'A4'
    def __init__(lst, A1: list[_T1], A2: list[_T2], A3: list[_T3], A4: list[_T4]):
        lst.A1, lst.A2, lst.A3, lst.A4 = A1, A2, A3, A4
    def __len__(lst): return len(lst.A1)
    def __getitem__(lst, i: int): return lst.A1[i], lst.A2[i], lst.A3[i], lst.A4[i]
    def __setitem__(lst, i: int, v: tuple[_T1, _T2, _T3, _T4]): lst.A1[i], lst.A2[i], lst.A3[i], lst.A4[i] = v
    def __contains__(lst, v: tuple[_T1, _T2, _T3, _T4]): raise NotImplementedError
    def index(lst, v: tuple[_T1, _T2, _T3, _T4]): raise NotImplementedError
    def reverse(lst): lst.A1.reverse(); lst.A2.reverse(); lst.A3.reverse(); lst.A4.reverse()
    def sort(lst, reverse=False): isort_parallel(lst.A1, lst.A2, lst.A3, lst.A4, reverse=reverse)
    def pop(lst): return lst.A1.pop(), lst.A2.pop(), lst.A3.pop(), lst.A4.pop()
    def append(lst, v: tuple[_T1, _T2, _T3, _T4]):
        v1, v2, v3, v4 = v
        lst.A1.append(v1); lst.A2.append(v2); lst.A3.append(v3); lst.A4.append(v4)
    def add(lst, i: int, v: tuple[_T1, _T2, _T3, _T4]): lst.A1[i] += v[0]; lst.A2[i] += v[1]; lst.A3[i] += v[2]; lst.A4[i] += v[3]


from typing import Callable, Generic, Union

class SegTree(Generic[_T]):
    _lst = list
    
    def __init__(seg, op: Callable[[_T, _T], _T], e: _T, v: Union[int, list[_T]]) -> None:
        if isinstance(v, int): n = v; v = None
        else: n = len(v)
        seg.op, seg.e, seg.n = op, e, n
        seg.log, seg.sz = (log := (n-1).bit_length()+1), (sz := 1 << log)
        if seg._lst is list: seg.d = [e]*(sz<<1)
        else: seg.d = seg._lst(*([e_]*(sz<<1) for e_ in e))
        if v: seg._build(v)

    def _build(seg, v):
        for i in range(seg.n): seg.d[seg.sz + i] = v[i]
        for i in range(seg.sz-1,0,-1): seg._merge(i, i<<1, i<<1|1)
    
    def _merge(seg, i, j, k): seg.d[i] = seg.op(seg.d[j], seg.d[k])

    def set(seg, p: int, x: _T) -> None:
        p += seg.sz
        seg.d[p] = x
        for _ in range(seg.log):
            p = p^(p&1)
            seg._merge(p>>1, p, p|1)
            p >>= 1
    __setitem__ = set

    def get(seg, p: int) -> _T: return seg.d[p+seg.sz]
    __getitem__ = get

    def prod(seg, l: int, r: int) -> _T:
        sml = smr = seg.e
        l, r = l+seg.sz, r+seg.sz
        while l < r:
            if l&1: sml, l = seg.op(sml, seg.d[l]), l+1
            if r&1: smr = seg.op(seg.d[r:=r-1], smr)
            l, r = l >> 1, r >> 1
        return seg.op(sml, smr)

    def all_prod(seg) -> _T: return seg.d[1]

    def max_right(seg, l: int, f: Callable[[_T], bool]) -> int:
        assert 0 <= l <= seg.n
        assert f(seg.e)
        if l == seg.n: return seg.n
        l, op, d, sm = l+(sz := seg.sz), seg.op, seg.d, seg.e
        while True:
            while l&1 == 0: l >>= 1
            if not f(op(sm, d[l])):
                while l < sz:
                    if f(op(sm, d[l:=l<<1])): sm, l = op(sm, d[l]), l+1
                return l - sz
            sm, l = op(sm, d[l]), l+1
            if l&-l == l: return seg.n

    def min_left(seg, r: int, f: Callable[[_T], bool]) -> int:
        assert 0 <= r <= seg.n
        assert f(seg.e)
        if r == 0: return 0
        r, op, d, sm = r+(sz := seg.sz), seg.op, seg.d, seg.e
        while True:
            r -= 1
            while r > 1 and r & 1: r >>= 1
            if not f(op(d[r], sm)):
                while r < sz:
                    if f(op(d[r:=r<<1|1], sm)): sm, r = op(d[r], sm), r-1
                return r + 1 - sz
            sm = op(d[r], sm)
            if (r & -r) == r: return 0

class SegTree4(SegTree[_T]):
    _lst = list4

if __name__ == '__main__':
    
    """
    Helper for making unittest files compatible with verification-helper.
    
    This module provides a helper function to run a dummy Library Checker test
    so that unittest files can be verified by oj-verify.
    """
    
    def run_verification_helper_unittest():
        """
        Run a dummy AOJ ITP1_1_A test for verification-helper compatibility.
        
        This function should be called in the __main__ block of unittest files
        that need to be compatible with verification-helper.
        
        The function:
        1. Prints "Hello World" (AOJ ITP1_1_A solution)
        2. Runs pytest for the calling test file
        3. Exits with the pytest result code
        """
        import sys
        
        # Print "Hello World" for AOJ ITP1_1_A problem
        print("Hello World")
        
        import io
        from contextlib import redirect_stdout, redirect_stderr
    
        # Capture all output during test execution
        output = io.StringIO()
        with redirect_stdout(output), redirect_stderr(output):
            # Get the calling module's file path
            frame = sys._getframe(1)
            test_file = frame.f_globals.get('__file__')
            if test_file is None:
                test_file = sys.argv[0]
            result = pytest.main([test_file])
        
        if result != 0: 
            print(output.getvalue())
        sys.exit(result)
    run_verification_helper_unittest()
Back to top page