cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: cp_library/ds/tree/ahocorasick_cls.py

Depends on

Verified with

Code

import cp_library.__header__
import cp_library.ds.__header__
import cp_library.ds.tree.__header__
from cp_library.ds.tree.trie_cls import Trie

class AhoCorasick(Trie):
    __slots__ = 'failed',

    def __init__(self):
        super().__init__()
        self.failed: 'AhoCorasick' = None

    def build_fail(self):
        arr_bfs = self.bfs()
        for p in arr_bfs:
            curr = p.parent
            if curr:
                c = p.last
                while curr.failed:
                    if c in curr.failed.dic:
                        p.failed = curr.failed.dic[c]
                        break
                    curr = curr.failed
                else:
                    p.failed = self
        self.failed = self
        return arr_bfs

    def count_freq(self, text: str) -> dict[str, int]:
        arr_bfs = self.build_fail()
        p = self
        for c in text:
            while p != self and c not in p.dic:
                p = p.failed
            p = p.dic.get(c, self)
            p.count += 1

        output = {}
        for i in range(len(arr_bfs) - 1, 0, -1):
            p = arr_bfs[i]
            p.failed.count += p.count
            if p.word:
                output[p.prefix()] = p.count
        return output
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''


from collections import deque
from typing import Dict, List, Optional

class Trie:
    __slots__ = 'dic', 'parent', 'last', 'count', 'word'

    def __init__(self):
        self.dic: Dict[str, Trie] = {}
        self.parent: Optional[Trie] = None
        self.last: str = ""
        self.count: int = 0
        self.word: bool = False
    
    def add(self, word: str) -> None:
        p = self
        for c in word:
            if c not in p.dic:   
                p.dic[c] = type(self)()
            parent = p
            p = p.dic[c]
            p.parent = parent
            p.last = c
        p.word = True
    
    def find(self, prefix: str) -> 'Trie':
        node = self
        for char in prefix:
            if char not in node.dic:
                return None
            node = node.dic[char]
        return node
    
    def search(self, word: str) -> bool:
        node = self.find(word)
        return node.word if node is not None else False

    def bfs(self) -> List['Trie']:
        output = []
        queue = deque([self])
        while queue:
            p = queue.popleft()
            output.append(p)
            queue.extend(p.dic.values())
        return output
    
    def prefix(self) -> str:
        output = []
        curr = self
        while curr.parent is not None:
            output.append(curr.last)
            curr = curr.parent
        return "".join(reversed(output))

class AhoCorasick(Trie):
    __slots__ = 'failed',

    def __init__(self):
        super().__init__()
        self.failed: 'AhoCorasick' = None

    def build_fail(self):
        arr_bfs = self.bfs()
        for p in arr_bfs:
            curr = p.parent
            if curr:
                c = p.last
                while curr.failed:
                    if c in curr.failed.dic:
                        p.failed = curr.failed.dic[c]
                        break
                    curr = curr.failed
                else:
                    p.failed = self
        self.failed = self
        return arr_bfs

    def count_freq(self, text: str) -> dict[str, int]:
        arr_bfs = self.build_fail()
        p = self
        for c in text:
            while p != self and c not in p.dic:
                p = p.failed
            p = p.dic.get(c, self)
            p.count += 1

        output = {}
        for i in range(len(arr_bfs) - 1, 0, -1):
            p = arr_bfs[i]
            p.failed.count += p.count
            if p.word:
                output[p.prefix()] = p.count
        return output
Back to top page