cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: cp_library/alg/tree/lca_table_recursive_cls.py

Depends on

Verified with

Code

import cp_library.alg.tree.__header__
import cp_library.misc.setrecursionlimit
from cp_library.ds.sparse_table_cls import SparseTable

class LCATable(SparseTable):
    def __init__(self, T, root):
        self.start = [-1] * len(T)
        euler_tour = []
        depths = []
        
        def dfs(u: int, p: int, depth: int):
            self.start[u] = len(euler_tour)
            euler_tour.append(u)
            depths.append(depth)
            
            for child in T[u]:
                if child != p:
                    dfs(child, u, depth + 1)
                    euler_tour.append(u)
                    depths.append(depth)
        
        dfs(root, -1, 0)
        super().__init__(min, list(zip(depths, euler_tour)))

    def query(self, u, v) -> tuple[int,int]:
        l, r = min(self.start[u], self.start[v]), max(self.start[u], self.start[v])+1
        d, a = super().query(l, r)
        return a, d

    def distance(self, u, v) -> int:
        l, r = min(self.start[u], self.start[v]), max(self.start[u], self.start[v])+1
        d, _ = super().query(l, r)
        return self.depth[l] + self.depth[r] - 2*d
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''


import sys
sys.setrecursionlimit(10**6)
import pypyjit
pypyjit.set_param("max_unroll_recursion=-1")
from typing import Generic, Callable
from typing import TypeVar
_S = TypeVar('S')
_T = TypeVar('T')
_U = TypeVar('U')


class SparseTable(Generic[_T]):
    def __init__(st, op: Callable[[_T,_T],_T], arr: list[_T]):
        st.N = N = len(arr)
        st.log, st.op = N.bit_length(), op
        st.data = [0] * (st.log*N)
        st.data[:N] = arr
        for i in range(1,st.log):
            a,b,c=i*N,(i-1)*N,(i-1)*N+(1<<(i-1))
            for j in range(N-(1<<i)+1):
                st.data[a+j] = op(st.data[b+j], st.data[c+j])

    def query(st, l: int, r: int) -> _T:
        k = (r-l).bit_length()-1
        return st.op(st.data[k*st.N+l],st.data[k*st.N+r-(1<<k)])

class LCATable(SparseTable):
    def __init__(self, T, root):
        self.start = [-1] * len(T)
        euler_tour = []
        depths = []
        
        def dfs(u: int, p: int, depth: int):
            self.start[u] = len(euler_tour)
            euler_tour.append(u)
            depths.append(depth)
            
            for child in T[u]:
                if child != p:
                    dfs(child, u, depth + 1)
                    euler_tour.append(u)
                    depths.append(depth)
        
        dfs(root, -1, 0)
        super().__init__(min, list(zip(depths, euler_tour)))

    def query(self, u, v) -> tuple[int,int]:
        l, r = min(self.start[u], self.start[v]), max(self.start[u], self.start[v])+1
        d, a = super().query(l, r)
        return a, d

    def distance(self, u, v) -> int:
        l, r = min(self.start[u], self.start[v]), max(self.start[u], self.start[v])+1
        d, _ = super().query(l, r)
        return self.depth[l] + self.depth[r] - 2*d
Back to top page