cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: cp_library/alg/tree/lca_table_weighted_iterative_cls.py

Depends on

Required by

Verified with

Code

import cp_library.__header__
import cp_library.alg.__header__
import cp_library.alg.tree.__header__
from cp_library.alg.iter.presum_fn import presum
from cp_library.alg.tree.lca_table_iterative_cls import LCATable

class LCATableWeighted(LCATable):
    def __init__(lca, T, root = 0):
        super().__init__(T, root)
        lca.weights = T.Wdelta
        lca.weighted_depth = None

    def distance(lca, u, v) -> int:
        if lca.weighted_depth is None:
            lca.weighted_depth = presum(lca.weights)
        l, r, a, _ = lca._query(u, v)
        m = lca.tin[a]
        return lca.weighted_depth[l] + lca.weighted_depth[r-1] - 2*lca.weighted_depth[m]
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''



import operator
from itertools import accumulate
from typing import Callable, Iterable, TypeVar
_T = TypeVar('T')

def presum(iter: Iterable[_T], func: Callable[[_T,_T],_T] = None, initial: _T = None, step = 1) -> list[_T]:
    if step == 1:
        return list(accumulate(iter, func, initial=initial))
    else:
        assert step >= 2
        if func is None:
            func = operator.add
        A = list(iter)
        if initial is not None:
            A = [initial] + A
        for i in range(step,len(A)):
            A[i] = func(A[i], A[i-step])
        return A


def sort2(a, b):
    return (a,b) if a < b else (b,a)
# from typing import Generic
# from cp_library.misc.typing import _T

def min2(a, b):
    return a if a < b else b



class MinSparseTable:
    def __init__(st, arr: list):
        st.N = N = len(arr)
        st.log = N.bit_length()
        st.data = data = [0] * (st.log*N)
        data[:N] = arr 
        for i in range(1,st.log):
            a, b, c = i*N, (i-1)*N, (i-1)*N + (1 << (i-1))
            for j in range(N - (1 << i) + 1):
                data[a+j] = min2(data[b+j], data[c+j])

    def query(st, l: int, r: int):
        k = (r-l).bit_length() - 1
        return min2(st.data[k*st.N + l], st.data[k*st.N + r - (1<<k)])
    

class LCATable(MinSparseTable):
    def __init__(lca, T, root = 0):
        N = len(T)
        T.euler_tour(root)
        lca.depth = depth = presum(T.delta)
        lca.tin, lca.tout = T.tin[:], T.tout[:]
        lca.mask = (1 << (shift := N.bit_length()))-1
        lca.shift = shift
        order = T.order
        M = len(order)
        packets = [0]*M
        for i in range(M):
            packets[i] = depth[i] << shift | order[i] 
        super().__init__(packets)

    def _query(lca, u, v):
        l, r = sort2(lca.tin[u], lca.tin[v]); r += 1
        da = super().query(l, r)
        return l, r, da & lca.mask, da >> lca.shift

    def query(lca, u, v) -> tuple[int,int]:
        l, r, a, d = lca._query(u, v)
        return a, d
    
    def distance(lca, u, v) -> int:
        l, r, a, d = lca._query(u, v)
        return lca.depth[l] + lca.depth[r-1] - 2*d
    
    def path(lca, u, v):
        path, par, lca, c = [], lca.T.par, lca.query(u, v)[0], u
        while c != lca:
            path.append(c)
            c = par[c]
        path.append(lca)
        rev_path, c = [], v
        while c != lca:
            rev_path.append(c)
            c = par[c]
        path.extend(reversed(rev_path))
        return path

class LCATableWeighted(LCATable):
    def __init__(lca, T, root = 0):
        super().__init__(T, root)
        lca.weights = T.Wdelta
        lca.weighted_depth = None

    def distance(lca, u, v) -> int:
        if lca.weighted_depth is None:
            lca.weighted_depth = presum(lca.weights)
        l, r, a, _ = lca._query(u, v)
        m = lca.tin[a]
        return lca.weighted_depth[l] + lca.weighted_depth[r-1] - 2*lca.weighted_depth[m]
Back to top page