cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/aoj/dsl/dsl_2_c_kdtree.test.py

Depends on

Code

# verification-helper: PROBLEM https://onlinejudge.u-aizu.ac.jp/courses/library/3/DSL/2/DSL_2_C

def main():
    N, = read()
    pts = [read() for _ in range(N)]

    kdtree = KDTree(pts)

    Q, = read()
    for _ in range(Q):
        sx,tx,sy,ty = read()
        tx += 1
        ty += 1
        ans = sorted(kdtree[sx:tx,sy:ty]) + ['']
        print(*ans, sep='\n')

from cp_library.ds.kdtree_cls import KDTree
from cp_library.io.read_fn import read

if __name__ == '__main__':
    main()
# verification-helper: PROBLEM https://onlinejudge.u-aizu.ac.jp/courses/library/3/DSL/2/DSL_2_C

def main():
    N, = read()
    pts = [read() for _ in range(N)]

    kdtree = KDTree(pts)

    Q, = read()
    for _ in range(Q):
        sx,tx,sy,ty = read()
        tx += 1
        ty += 1
        ans = sorted(kdtree[sx:tx,sy:ty]) + ['']
        print(*ans, sep='\n')

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
import typing
import random


class KDTreeNode:
    __slots__ = ['id', 'point', 'children']
    
    def __init__(self, id, point, children):
        self.id = id
        self.point = point
        self.children = children

class KDTree:
    __slots__ = ['k', 'nodes', 'root']

    def __init__(self, points):
        self.k = len(points[0])
        self.build_tree(points)

    def median_of_three(self, l, r, axis):
        m = (l + r) // 2
        a, b, c = self.nodes[l].point[axis], self.nodes[m].point[axis], self.nodes[r].point[axis]
        if a <= b <= c or c <= b <= a:
            return m
        if b <= a <= c or c <= a <= b:
            return l
        return r

    def partition(self, l, r, axis, pi):
        nodes = self.nodes
        nodes[pi], nodes[r] = nodes[r], nodes[pi]
        pi = l
        pivot = nodes[r].point[axis]
        for j in range(l, r):
            v = nodes[j].point[axis]
            if v < pivot or (v == pivot and j&1):
                nodes[pi], nodes[j] = nodes[j], nodes[pi]
                pi += 1
        nodes[pi], nodes[r] = nodes[r], nodes[pi]
        return pi

    def build_tree(self, points):
        self.nodes = [KDTreeNode(id, point, [None,None]) for id,point in enumerate(points)]
        root = KDTreeNode(-1, None, [None])
        stack = [(0, len(points)-1, 0, root, 0)]

        while stack:
            l, r, depth, parent, child = stack.pop()
            axis = depth % self.k
            
            pi = self.partition(l,r,axis, random.randint(l, r))

            parent.children[child] = self.nodes[pi]
 
            if pi < r:
                stack.append((pi + 1, r, depth+1, self.nodes[pi], 1))
            if l < pi:
                stack.append((l, pi - 1, depth+1, self.nodes[pi], 0))
        self.root = root.children[0]

    def __getitem__(self, ranges: typing.Tuple[slice]):
        result = []
        stack = [(self.root, 0)]

        while stack:
            node, depth = stack.pop()
            axis = depth % self.k

            # Check if the current point is within the range
            if all(ranges[i].start <= node.point[i] < ranges[i].stop for i in range(self.k)):
                result.append(node.id)

            # Check right subtree if necessary
            if node.children[1] and node.point[axis] < ranges[axis].stop:
                stack.append((node.children[1], depth + 1))

            # Check left subtree if necessary
            if node.children[0] and ranges[axis].start <= node.point[axis]:
                stack.append((node.children[0], depth + 1))

        return result
from typing import Type, Union, overload
from typing import TypeVar

_S = TypeVar('S'); _T = TypeVar('T'); _U = TypeVar('U'); _T1 = TypeVar('T1'); _T2 = TypeVar('T2'); _T3 = TypeVar('T3'); _T4 = TypeVar('T4'); _T5 = TypeVar('T5'); _T6 = TypeVar('T6')


@overload
def read() -> list[int]: ...
@overload
def read(spec: Type[_T], char=False) -> _T: ...
@overload
def read(spec: _U, char=False) -> _U: ...
@overload
def read(*specs: Type[_T], char=False) -> tuple[_T, ...]: ...
@overload
def read(*specs: _U, char=False) -> tuple[_U, ...]: ...
def read(*specs: Union[Type[_T],_T], char=False):
    IO.stdin.char = char
    if not specs: return IO.stdin.readnumsinto([])
    parser: _T = Parser.compile(specs[0] if len(specs) == 1 else specs)
    return parser(IO.stdin)
