This documentation is automatically generated by online-judge-tools/verification-helper
import cp_library.__header__
import cp_library.alg.__header__
import cp_library.alg.tree.__header__
from cp_library.alg.iter.presum_fn import presum
from cp_library.alg.tree.lca_table_iterative_cls import LCATable
class LCATableWeighted(LCATable):
def __init__(lca, T, root = 0):
super().__init__(T, root)
lca.weights = T.Wdelta
lca.weighted_depth = None
def distance(lca, u, v) -> int:
if lca.weighted_depth is None:
lca.weighted_depth = presum(lca.weights)
l, r, a, _ = lca._query(u, v)
m = lca.tin[a]
return lca.weighted_depth[l] + lca.weighted_depth[r-1] - 2*lca.weighted_depth[m]
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
https://kobejean.github.io/cp-library
'''
import operator
from itertools import accumulate
from typing import Callable, Iterable, TypeVar
_T = TypeVar('T')
def presum(iter: Iterable[_T], func: Callable[[_T,_T],_T] = None, initial: _T = None, step = 1) -> list[_T]:
if step == 1:
return list(accumulate(iter, func, initial=initial))
else:
assert step >= 2
if func is None:
func = operator.add
A = list(iter)
if initial is not None:
A = [initial] + A
for i in range(step,len(A)):
A[i] = func(A[i], A[i-step])
return A
def sort2(a, b):
return (a,b) if a < b else (b,a)
# from typing import Generic
# from cp_library.misc.typing import _T
def min2(a, b):
return a if a < b else b
class MinSparseTable:
def __init__(st, arr: list):
st.N = N = len(arr)
st.log = N.bit_length()
st.data = data = [0] * (st.log*N)
data[:N] = arr
for i in range(1,st.log):
a, b, c = i*N, (i-1)*N, (i-1)*N + (1 << (i-1))
for j in range(N - (1 << i) + 1):
data[a+j] = min2(data[b+j], data[c+j])
def query(st, l: int, r: int):
k = (r-l).bit_length() - 1
return min2(st.data[k*st.N + l], st.data[k*st.N + r - (1<<k)])
class LCATable(MinSparseTable):
def __init__(lca, T, root = 0):
N = len(T)
T.euler_tour(root)
lca.depth = depth = presum(T.delta)
lca.tin, lca.tout = T.tin[:], T.tout[:]
lca.mask = (1 << (shift := N.bit_length()))-1
lca.shift = shift
order = T.order
M = len(order)
packets = [0]*M
for i in range(M):
packets[i] = depth[i] << shift | order[i]
super().__init__(packets)
def _query(lca, u, v):
l, r = sort2(lca.tin[u], lca.tin[v]); r += 1
da = super().query(l, r)
return l, r, da & lca.mask, da >> lca.shift
def query(lca, u, v) -> tuple[int,int]:
l, r, a, d = lca._query(u, v)
return a, d
def distance(lca, u, v) -> int:
l, r, a, d = lca._query(u, v)
return lca.depth[l] + lca.depth[r-1] - 2*d
def path(lca, u, v):
path, par, lca, c = [], lca.T.par, lca.query(u, v)[0], u
while c != lca:
path.append(c)
c = par[c]
path.append(lca)
rev_path, c = [], v
while c != lca:
rev_path.append(c)
c = par[c]
path.extend(reversed(rev_path))
return path
class LCATableWeighted(LCATable):
def __init__(lca, T, root = 0):
super().__init__(T, root)
lca.weights = T.Wdelta
lca.weighted_depth = None
def distance(lca, u, v) -> int:
if lca.weighted_depth is None:
lca.weighted_depth = presum(lca.weights)
l, r, a, _ = lca._query(u, v)
m = lca.tin[a]
return lca.weighted_depth[l] + lca.weighted_depth[r-1] - 2*lca.weighted_depth[m]