cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: cp_library/ds/tree/ahocorasick_cls.py

Depends on

Verified with

Code

import cp_library.__header__
from typing import Optional
from collections import Counter, deque
import cp_library.ds.__header__
import cp_library.ds.tree.__header__
from cp_library.ds.tree.trie_cls import Trie

class AhoCorasick(Trie):
    __slots__ = 'failed', 'freq'

    def __init__(T):
        super().__init__()
        T.failed: Optional['AhoCorasick'] = None
        T.freq: int = 0

    def build(T):
        order: list[AhoCorasick] = T.bfs()
        for node in order:
            now: AhoCorasick = node.par
            chr = node.chr
            while now.failed:
                if chr in now.failed.sub:
                    node.failed = now.failed.sub[chr]
                    break
                now = now.failed
            else:
                node.failed = T
        T.failed = T
        return order

    def freq_table(T, text: str) -> Counter[str, int]:
        order = T.build()
        order.reverse()
        node: AhoCorasick = T
        for chr in text:
            while node != T and chr not in node.sub:
                node = node.failed
            node = node.sub.get(chr, T)
            node.freq += 1

        output = Counter()
        for node in order:
            node.failed.freq += node.freq
            if node.word:
                output[str(node)] = node.freq
        return output

    def bfs(T) -> list['Trie']:
        order, que = [], deque([T])
        while que:
            order.extend(sub := que.popleft().sub.values())
            que.extend(sub)
        return order
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
from typing import Optional
from collections import Counter, deque



class Trie:
    __slots__ = 'sub', 'par', 'chr', 'cnt', 'word'

    def __init__(T):
        T.sub: dict[str, Trie] = {}
        T.par: Optional[Trie] = None
        T.chr: str = ""
        T.cnt: int = 0
        T.word: bool = False

    def add(T, word: str):
        (node := T).cnt += 1
        for chr in word:
            if chr not in node.sub:   
                node.sub[chr] = T.__class__()
            par, node = node, node.sub[chr]
            node.par, node.chr = par, chr
            node.cnt += 1
        node.word = True

    def remove(T, word: str):
        node = T.find(word)
        assert node and node.cnt >= 1
        if node.cnt == 1 and node.par:
            del node.par.sub[node.chr]
        while node:
            node.cnt -= 1
            node = node.par
    
    def discard(T, word: str):
        node = T.find(word)
        if node:
            if node.par:
                del node.par.sub[node.chr]
            cnt = node.cnt
            while node:
                node.cnt -= cnt
                node = node.par

    def find(T, prefix: str, full = True) -> Optional['Trie']:
        node = T
        for chr in prefix:
            if chr not in node.sub: return None if full else node
            node = node.sub[chr]
        return node
    
    def __contains__(T, word: str) -> bool:
        node = T.find(word)
        return node.word if node is not None else False

    def __len__(T):
        return T.cnt

    def __str__(T) -> str:
        ret, node = [], T
        while node.par:
            ret.append(node.chr); node = node.par
        ret.reverse()
        return "".join(ret)
    

class AhoCorasick(Trie):
    __slots__ = 'failed', 'freq'

    def __init__(T):
        super().__init__()
        T.failed: Optional['AhoCorasick'] = None
        T.freq: int = 0

    def build(T):
        order: list[AhoCorasick] = T.bfs()
        for node in order:
            now: AhoCorasick = node.par
            chr = node.chr
            while now.failed:
                if chr in now.failed.sub:
                    node.failed = now.failed.sub[chr]
                    break
                now = now.failed
            else:
                node.failed = T
        T.failed = T
        return order

    def freq_table(T, text: str) -> Counter[str, int]:
        order = T.build()
        order.reverse()
        node: AhoCorasick = T
        for chr in text:
            while node != T and chr not in node.sub:
                node = node.failed
            node = node.sub.get(chr, T)
            node.freq += 1

        output = Counter()
        for node in order:
            node.failed.freq += node.freq
            if node.word:
                output[str(node)] = node.freq
        return output

    def bfs(T) -> list['Trie']:
        order, que = [], deque([T])
        while que:
            order.extend(sub := que.popleft().sub.values())
            que.extend(sub)
        return order
Back to top page