from os import read as os_read, write as os_write, fstat as os_fstat
import sys
from __pypy__.builders import StringBuilder



def max2(a, b): return a if a > b else b

class IOBase:
    @property
    def char(io) -> bool: ...
    @property
    def writable(io) -> bool: ...
    def __next__(io) -> str: ...
    def write(io, s: str) -> None: ...
    def readline(io) -> str: ...
    def readtoken(io) -> str: ...
    def readtokens(io) -> list[str]: ...
    def readints(io) -> list[int]: ...
    def readdigits(io) -> list[int]: ...
    def readnums(io) -> list[int]: ...
    def readchar(io) -> str: ...
    def readchars(io) -> str: ...
    def readinto(io, lst: list[str]) -> list[str]: ...
    def readcharsinto(io, lst: list[str]) -> list[str]: ...
    def readtokensinto(io, lst: list[str]) -> list[str]: ...
    def readintsinto(io, lst: list[int]) -> list[int]: ...
    def readdigitsinto(io, lst: list[int]) -> list[int]: ...
    def readnumsinto(io, lst: list[int]) -> list[int]: ...
    def wait(io): ...
    def flush(io) -> None: ...
    def line(io) -> list[str]: ...

class IO(IOBase):
    BUFSIZE = 1 << 16; stdin: 'IO'; stdout: 'IO'
    __slots__ = 'f', 'file', 'B', 'O', 'V', 'S', 'l', 'p', 'char', 'sz', 'st', 'ist', 'writable', 'encoding', 'errors'
    def __init__(io, file):
        io.file = file
        try: io.f = file.fileno(); io.sz, io.writable = max2(io.BUFSIZE, os_fstat(io.f).st_size), ('x' in file.mode or 'r' not in file.mode)
        except: io.f, io.sz, io.writable = -1, io.BUFSIZE, False
        io.B, io.O, io.S = bytearray(), [], StringBuilder(); io.V = memoryview(io.B); io.l = io.p = 0
        io.char, io.st, io.ist, io.encoding, io.errors = False, [], [], 'ascii', 'ignore'
    def _dec(io, l, r): return io.V[l:r].tobytes().decode(io.encoding, io.errors)
    def readbytes(io, sz): return os_read(io.f, sz)
    def load(io):
        while io.l >= len(io.O):
            if not (b := io.readbytes(io.sz)):
                if io.O[-1] < len(io.B): io.O.append(len(io.B))
                break
            pos = len(io.B); io.B.extend(b)
            while ~(pos := io.B.find(b'\n', pos)): io.O.append(pos := pos+1)
    def __next__(io):
        if io.char: return io.readchar()
        else: return io.readtoken()
    def readchar(io):
        io.load(); r = io.O[io.l]
        c = chr(io.B[io.p])
        if io.p >= r-1: io.p = r; io.l += 1
        else: io.p += 1
        return c
    def write(io, s: str): io.S.append(s)
    def readline(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p)
    def readtoken(io):
        io.load(); r = io.O[io.l]
        if ~(p := io.B.find(b' ', io.p, r)): s = io._dec(io.p, p); io.p = p+1
        else: s = io._dec(io.p, r-1); io.p = r; io.l += 1
        return s
    def readtokens(io): io.st.clear(); return io.readtokensinto(io.st)
    def readints(io): io.ist.clear(); return io.readintsinto(io.ist)
    def readdigits(io): io.ist.clear(); return io.readdigitsinto(io.ist)
    def readnums(io): io.ist.clear(); return io.readnumsinto(io.ist)
    def readchars(io): io.load(); l, io.p = io.p, io.O[io.l]; io.l += 1; return io._dec(l, io.p-1)
    def readinto(io, lst):
        if io.char: return io.readcharsinto(lst)
        else: return io.readtokensinto(lst)
    def readcharsinto(io, lst): lst.extend(io.readchars()); return lst
    def readtokensinto(io, lst): 
        io.load(); r = io.O[io.l]
        while ~(p := io.B.find(b' ', io.p, r)): lst.append(io._dec(io.p, p)); io.p = p+1
        lst.append(io._dec(io.p, r-1)); io.p = r; io.l += 1; return lst
    def _readint(io, r):
        while io.p < r and io.B[io.p] <= 32: io.p += 1
        if io.p >= r: return None
        minus = x = 0
        if io.B[io.p] == 45: minus = 1; io.p += 1
        while io.p < r and io.B[io.p] >= 48: x = x * 10 + (io.B[io.p] & 15); io.p += 1
        io.p += 1
        return -x if minus else x
    def readintsinto(io, lst):
        io.load(); r = io.O[io.l]
        while io.p < r and (x := io._readint(r)) is not None: lst.append(x)
        io.l += 1; return lst
    def _readdigit(io): d = io.B[io.p] & 15; io.p += 1; return d
    def readdigitsinto(io, lst):
        io.load(); r = io.O[io.l]
        while io.p < r and io.B[io.p] > 32: lst.append(io._readdigit())
        if io.B[io.p] == 10: io.l += 1
        io.p += 1
        return lst
    def readnumsinto(io, lst):
        if io.char: return io.readdigitsinto(lst)
        else: return io.readintsinto(lst)
    def line(io): io.st.clear(); return io.readinto(io.st)
    def wait(io):
        io.load(); r = io.O[io.l]
        while io.p < r: yield
    def flush(io):
        if io.writable: os_write(io.f, io.S.build().encode(io.encoding, io.errors)); io.S = StringBuilder()
