This documentation is automatically generated by online-judge-tools/verification-helper
#!/usr/bin/env python3
"""
Comprehensive benchmark comparing modular arithmetic approaches on lists:
1. Plain int list with manual modular operations
2. mlist_cls (optimized modular list)
3. List of mint_cls (modular integers)
Tests various operations to provide fair comparison across different use cases.
"""
import random
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cp_library.perf.benchmark import Benchmark, BenchmarkConfig
from cp_library.math.mod.mlist_cls import mlist
from cp_library.math.mod.mint_ntt_cls import mint
# Setup modular arithmetic with a common modulus
MOD = 998244353
mint.set_mod(MOD)
# Configure benchmark
config = BenchmarkConfig(
name="mlist",
sizes=[1000000, 100000, 10000, 1000, 100, 10, 1], # Reverse order to warm up JIT
operations=['construction', 'addition', 'multiplication', 'mixed_ops', 'elementwise_mul', 'sum_all', 'conv'],
iterations=10,
warmup=3,
output_dir="./output/benchmark_results/mlist"
)
# Create benchmark instance
benchmark = Benchmark(config)
# Data generators
@benchmark.data_generator("default")
def generate_modular_data(size: int, operation: str):
"""Generate test data for modular arithmetic operations"""
# Generate two random lists for operations
list1 = [random.randint(1, MOD-1) for _ in range(size)]
list2 = [random.randint(1, MOD-1) for _ in range(size)]
# Pre-initialize data for fair timing (exclude initialization overhead)
preinitialized = {
'list1_copy': list(list1),
'list2_copy': list(list2),
'mlist1': mlist(list(list1)),
'mlist2': mlist(list(list2)),
'mint_list1': [mint(x) for x in list1],
'mint_list2': [mint(x) for x in list2],
'result_buffer': [0] * size,
'mlist_result': mlist(size),
'constant': 12345,
'mint_constant': mint(12345)
}
return {
'list1': list1,
'list2': list2,
'size': size,
'operation': operation,
'mod': MOD,
'preinitialized': preinitialized
}
# Construction operation
@benchmark.implementation("int_list", "construction")
def construction_int_list(data):
"""Construct int list from raw data"""
list1 = list(data['list1'])
list2 = list(data['list2'])
checksum = 0
for x in list1:
checksum ^= x
for x in list2:
checksum ^= x
return checksum
@benchmark.implementation("mlist", "construction")
def construction_mlist(data):
"""Construct mlist from raw data"""
mlist1 = mlist(data['list1'])
mlist2 = mlist(data['list2'])
checksum = 0
for x in mlist1.data:
checksum ^= x
for x in mlist2.data:
checksum ^= x
return checksum
@benchmark.implementation("mint_list", "construction")
def construction_mint_list(data):
"""Construct mint list from raw data"""
mint_list1 = [mint(x) for x in data['list1']]
mint_list2 = [mint(x) for x in data['list2']]
checksum = 0
for x in mint_list1:
checksum ^= x
for x in mint_list2:
checksum ^= x
return checksum
# Addition operation
@benchmark.implementation("int_list", "addition")
def addition_int_list(data):
"""Element-wise addition with manual modulo"""
pre = data['preinitialized']
list1, list2, mod = pre['list1_copy'], pre['list2_copy'], data['mod']
checksum = 0
for i in range(data['size']):
checksum ^= (list1[i] + list2[i]) % mod
return checksum
@benchmark.implementation("mlist", "addition")
def addition_mlist(data):
"""Element-wise addition using mlist"""
pre = data['preinitialized']
list1, list2 = pre['mlist1'], pre['mlist2']
checksum = 0
for i in range(data['size']):
checksum ^= list1[i] + list2[i]
return checksum
@benchmark.implementation("mint_list", "addition")
def addition_mint_list(data):
"""Element-wise addition using mint list"""
pre = data['preinitialized']
list1, list2 = pre['mint_list1'], pre['mint_list2']
checksum = 0
for i in range(data['size']):
checksum ^= list1[i] + list2[i]
return checksum
# Multiplication operation
@benchmark.implementation("int_list", "multiplication")
def multiplication_int_list(data):
"""Element-wise multiplication with manual modulo"""
pre = data['preinitialized']
list1, list2, mod = pre['list1_copy'], pre['list2_copy'], data['mod']
checksum = 0
for i in range(data['size']):
checksum ^= (list1[i] * list2[i]) % mod
return checksum
@benchmark.implementation("mlist", "multiplication")
def multiplication_mlist(data):
"""Element-wise multiplication using mlist"""
pre = data['preinitialized']
list1, list2 = pre['mlist1'], pre['mlist2']
checksum = 0
for i in range(data['size']):
checksum ^= list1[i] * list2[i]
return checksum
@benchmark.implementation("mint_list", "multiplication")
def multiplication_mint_list(data):
"""Element-wise multiplication using mint list"""
pre = data['preinitialized']
list1, list2 = pre['mint_list1'], pre['mint_list2']
checksum = 0
for i in range(data['size']):
checksum ^= list1[i] * list2[i]
return checksum
# Mixed operations
@benchmark.implementation("int_list", "mixed_ops")
def mixed_ops_int_list(data):
"""Mix of addition, multiplication, and subtraction"""
pre = data['preinitialized']
list1, list2, mod = pre['list1_copy'], pre['list2_copy'], data['mod']
checksum = 0
for i in range(data['size']):
if i % 3 == 0:
checksum ^= (list1[i] + list2[i]) % mod
elif i % 3 == 1:
checksum ^= (list1[i] * list2[i]) % mod
else:
checksum ^= (list1[i] - list2[i]) % mod
return checksum
@benchmark.implementation("mlist", "mixed_ops")
def mixed_ops_mlist(data):
"""Mix of operations using mlist"""
pre = data['preinitialized']
list1, list2 = pre['mlist1'], pre['mlist2']
checksum = 0
for i in range(data['size']):
if i % 3 == 0:
checksum ^= list1[i] + list2[i]
elif i % 3 == 1:
checksum ^= list1[i] * list2[i]
else:
checksum ^= list1[i] - list2[i]
return checksum
@benchmark.implementation("mint_list", "mixed_ops")
def mixed_ops_mint_list(data):
"""Mix of operations using mint list"""
pre = data['preinitialized']
list1, list2 = pre['mint_list1'], pre['mint_list2']
checksum = 0
for i in range(data['size']):
if i % 3 == 0:
checksum ^= list1[i] + list2[i]
elif i % 3 == 1:
checksum ^= list1[i] * list2[i]
else:
checksum ^= list1[i] - list2[i]
return checksum
@benchmark.implementation("int_list_e", "mixed_ops")
def mixed_ops_int_list(data):
"""Mix of addition, multiplication, and subtraction"""
pre = data['preinitialized']
list1, list2, mod = pre['list1_copy'], pre['list2_copy'], data['mod']
checksum = 0
for i, x in enumerate(list1):
if i % 3 == 0:
checksum ^= (x + list2[i]) % mod
elif i % 3 == 1:
checksum ^= (x * list2[i]) % mod
else:
checksum ^= (x - list2[i]) % mod
return checksum
@benchmark.implementation("mlist_e", "mixed_ops")
def mixed_ops_mlist(data):
"""Mix of operations using mlist"""
pre = data['preinitialized']
list1, list2 = pre['mlist1'], pre['mlist2']
checksum = 0
for i, x in enumerate(list1):
if i % 3 == 0:
checksum ^= x + list2[i]
elif i % 3 == 1:
checksum ^= x * list2[i]
else:
checksum ^= x - list2[i]
return checksum
@benchmark.implementation("mint_list_e", "mixed_ops")
def mixed_ops_mint_list(data):
"""Mix of operations using mint list"""
pre = data['preinitialized']
list1, list2 = pre['mint_list1'], pre['mint_list2']
checksum = 0
for i, x in enumerate(list1):
if i % 3 == 0:
checksum ^= x + list2[i]
elif i % 3 == 1:
checksum ^= x * list2[i]
else:
checksum ^= x - list2[i]
return checksum
# Element-wise multiplication by constant
@benchmark.implementation("int_list", "elementwise_mul")
def elementwise_mul_int_list(data):
"""Multiply each element by a constant"""
pre = data['preinitialized']
list1, mod, constant = pre['list1_copy'], data['mod'], pre['constant']
checksum = 0
for x in list1:
checksum ^= (x * constant) % mod
return checksum
@benchmark.implementation("mlist", "elementwise_mul")
def elementwise_mul_mlist(data):
"""Multiply each element by a constant using mlist"""
pre = data['preinitialized']
list1, constant = pre['mlist1'], pre['mint_constant']
checksum = 0
for x in list1:
checksum ^= x * constant
return checksum
@benchmark.implementation("mint_list", "elementwise_mul")
def elementwise_mul_mint_list(data):
"""Multiply each element by a constant using mint list"""
pre = data['preinitialized']
list1, constant = pre['mint_list1'], pre['mint_constant']
checksum = 0
for x in list1:
result = x * constant
checksum ^= result
return checksum
# Sum all elements
@benchmark.implementation("int_list", "sum_all")
def sum_all_int_list(data):
"""Sum all elements"""
pre = data['preinitialized']
list1, mod = pre['list1_copy'], data['mod']
result = 0
for x in list1:
result = (result + x) % mod
return result
@benchmark.implementation("mlist", "sum_all")
def sum_all_mlist(data):
"""Sum all elements using mlist"""
pre = data['preinitialized']
list1 = pre['mlist1']
result = mint(0)
for x in list1:
result = result + x
return int(result)
@benchmark.implementation("mint_list", "sum_all")
def sum_all_mint_list(data):
"""Sum all elements using mint list"""
pre = data['preinitialized']
list1 = pre['mint_list1']
result = mint(0)
for x in list1:
result = result + x
return int(result)
# Convolution operation
@benchmark.implementation("int_list", "conv")
def conv_int_list(data):
"""Convolution using mint.ntt.conv with int lists"""
pre = data['preinitialized']
list1, list2 = pre['list1_copy'], pre['list2_copy']
# Use mint.ntt.conv for convolution
result = mint.ntt.conv(list1, list2, len(list1) + len(list2) - 1)
checksum = 0
for x in result:
checksum ^= x
return checksum
@benchmark.implementation("mlist", "conv")
def conv_mlist(data):
"""Convolution using mlist.conv method"""
pre = data['preinitialized']
mlist1, mlist2 = pre['mlist1'], pre['mlist2']
# Use mlist.conv method
result = mlist1.conv(mlist2, len(mlist1) + len(mlist2) - 1)
checksum = 0
for x in result.data:
checksum ^= x
return checksum
@benchmark.implementation("mint_list", "conv")
def conv_mint_list(data):
"""Convolution using mint.ntt.conv with mint lists"""
pre = data['preinitialized']
mint_list1, mint_list2 = pre['mint_list1'], pre['mint_list2']
# Convert to int lists, convolve, convert back
int_list1 = [int(x) for x in mint_list1]
int_list2 = [int(x) for x in mint_list2]
result_ints = mint.ntt.conv(int_list1, int_list2, len(int_list1) + len(int_list2) - 1)
result = [mint(x) for x in result_ints]
checksum = 0
for x in result:
checksum ^= x
return checksum
@benchmark.implementation("mint_list_direct", "conv")
def conv_mint_list_direct(data):
"""Convolution using mint.ntt.conv directly with mint lists"""
pre = data['preinitialized']
mint_list1, mint_list2 = pre['mint_list1'], pre['mint_list2']
result = mint.ntt.conv(mint_list1, mint_list2, len(mint_list1) + len(mint_list2) - 1)
checksum = 0
for x in result:
checksum ^= x
return checksum
# Custom validator for modular arithmetic results (now using XOR checksums)
@benchmark.validator("default")
def validate_modular_result(expected, actual):
"""Validate modular arithmetic results using XOR checksums"""
try:
# Compare XOR checksums directly
return int(expected) == int(actual)
except Exception:
return False
if __name__ == "__main__":
# Parse command line args and run appropriate mode
runner = benchmark.parse_args()
runner.run()
#!/usr/bin/env python3
"""
Comprehensive benchmark comparing modular arithmetic approaches on lists:
1. Plain int list with manual modular operations
2. mlist_cls (optimized modular list)
3. List of mint_cls (modular integers)
Tests various operations to provide fair comparison across different use cases.
"""
import random
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
"""
Declarative benchmark framework with minimal boilerplate.
Features:
- Decorator-based benchmark registration
- Automatic data generation and validation
- Built-in timing with warmup
- Configurable operations and sizes
- JSON results and matplotlib plotting
"""
import time
import json
import statistics
import argparse
from typing import Dict, List, Any, Callable, Union
from dataclasses import dataclass
from pathlib import Path
from collections import defaultdict
@dataclass
class BenchmarkConfig:
"""Configuration for benchmark runs"""
name: str
sizes: List[int] = None
operations: List[str] = None
iterations: int = 10
warmup: int = 2
output_dir: str = "./output/benchmark_results"
save_results: bool = True
plot_results: bool = True
plot_scale: str = "loglog" # Options: "loglog", "linear", "semilogx", "semilogy"
progressive: bool = True # Show results operation by operation across sizes
# Profiling mode
profile_mode: bool = False
profile_size: int = None
profile_operation: str = None
profile_implementation: str = None
def __post_init__(self):
if self.sizes is None:
self.sizes = [100, 1000, 10000, 100000]
if self.operations is None:
self.operations = ['default']
class Benchmark:
"""Declarative benchmark framework using decorators"""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.data_generators = {}
self.implementations = {}
self.validators = {}
self.setups = {}
self.results = []
def profile(self, operation: str = None, size: int = None, implementation: str = None):
"""Create a profiling version of this benchmark"""
profile_config = BenchmarkConfig(
name=f"{self.config.name}_profile",
sizes=self.config.sizes,
operations=self.config.operations,
profile_mode=True,
profile_operation=operation,
profile_size=size,
profile_implementation=implementation,
save_results=False,
plot_results=False
)
profile_benchmark = Benchmark(profile_config)
profile_benchmark.data_generators = self.data_generators
profile_benchmark.implementations = self.implementations
profile_benchmark.validators = self.validators
profile_benchmark.setups = self.setups
return profile_benchmark
def parse_args(self):
"""Parse command line arguments for profiling mode"""
parser = argparse.ArgumentParser(
description=f"Benchmark {self.config.name} with optional profiling mode",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Normal benchmark mode
python benchmark.py
# Profile specific operation and implementation
python benchmark.py --profile --operation random_access --implementation grid
# Profile with specific size
python benchmark.py --profile --size 1000000
# Profile all implementations of an operation
python benchmark.py --profile --operation construction
"""
)
parser.add_argument('--profile', action='store_true',
help='Run in profiling mode (minimal overhead for profilers)')
parser.add_argument('--operation', type=str,
help=f'Operation to profile. Options: {", ".join(self.config.operations)}')
parser.add_argument('--size', type=int,
help=f'Size to profile. Options: {", ".join(map(str, self.config.sizes))}')
parser.add_argument('--implementation', type=str,
help='Specific implementation to profile (default: all)')
args = parser.parse_args()
# If profile mode requested, return a profiling benchmark
if args.profile:
return self.profile(
operation=args.operation,
size=args.size,
implementation=args.implementation
)
# Otherwise return self for normal mode
return self
def data_generator(self, name: str = "default"):
"""Decorator to register data generator"""
def decorator(func):
self.data_generators[name] = func
return func
return decorator
def implementation(self, name: str, operations: Union[str, List[str]] = None):
"""Decorator to register implementation"""
if operations is None:
operations = ['default']
elif isinstance(operations, str):
operations = [operations]
def decorator(func):
for op in operations:
if op not in self.implementations:
self.implementations[op] = {}
self.implementations[op][name] = func
return func
return decorator
def validator(self, operation: str = "default"):
"""Decorator to register custom validator"""
def decorator(func):
self.validators[operation] = func
return func
return decorator
def setup(self, name: str, operations: Union[str, List[str]] = None):
"""Decorator to register setup function that runs before timing"""
if operations is None:
operations = ['default']
elif isinstance(operations, str):
operations = [operations]
def decorator(func):
for op in operations:
if op not in self.setups:
self.setups[op] = {}
self.setups[op][name] = func
return func
return decorator
def measure_time(self, func: Callable, data: Any, setup_func: Callable = None) -> tuple[Any, float]:
"""Measure execution time with warmup and optional setup"""
# Warmup runs
for _ in range(self.config.warmup):
try:
if setup_func:
setup_data = setup_func(data)
func(setup_data)
else:
func(data)
except Exception:
# If warmup fails, let the main measurement handle the error
break
# Actual measurement
start = time.perf_counter()
for _ in range(self.config.iterations):
if setup_func:
setup_data = setup_func(data)
result = func(setup_data)
else:
result = func(data)
elapsed_ms = (time.perf_counter() - start) * 1000 / self.config.iterations
return result, elapsed_ms
def validate_result(self, expected: Any, actual: Any, operation: str) -> bool:
"""Validate result using custom validator or default comparison"""
if operation in self.validators:
return self.validators[operation](expected, actual)
return expected == actual
def run(self):
"""Run all benchmarks"""
if self.config.profile_mode:
self._run_profile_mode()
else:
self._run_normal_mode()
def _run_normal_mode(self):
"""Run normal benchmark mode"""
print(f"Running {self.config.name}")
print(f"Sizes: {self.config.sizes}")
print(f"Operations: {self.config.operations}")
print("="*80)
# Always show progressive results: operation by operation across all sizes
for operation in self.config.operations:
for size in self.config.sizes:
self._run_single(operation, size)
# Save and plot results
if self.config.save_results:
self._save_results()
if self.config.plot_results:
self._plot_results()
# Print summary
self._print_summary()
def _run_profile_mode(self):
"""Run profiling mode with minimal overhead for use with vmprof"""
operation = self.config.profile_operation or self.config.operations[0]
size = self.config.profile_size or max(self.config.sizes)
impl_name = self.config.profile_implementation
print(f"PROFILING MODE: {self.config.name}")
print(f"Operation: {operation}, Size: {size}")
if impl_name:
print(f"Implementation: {impl_name}")
print("="*80)
print("Run with vmprof: vmprof --web " + ' '.join(sys.argv))
print("="*80)
# Generate test data
generator = self.data_generators.get(operation, self.data_generators.get('default'))
if not generator:
raise ValueError(f"No data generator for operation: {operation}")
test_data = generator(size, operation)
# Get implementations
impls = self.implementations.get(operation, {})
if not impls:
raise ValueError(f"No implementations for operation: {operation}")
# Filter to specific implementation if requested
if impl_name:
if impl_name not in impls:
raise ValueError(f"Implementation '{impl_name}' not found for operation '{operation}'")
impls = {impl_name: impls[impl_name]}
# Run with minimal overhead - no timing, no validation
for name, func in impls.items():
print(f"\nRunning {name}...")
sys.stdout.flush()
# Setup if needed
setup_func = self.setups.get(operation, {}).get(name)
if setup_func:
data = setup_func(test_data)
else:
data = test_data
# Run the actual function (this is what vmprof will profile)
result = func(data)
print(f"Completed {name}, result checksum: {result}")
sys.stdout.flush()
def _run_single(self, operation: str, size: int):
"""Run a single operation/size combination"""
print(f"\nOperation: {operation}, Size: {size}")
print("-" * 50)
sys.stdout.flush()
# Generate test data
generator = self.data_generators.get(operation,
self.data_generators.get('default'))
if not generator:
raise ValueError(f"No data generator for operation: {operation}")
test_data = generator(size, operation)
# Get implementations for this operation
impls = self.implementations.get(operation, {})
if not impls:
print(f"No implementations for operation: {operation}")
return
# Get setup functions for this operation
setups = self.setups.get(operation, {})
# Run reference implementation first
ref_name, ref_impl = next(iter(impls.items()))
ref_setup = setups.get(ref_name)
expected_result, _ = self.measure_time(ref_impl, test_data, ref_setup)
# Run all implementations
for impl_name, impl_func in impls.items():
try:
setup_func = setups.get(impl_name)
result, time_ms = self.measure_time(impl_func, test_data, setup_func)
correct = self.validate_result(expected_result, result, operation)
# Store result
self.results.append({
'operation': operation,
'size': size,
'implementation': impl_name,
'time_ms': time_ms,
'correct': correct,
'error': None
})
status = "OK" if correct else "FAIL"
print(f" {impl_name:<20} {time_ms:>8.3f} ms {status}")
sys.stdout.flush()
except Exception as e:
self.results.append({
'operation': operation,
'size': size,
'implementation': impl_name,
'time_ms': float('inf'),
'correct': False,
'error': str(e)
})
print(f" {impl_name:<20} ERROR: {str(e)[:40]}")
sys.stdout.flush()
def _save_results(self):
"""Save results to JSON"""
output_dir = Path(self.config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
filename = output_dir / f"{self.config.name}_{int(time.time())}.json"
with open(filename, 'w') as f:
json.dump(self.results, f, indent=2)
print(f"\nResults saved to {filename}")
def _plot_results(self):
"""Generate plots using matplotlib if available"""
try:
import matplotlib.pyplot as plt
output_dir = Path(self.config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Group and prepare data for plotting
data_by_op = self._group_results_by_operation()
# Create plots for each operation
for operation, operation_data in data_by_op.items():
self._create_performance_plot(plt, operation, operation_data, output_dir)
except ImportError:
print("Matplotlib not available - skipping plots")
except Exception as e:
print(f"Plotting failed: {e}")
def _group_results_by_operation(self) -> Dict[str, Dict[int, List[Dict[str, Any]]]]:
"""Group results by operation and size for plotting"""
data_by_op = defaultdict(lambda: defaultdict(list))
for r in self.results:
if r['time_ms'] != float('inf') and r['correct']:
data_by_op[r['operation']][r['size']].append({
'implementation': r['implementation'],
'time_ms': r['time_ms']
})
return data_by_op
def _create_performance_plot(self, plt, operation: str, operation_data: Dict[int, List[Dict[str, Any]]], output_dir: Path):
"""Create a performance plot for a single operation"""
sizes = sorted(operation_data.keys())
implementations = set()
for size_data in operation_data.values():
for entry in size_data:
implementations.add(entry['implementation'])
implementations = sorted(implementations)
plt.figure(figsize=(10, 6))
for impl in implementations:
impl_times = []
impl_sizes = []
for size in sizes:
times = [entry['time_ms'] for entry in operation_data[size]
if entry['implementation'] == impl]
if times:
impl_times.append(statistics.mean(times))
impl_sizes.append(size)
if impl_times:
plt.plot(impl_sizes, impl_times, 'o-', label=impl)
plt.xlabel('Input Size')
plt.ylabel('Time (ms)')
plt.title(f'{self.config.name} - {operation} Operation')
plt.legend()
plt.grid(True, alpha=0.3)
# Apply the configured scaling
if self.config.plot_scale == "loglog":
plt.loglog()
elif self.config.plot_scale == "linear":
pass # Default linear scale
elif self.config.plot_scale == "semilogx":
plt.semilogx()
elif self.config.plot_scale == "semilogy":
plt.semilogy()
else:
# Default to loglog if invalid option
plt.loglog()
plot_file = output_dir / f"{self.config.name}_{operation}_performance.png"
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.close()
print(f"Plot saved: {plot_file}")
def _print_summary(self):
"""Print performance summary"""
print("\n" + "="*80)
print("PERFORMANCE SUMMARY")
print("="*80)
# Group by operation
by_operation = defaultdict(lambda: defaultdict(list))
for r in self.results:
if r['error'] is None and r['time_ms'] != float('inf'):
by_operation[r['operation']][r['implementation']].append(r['time_ms'])
print(f"{'Operation':<15} {'Best Implementation':<20} {'Avg Time (ms)':<15} {'Speedup':<10}")
print("-" * 70)
for op, impl_times in sorted(by_operation.items()):
# Calculate averages
avg_times = [(impl, statistics.mean(times))
for impl, times in impl_times.items()]
avg_times.sort(key=lambda x: x[1])
if avg_times:
best_impl, best_time = avg_times[0]
worst_time = avg_times[-1][1]
speedup = worst_time / best_time if best_time > 0 else 0
print(f"{op:<15} {best_impl:<20} {best_time:<15.3f} {speedup:<10.1f}x")
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
class mint(int):
mod: int
zero: 'mint'
one: 'mint'
two: 'mint'
cache: list['mint']
def __new__(cls, *args, **kwargs):
if 0 <= (x := int(*args, **kwargs)) < 64:
return cls.cache[x]
else:
return cls.fix(x)
@classmethod
def set_mod(cls, mod: int):
mint.mod = cls.mod = mod
mint.zero = cls.zero = cls.cast(0)
mint.one = cls.one = cls.fix(1)
mint.two = cls.two = cls.fix(2)
mint.cache = cls.cache = [cls.zero, cls.one, cls.two]
for x in range(3,64): mint.cache.append(cls.fix(x))
@classmethod
def fix(cls, x): return cls.cast(x%cls.mod)
@classmethod
def cast(cls, x): return super().__new__(cls,x)
@classmethod
def mod_inv(cls, x):
a,b,s,t = int(x), cls.mod, 1, 0
while b: a,b,s,t = b,a%b,t,s-a//b*t
if a == 1: return cls.fix(s)
raise ValueError(f"{x} is not invertible in mod {cls.mod}")
@property
def inv(self): return mint.mod_inv(self)
def __add__(self, x): return mint.fix(super().__add__(x))
def __radd__(self, x): return mint.fix(super().__radd__(x))
def __sub__(self, x): return mint.fix(super().__sub__(x))
def __rsub__(self, x): return mint.fix(super().__rsub__(x))
def __mul__(self, x): return mint.fix(super().__mul__(x))
def __rmul__(self, x): return mint.fix(super().__rmul__(x))
def __floordiv__(self, x): return self * mint.mod_inv(x)
def __rfloordiv__(self, x): return self.inv * x
def __truediv__(self, x): return self * mint.mod_inv(x)
def __rtruediv__(self, x): return self.inv * x
def __pow__(self, x):
return self.cast(super().__pow__(x, self.mod))
def __neg__(self): return mint.mod-self
def __pos__(self): return self
def __abs__(self): return self
def __class_getitem__(self, x: int): return self.cache[x]
def mod_inv(x, mod):
a,b,s,t = x, mod, 1, 0
while b:
a,b,s,t = b,a%b,t,s-a//b*t
if a == 1: return s % mod
raise ValueError(f"{x} is not invertible in mod {mod}")
class NTT:
def __init__(self, mod = 998244353) -> None:
self.mod = m = mod
self.g = g = self.primitive_root(m)
self.rank2 = rank2 = ((m-1)&(1-m)).bit_length() - 1
self.root = root = [0] * (rank2 + 1)
root[rank2] = pow(g, (m - 1) >> rank2, m)
self.iroot = iroot = [0] * (rank2 + 1)
iroot[rank2] = pow(root[rank2], m - 2, m)
for i in range(rank2 - 1, -1, -1):
root[i] = root[i+1] * root[i+1] % m
iroot[i] = iroot[i+1] * iroot[i+1] % m
def rates(s):
r8,ir8 = [0]*max(0,rank2-s+1), [0]*max(0,rank2-s+1)
p = ip = 1
for i in range(rank2-s+1):
r, ir = root[i+s], iroot[i+s]
p,ip,r8[i],ir8[i]= p*ir%m,ip*r%m,r*p%m,ir*ip%m
return r8, ir8
self.rate2, self.irate2 = rates(2)
self.rate3, self.irate3 = rates(3)
def primitive_root(self, m):
if m == 2: return 1
if m == 167772161: return 3
if m == 469762049: return 3
if m == 754974721: return 11
if m == 998244353: return 3
divs = [0] * 20
cnt, divs[0], x = 1, 2, (m - 1) // 2
while x % 2 == 0: x //= 2
i=3
while i*i <= x:
if x%i == 0:
divs[cnt],cnt = i,cnt+1
while x%i==0:x//=i
i+=2
if x > 1: divs[cnt],cnt = x,cnt+1
for g in range(2,m):
for i in range(cnt):
if pow(g,(m-1)//divs[i],m)==1:break
else:return g
def fntt(self, A: list[int]):
im, r8, m, h = self.root[2],self.rate3,self.mod,(len(A)-1).bit_length()
for L in range(0,h-1,2):
p, r = 1<<(h-L-2),1
for s in range(1 << L):
r3,of=(r2:=r*r%m)*r%m,s<<(h-L)
for i in range(p):
i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
a0,a1,a2,a3 = A[i0],A[i1]*r,A[i2]*r2,A[i3]*r3
a0,a1,a2,a3 = a0+a2,a1+a3,a0-a2,(a1-a3)%m*im
A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a0-a1)%m,(a2+a3)%m,(a2-a3)%m
r=r*r8[(~s&-~s).bit_length()-1]%m
if h&1:
r, r8 = 1, self.rate2
for s in range(1<<(h-1)):
i1=(i0:=s<<1)+1
al,ar = A[i0],A[i1]*r%m
A[i0],A[i1] = (al+ar)%m,(al-ar)%m
r=r*r8[(~s&-~s).bit_length()-1]%m
return A
def _ifntt(self, A: list[int]):
im, r8, m, h = self.iroot[2],self.irate3,self.mod,(len(A)-1).bit_length()
for L in range(h,1,-2):
p,r = 1<<(h-L),1
for s in range(1<<(L-2)):
r3,of=(r2:=r*r%m)*r%m,s<<(h-L+2)
for i in range(p):
i3=(i2:=(i1:=(i0:=i+of)+p)+p)+p
a0,a1,a2,a3 = A[i0],A[i1],A[i2],A[i3]
a0,a1,a2,a3 = a0+a1,a2+a3,a0-a1,(a2-a3)*im%m
A[i0],A[i1],A[i2],A[i3] = (a0+a1)%m,(a2+a3)*r%m,(a0-a1)*r2%m,(a2-a3)*r3%m
r=r*r8[(~s&-~s).bit_length()-1]%m
if h&1:
for i0 in range(p:=1<<(h-1)):
al,ar = A[i0],A[i1:=i0+p]
A[i0],A[i1] = (al+ar)%m,(al-ar)%m
return A
def ifntt(self, A: list[int]):
self._ifntt(A)
iz = mod_inv(N:=len(A),mod:=self.mod)
for i in range(N): A[i]=A[i]*iz%mod
return A
def conv_naive(self, A, B, N):
n, m, mod = len(A),len(B),self.mod
C = [0]*N
if n < m: A,B,n,m = B,A,m,n
for i,a in enumerate(A):
for j in range(min(m,N-i)):
C[ij]=(C[ij:=i+j]+a*B[j])%mod
return C
def conv_fntt(self, A, B, N):
n,m,mod=len(A),len(B),self.mod
z=1<<(n+m-2).bit_length()
self.fntt(A:=A+[0]*(z-n)), self.fntt(B:=B+[0]*(z-m))
for i, b in enumerate(B): A[i] = A[i] * b % mod
self.ifntt(A)
del A[N:]
return A
def deconv(self, C, B, N = None):
n, m = len(C), len(B)
if N is None: N = n - m + 1
z = 1 << (n + m - 2).bit_length()
self.fntt(C := C+[0]*(z-n)), self.fntt(B := B+[0]*(z - m))
A = [0] * z
for i in range(z):
if B[i] == 0:
raise ValueError("Division by zero in NTT domain - deconvolution not possible")
b_inv = mod_inv(B[i], self.mod)
A[i] = (C[i] * b_inv) % self.mod
self.ifntt(A)
return A[:N]
def conv_half(self, A, Bres):
mod = self.mod
self.fntt(A)
for i, b in enumerate(Bres): A[i] = A[i] * b % mod
self.ifntt(A)
return A
def conv(self, A, B, N = None):
n,m = len(A), len(B)
N = n+m-1 if N is None else N
if min(n,m) <= 60: return self.conv_naive(A, B, N)
return self.conv_fntt(A, B, N)
def cycle_conv(self, A, B):
n,m,mod=len(A),len(B),self.mod
assert n == m
if n==0:return[]
con,res=self.conv(A,B),[0]*n
for i in range(n-1):res[i]=(con[i]+con[i+n])%mod
res[n-1]=con[n-1]
return res
class mint(mint):
ntt: NTT
@classmethod
def set_mod(cls, mod: int):
super().set_mod(mod)
cls.ntt = NTT(mod)
class mlist:
def __init__(lst, data): lst.data = [0]*data if isinstance(data, int) else [int(x) for x in data]
@staticmethod
def from_raw(data: list[int]):
(lst := mlist.__new__(mlist)).data = data
return lst
def __getitem__(lst, i) -> mint: return mint(lst.data[i])
def __setitem__(lst, i, x): lst.data[i] = int(x)
def __len__(lst): return len(lst.data)
def conv(A, B, N):
A = A.data
B = B.data if hasattr(B, 'data') else B
return mlist.from_raw(mint.ntt.conv(A, B, N))
# Setup modular arithmetic with a common modulus
MOD = 998244353
mint.set_mod(MOD)
# Configure benchmark
config = BenchmarkConfig(
name="mlist",
sizes=[1000000, 100000, 10000, 1000, 100, 10, 1], # Reverse order to warm up JIT
operations=['construction', 'addition', 'multiplication', 'mixed_ops', 'elementwise_mul', 'sum_all', 'conv'],
iterations=10,
warmup=3,
output_dir="./output/benchmark_results/mlist"
)
# Create benchmark instance
benchmark = Benchmark(config)
# Data generators
@benchmark.data_generator("default")
def generate_modular_data(size: int, operation: str):
"""Generate test data for modular arithmetic operations"""
# Generate two random lists for operations
list1 = [random.randint(1, MOD-1) for _ in range(size)]
list2 = [random.randint(1, MOD-1) for _ in range(size)]
# Pre-initialize data for fair timing (exclude initialization overhead)
preinitialized = {
'list1_copy': list(list1),
'list2_copy': list(list2),
'mlist1': mlist(list(list1)),
'mlist2': mlist(list(list2)),
'mint_list1': [mint(x) for x in list1],
'mint_list2': [mint(x) for x in list2],
'result_buffer': [0] * size,
'mlist_result': mlist(size),
'constant': 12345,
'mint_constant': mint(12345)
}
return {
'list1': list1,
'list2': list2,
'size': size,
'operation': operation,
'mod': MOD,
'preinitialized': preinitialized
}
# Construction operation
@benchmark.implementation("int_list", "construction")
def construction_int_list(data):
"""Construct int list from raw data"""
list1 = list(data['list1'])
list2 = list(data['list2'])
checksum = 0
for x in list1:
checksum ^= x
for x in list2:
checksum ^= x
return checksum
@benchmark.implementation("mlist", "construction")
def construction_mlist(data):
"""Construct mlist from raw data"""
mlist1 = mlist(data['list1'])
mlist2 = mlist(data['list2'])
checksum = 0
for x in mlist1.data:
checksum ^= x
for x in mlist2.data:
checksum ^= x
return checksum
@benchmark.implementation("mint_list", "construction")
def construction_mint_list(data):
"""Construct mint list from raw data"""
mint_list1 = [mint(x) for x in data['list1']]
mint_list2 = [mint(x) for x in data['list2']]
checksum = 0
for x in mint_list1:
checksum ^= x
for x in mint_list2:
checksum ^= x
return checksum
# Addition operation
@benchmark.implementation("int_list", "addition")
def addition_int_list(data):
"""Element-wise addition with manual modulo"""
pre = data['preinitialized']
list1, list2, mod = pre['list1_copy'], pre['list2_copy'], data['mod']
checksum = 0
for i in range(data['size']):
checksum ^= (list1[i] + list2[i]) % mod
return checksum
@benchmark.implementation("mlist", "addition")
def addition_mlist(data):
"""Element-wise addition using mlist"""
pre = data['preinitialized']
list1, list2 = pre['mlist1'], pre['mlist2']
checksum = 0
for i in range(data['size']):
checksum ^= list1[i] + list2[i]
return checksum
@benchmark.implementation("mint_list", "addition")
def addition_mint_list(data):
"""Element-wise addition using mint list"""
pre = data['preinitialized']
list1, list2 = pre['mint_list1'], pre['mint_list2']
checksum = 0
for i in range(data['size']):
checksum ^= list1[i] + list2[i]
return checksum
# Multiplication operation
@benchmark.implementation("int_list", "multiplication")
def multiplication_int_list(data):
"""Element-wise multiplication with manual modulo"""
pre = data['preinitialized']
list1, list2, mod = pre['list1_copy'], pre['list2_copy'], data['mod']
checksum = 0
for i in range(data['size']):
checksum ^= (list1[i] * list2[i]) % mod
return checksum
@benchmark.implementation("mlist", "multiplication")
def multiplication_mlist(data):
"""Element-wise multiplication using mlist"""
pre = data['preinitialized']
list1, list2 = pre['mlist1'], pre['mlist2']
checksum = 0
for i in range(data['size']):
checksum ^= list1[i] * list2[i]
return checksum
@benchmark.implementation("mint_list", "multiplication")
def multiplication_mint_list(data):
"""Element-wise multiplication using mint list"""
pre = data['preinitialized']
list1, list2 = pre['mint_list1'], pre['mint_list2']
checksum = 0
for i in range(data['size']):
checksum ^= list1[i] * list2[i]
return checksum
# Mixed operations
@benchmark.implementation("int_list", "mixed_ops")
def mixed_ops_int_list(data):
"""Mix of addition, multiplication, and subtraction"""
pre = data['preinitialized']
list1, list2, mod = pre['list1_copy'], pre['list2_copy'], data['mod']
checksum = 0
for i in range(data['size']):
if i % 3 == 0:
checksum ^= (list1[i] + list2[i]) % mod
elif i % 3 == 1:
checksum ^= (list1[i] * list2[i]) % mod
else:
checksum ^= (list1[i] - list2[i]) % mod
return checksum
@benchmark.implementation("mlist", "mixed_ops")
def mixed_ops_mlist(data):
"""Mix of operations using mlist"""
pre = data['preinitialized']
list1, list2 = pre['mlist1'], pre['mlist2']
checksum = 0
for i in range(data['size']):
if i % 3 == 0:
checksum ^= list1[i] + list2[i]
elif i % 3 == 1:
checksum ^= list1[i] * list2[i]
else:
checksum ^= list1[i] - list2[i]
return checksum
@benchmark.implementation("mint_list", "mixed_ops")
def mixed_ops_mint_list(data):
"""Mix of operations using mint list"""
pre = data['preinitialized']
list1, list2 = pre['mint_list1'], pre['mint_list2']
checksum = 0
for i in range(data['size']):
if i % 3 == 0:
checksum ^= list1[i] + list2[i]
elif i % 3 == 1:
checksum ^= list1[i] * list2[i]
else:
checksum ^= list1[i] - list2[i]
return checksum
@benchmark.implementation("int_list_e", "mixed_ops")
def mixed_ops_int_list(data):
"""Mix of addition, multiplication, and subtraction"""
pre = data['preinitialized']
list1, list2, mod = pre['list1_copy'], pre['list2_copy'], data['mod']
checksum = 0
for i, x in enumerate(list1):
if i % 3 == 0:
checksum ^= (x + list2[i]) % mod
elif i % 3 == 1:
checksum ^= (x * list2[i]) % mod
else:
checksum ^= (x - list2[i]) % mod
return checksum
@benchmark.implementation("mlist_e", "mixed_ops")
def mixed_ops_mlist(data):
"""Mix of operations using mlist"""
pre = data['preinitialized']
list1, list2 = pre['mlist1'], pre['mlist2']
checksum = 0
for i, x in enumerate(list1):
if i % 3 == 0:
checksum ^= x + list2[i]
elif i % 3 == 1:
checksum ^= x * list2[i]
else:
checksum ^= x - list2[i]
return checksum
@benchmark.implementation("mint_list_e", "mixed_ops")
def mixed_ops_mint_list(data):
"""Mix of operations using mint list"""
pre = data['preinitialized']
list1, list2 = pre['mint_list1'], pre['mint_list2']
checksum = 0
for i, x in enumerate(list1):
if i % 3 == 0:
checksum ^= x + list2[i]
elif i % 3 == 1:
checksum ^= x * list2[i]
else:
checksum ^= x - list2[i]
return checksum
# Element-wise multiplication by constant
@benchmark.implementation("int_list", "elementwise_mul")
def elementwise_mul_int_list(data):
"""Multiply each element by a constant"""
pre = data['preinitialized']
list1, mod, constant = pre['list1_copy'], data['mod'], pre['constant']
checksum = 0
for x in list1:
checksum ^= (x * constant) % mod
return checksum
@benchmark.implementation("mlist", "elementwise_mul")
def elementwise_mul_mlist(data):
"""Multiply each element by a constant using mlist"""
pre = data['preinitialized']
list1, constant = pre['mlist1'], pre['mint_constant']
checksum = 0
for x in list1:
checksum ^= x * constant
return checksum
@benchmark.implementation("mint_list", "elementwise_mul")
def elementwise_mul_mint_list(data):
"""Multiply each element by a constant using mint list"""
pre = data['preinitialized']
list1, constant = pre['mint_list1'], pre['mint_constant']
checksum = 0
for x in list1:
result = x * constant
checksum ^= result
return checksum
# Sum all elements
@benchmark.implementation("int_list", "sum_all")
def sum_all_int_list(data):
"""Sum all elements"""
pre = data['preinitialized']
list1, mod = pre['list1_copy'], data['mod']
result = 0
for x in list1:
result = (result + x) % mod
return result
@benchmark.implementation("mlist", "sum_all")
def sum_all_mlist(data):
"""Sum all elements using mlist"""
pre = data['preinitialized']
list1 = pre['mlist1']
result = mint(0)
for x in list1:
result = result + x
return int(result)
@benchmark.implementation("mint_list", "sum_all")
def sum_all_mint_list(data):
"""Sum all elements using mint list"""
pre = data['preinitialized']
list1 = pre['mint_list1']
result = mint(0)
for x in list1:
result = result + x
return int(result)
# Convolution operation
@benchmark.implementation("int_list", "conv")
def conv_int_list(data):
"""Convolution using mint.ntt.conv with int lists"""
pre = data['preinitialized']
list1, list2 = pre['list1_copy'], pre['list2_copy']
# Use mint.ntt.conv for convolution
result = mint.ntt.conv(list1, list2, len(list1) + len(list2) - 1)
checksum = 0
for x in result:
checksum ^= x
return checksum
@benchmark.implementation("mlist", "conv")
def conv_mlist(data):
"""Convolution using mlist.conv method"""
pre = data['preinitialized']
mlist1, mlist2 = pre['mlist1'], pre['mlist2']
# Use mlist.conv method
result = mlist1.conv(mlist2, len(mlist1) + len(mlist2) - 1)
checksum = 0
for x in result.data:
checksum ^= x
return checksum
@benchmark.implementation("mint_list", "conv")
def conv_mint_list(data):
"""Convolution using mint.ntt.conv with mint lists"""
pre = data['preinitialized']
mint_list1, mint_list2 = pre['mint_list1'], pre['mint_list2']
# Convert to int lists, convolve, convert back
int_list1 = [int(x) for x in mint_list1]
int_list2 = [int(x) for x in mint_list2]
result_ints = mint.ntt.conv(int_list1, int_list2, len(int_list1) + len(int_list2) - 1)
result = [mint(x) for x in result_ints]
checksum = 0
for x in result:
checksum ^= x
return checksum
@benchmark.implementation("mint_list_direct", "conv")
def conv_mint_list_direct(data):
"""Convolution using mint.ntt.conv directly with mint lists"""
pre = data['preinitialized']
mint_list1, mint_list2 = pre['mint_list1'], pre['mint_list2']
result = mint.ntt.conv(mint_list1, mint_list2, len(mint_list1) + len(mint_list2) - 1)
checksum = 0
for x in result:
checksum ^= x
return checksum
# Custom validator for modular arithmetic results (now using XOR checksums)
@benchmark.validator("default")
def validate_modular_result(expected, actual):
"""Validate modular arithmetic results using XOR checksums"""
try:
# Compare XOR checksums directly
return int(expected) == int(actual)
except Exception:
return False
if __name__ == "__main__":
# Parse command line args and run appropriate mode
runner = benchmark.parse_args()
runner.run()