cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: cp_library/alg/dp/rerooting_iterative_cls.py

Depends on

Verified with

Code

import cp_library.alg.dp.__header__
from typing import TypeVar, Callable
from cp_library.ds.bidirectional_array_cls import BidirectionalArray

class ReRootingDP():
    ''' A class implementation of the Re-rooting Dynamic Programming technique. '''
    
    S = TypeVar('S')
    MergeOp = Callable[[S, S], S]
    AddNodeOp = Callable[[int, S], S]
    AddEdgeOp = Callable[[int, int, S], S]

    def __init__(self, T: list[list[int]], e: S,
                 merge: MergeOp, 
                 add_node: AddNodeOp = lambda u,s:s, 
                 add_edge: AddEdgeOp = lambda u,v,s:s):
        '''
        T: list[list[int]] - Adjacency list representation of the tree.
        e: S - Identity element for the merge operation.
        merge: (S,S) -> S - Function to merge two states.
        add_node: (int,S) -> S - Function to incorporate a node into the state.
        add_edge: (int,int,S) -> S - Function to incorporate an edge into the state.
        '''
        self.T = T
        self.e = e
        self.merge = merge
        self.add_node = add_node
        self.add_edge = add_edge

    def solve(self) -> list[S]:
        dp = [[self.e]*len(adj) for adj in self.T]
        ans = [self.e for _ in range(len(self.T))]
        parent_idx = [None for _ in range(len(self.T))]
        child_idx = [None for _ in range(len(self.T))]
        stack = [(2,0,None),(0,0,None)]
        while stack:
            phase, u, p = stack.pop()
            match phase:
                case 0:  # Visit children
                    if p is not None:
                        stack.append((1,u,p))
                    for i,v in enumerate(self.T[u]):
                        if v != p:
                            stack.append((0,v,u))
                            child_idx[v] = i
                        else:
                            parent_idx[u] = i
                case 1:  # Upward updates
                    val = dp[p][child_idx[u]] = self.add_edge(p, u, self.add_node(u, ans[u]))
                    ans[p] = self.merge(ans[p], val)
                case 2:  # Downward updates
                    ba = BidirectionalArray(self.e, self.merge, dp[u])
                    for i,v in enumerate(self.T[u]):
                        if v != p:
                            dp[v][parent_idx[v]] = self.add_edge(v, u, self.add_node(u, ba.out(i)))
                            stack.append((2,v,u))
                    ans[u] = ba.all()
        return ans
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
from typing import TypeVar, Callable


class BidirectionalArray:
    def __init__(self, e, op, data):
        self.size = len(data)
        self.prefix = [e] + data.copy()
        self.suffix = data.copy() + [e]
        self.e = e
        self.op = op
        for i in range(self.size):
            self.prefix[i+1] = op(self.prefix[i], self.prefix[i+1])
        for i in range(self.size,0,-1):
            self.suffix[i-1] = op(self.suffix[i-1], self.suffix[i])
    def left(self, l): return self.prefix[l]
    def right(self, r): return self.suffix[r]
    def all(self): return self.prefix[-1]
    def out(self, l, r=None):
        r = l+1 if r is None else r
        return self.op(self.prefix[l], self.suffix[r])

class ReRootingDP():
    ''' A class implementation of the Re-rooting Dynamic Programming technique. '''
    
    S = TypeVar('S')
    MergeOp = Callable[[S, S], S]
    AddNodeOp = Callable[[int, S], S]
    AddEdgeOp = Callable[[int, int, S], S]

    def __init__(self, T: list[list[int]], e: S,
                 merge: MergeOp, 
                 add_node: AddNodeOp = lambda u,s:s, 
                 add_edge: AddEdgeOp = lambda u,v,s:s):
        '''
        T: list[list[int]] - Adjacency list representation of the tree.
        e: S - Identity element for the merge operation.
        merge: (S,S) -> S - Function to merge two states.
        add_node: (int,S) -> S - Function to incorporate a node into the state.
        add_edge: (int,int,S) -> S - Function to incorporate an edge into the state.
        '''
        self.T = T
        self.e = e
        self.merge = merge
        self.add_node = add_node
        self.add_edge = add_edge

    def solve(self) -> list[S]:
        dp = [[self.e]*len(adj) for adj in self.T]
        ans = [self.e for _ in range(len(self.T))]
        parent_idx = [None for _ in range(len(self.T))]
        child_idx = [None for _ in range(len(self.T))]
        stack = [(2,0,None),(0,0,None)]
        while stack:
            phase, u, p = stack.pop()
            match phase:
                case 0:  # Visit children
                    if p is not None:
                        stack.append((1,u,p))
                    for i,v in enumerate(self.T[u]):
                        if v != p:
                            stack.append((0,v,u))
                            child_idx[v] = i
                        else:
                            parent_idx[u] = i
                case 1:  # Upward updates
                    val = dp[p][child_idx[u]] = self.add_edge(p, u, self.add_node(u, ans[u]))
                    ans[p] = self.merge(ans[p], val)
                case 2:  # Downward updates
                    ba = BidirectionalArray(self.e, self.merge, dp[u])
                    for i,v in enumerate(self.T[u]):
                        if v != p:
                            dp[v][parent_idx[v]] = self.add_edge(v, u, self.add_node(u, ba.out(i)))
                            stack.append((2,v,u))
                    ans[u] = ba.all()
        return ans
Back to top page