sys.stdin = IO.stdin = IO(sys.stdin); sys.stdout = IO.stdout = IO(sys.stdout)
from numbers import Number
from types import GenericAlias 
from typing import Callable, Collection

class Parsable:
    @classmethod
    def compile(cls):
        def parser(io: 'IOBase'): return cls(next(io))
        return parser
    @classmethod
    def __class_getitem__(cls, item): return GenericAlias(cls, item)

class Parser:
    def __init__(self, spec):  self.parse = Parser.compile(spec)
    def __call__(self, io: IOBase): return self.parse(io)
    @staticmethod
    def compile_type(cls, args = ()):
        if issubclass(cls, Parsable): return cls.compile(*args)
        elif issubclass(cls, (Number, str)):
            def parse(io: IOBase): return cls(next(io))              
            return parse
        elif issubclass(cls, tuple): return Parser.compile_tuple(cls, args)
        elif issubclass(cls, Collection): return Parser.compile_collection(cls, args)
        elif callable(cls):
            def parse(io: IOBase): return cls(next(io))              
            return parse
        else: raise NotImplementedError()
    @staticmethod
    def compile(spec=int):
        if isinstance(spec, (type, GenericAlias)):
            cls, args = typing.get_origin(spec) or spec, typing.get_args(spec) or tuple()
            return Parser.compile_type(cls, args)
        elif isinstance(offset := spec, Number): 
            cls = type(spec)  
            def parse(io: IOBase): return cls(next(io)) + offset
            return parse
        elif isinstance(args := spec, tuple): return Parser.compile_tuple(type(spec), args)
        elif isinstance(args := spec, Collection): return Parser.compile_collection(type(spec), args)
        elif isinstance(fn := spec, Callable): 
            def parse(io: IOBase): return fn(next(io))
            return parse
        else: raise NotImplementedError()
    @staticmethod
    def compile_line(cls, spec=int):
        if spec is int:
            def parse(io: IOBase): return cls(io.readnums())
        elif spec is str:
            def parse(io: IOBase): return cls(io.line())
        else:
            fn = Parser.compile(spec)
            def parse(io: IOBase): return cls((fn(io) for _ in io.wait()))
        return parse
    @staticmethod
    def compile_repeat(cls, spec, N):
        fn = Parser.compile(spec)
        def parse(io: IOBase): return cls([fn(io) for _ in range(N)])
        return parse
    @staticmethod
    def compile_children(cls, specs):
        fns = tuple((Parser.compile(spec) for spec in specs))
        def parse(io: IOBase): return cls([fn(io) for fn in fns])  
        return parse
    @staticmethod
    def compile_tuple(cls, specs):
        if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...: return Parser.compile_line(cls, specs[0])
        else: return Parser.compile_children(cls, specs)
    @staticmethod
    def compile_collection(cls, specs):
        if not specs or len(specs) == 1 or isinstance(specs, set):
            return Parser.compile_line(cls, *specs)
        elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
            return Parser.compile_repeat(cls, specs[0], specs[1])
        else:
            raise NotImplementedError()

if __name__ == '__main__':
    main()
Back to top page