cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/data-structure/static_rectangle_add_rectangle_sum_bit_monoid.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/static_rectangle_add_rectangle_sum

def main():
    mod, s, m = 998244353, 31, (1 << 31)-1
    N, Q = read()
    N4, Q4 = N<<2, Q<<2
    X, Y, V, W = [0]*N4,[0]*N4,[0]*N4,[0]*N4
    Xq, Yq = [0]*Q4,[0]*Q4
    for i in range(N):
        l, d, r, u, w = read()
        X[i:=i<<2], Y[i], V[i], W[i] = l, d, (-l*w%mod)<<s|(-d*w%mod),(( w%mod)<<s)|( l*d%mod*w%mod)
        X[i:=i +1], Y[i], V[i], W[i] = l, u, ( l*w%mod)<<s|( u*w%mod),((-w%mod)<<s)|(-l*u%mod*w%mod)
        X[i:=i +1], Y[i], V[i], W[i] = r, d, ( r*w%mod)<<s|( d*w%mod),((-w%mod)<<s)|(-r*d%mod*w%mod)
        X[i:=i +1], Y[i], V[i], W[i] = r, u, (-r*w%mod)<<s|(-u*w%mod),(( w%mod)<<s)|( r*u%mod*w%mod)
    for i in range(Q):
        l, d, r, u = read()
        Xq[i:=i<<2], Yq[i] = l, d
        Xq[i:=i +1], Yq[i] = l, u
        Xq[i:=i +1], Yq[i] = r, d
        Xq[i:=i +1], Yq[i] = r, u
    OYq = Yq[:]
    icoord_compress_with_queries(Yq,Y,x=1)
    def op(a, b):
        v = a+b
        return ((v>>s)%mod)<<s|(v&m)%mod
    Vseg, Wseg = BITMonoid(op, 0, N4+Q4), BITMonoid(op, 0, N4+Q4)

    def poly_eval(x,y,v,w):
        v1, v2 = v>>s, v&m; w1, w2 = w>>s, w&m
        return (w2+y*v1+x*v2+x*y%mod*w1)%mod

    qans = [0]*Q4
    for i in argsort_multi(X+Xq,Y+Yq):
        if i < N4:
            Vseg.add(Y[i],V[i]); Wseg.add(Y[i],W[i])
        else:
            i -= N4
            qans[i] = poly_eval(Xq[i],OYq[i],Vseg.sum(Yq[i]),Wseg.sum(Yq[i]))
    for i in range(Q):
        ans = (qans[i:=i<<2]-qans[i:=i+1]-qans[i:=i+1]+qans[i:=i+1])%mod
        write(ans)

from cp_library.ds.tree.bit.bit_monoid_cls import BITMonoid
from cp_library.alg.iter.arg.argsort_multi_fn import argsort_multi
from cp_library.alg.iter.cmpr.icoord_compress_with_queries_fn import icoord_compress_with_queries
from cp_library.io.write_fn import write
from cp_library.io.read_fn import read

if __name__ == "__main__":
    main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/static_rectangle_add_rectangle_sum

def main():
    mod, s, m = 998244353, 31, (1 << 31)-1
    N, Q = read()
    N4, Q4 = N<<2, Q<<2
    X, Y, V, W = [0]*N4,[0]*N4,[0]*N4,[0]*N4
    Xq, Yq = [0]*Q4,[0]*Q4
    for i in range(N):
        l, d, r, u, w = read()
        X[i:=i<<2], Y[i], V[i], W[i] = l, d, (-l*w%mod)<<s|(-d*w%mod),(( w%mod)<<s)|( l*d%mod*w%mod)
        X[i:=i +1], Y[i], V[i], W[i] = l, u, ( l*w%mod)<<s|( u*w%mod),((-w%mod)<<s)|(-l*u%mod*w%mod)
        X[i:=i +1], Y[i], V[i], W[i] = r, d, ( r*w%mod)<<s|( d*w%mod),((-w%mod)<<s)|(-r*d%mod*w%mod)
        X[i:=i +1], Y[i], V[i], W[i] = r, u, (-r*w%mod)<<s|(-u*w%mod),(( w%mod)<<s)|( r*u%mod*w%mod)
    for i in range(Q):
        l, d, r, u = read()
        Xq[i:=i<<2], Yq[i] = l, d
        Xq[i:=i +1], Yq[i] = l, u
        Xq[i:=i +1], Yq[i] = r, d
        Xq[i:=i +1], Yq[i] = r, u
    OYq = Yq[:]
    icoord_compress_with_queries(Yq,Y,x=1)
    def op(a, b):
        v = a+b
        return ((v>>s)%mod)<<s|(v&m)%mod
    Vseg, Wseg = BITMonoid(op, 0, N4+Q4), BITMonoid(op, 0, N4+Q4)

    def poly_eval(x,y,v,w):
        v1, v2 = v>>s, v&m; w1, w2 = w>>s, w&m
        return (w2+y*v1+x*v2+x*y%mod*w1)%mod

    qans = [0]*Q4
    for i in argsort_multi(X+Xq,Y+Yq):
        if i < N4:
            Vseg.add(Y[i],V[i]); Wseg.add(Y[i],W[i])
        else:
            i -= N4
            qans[i] = poly_eval(Xq[i],OYq[i],Vseg.sum(Yq[i]),Wseg.sum(Yq[i]))
    for i in range(Q):
        ans = (qans[i:=i<<2]-qans[i:=i+1]-qans[i:=i+1]+qans[i:=i+1])%mod
        write(ans)

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
from typing import Callable, Generic, Union
from typing import TypeVar
_T = TypeVar('T')
_U = TypeVar('U')


'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
            ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓            
            ┃                                    7 ┃            
            ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━┛            
            ┏━━━━━━━━━━━━━━━━━━┓                 │              
            ┃                3 ┃◄────────────────┤              
            ┗━━━━━━━━━━━━━━━━┯━┛                 │              
            ┏━━━━━━━━┓       │  ┏━━━━━━━━┓       │              
            ┃      1 ┃◄──────┤  ┃      5 ┃◄──────┤              
            ┗━━━━━━┯━┛       │  ┗━━━━━━┯━┛       │              
            ┏━━━┓  │  ┏━━━┓  │  ┏━━━┓  │  ┏━━━┓  │              
            ┃ 0 ┃◄─┤  ┃ 2 ┃◄─┤  ┃ 4 ┃◄─┤  ┃ 6 ┃◄─┤              
            ┗━┯━┛  │  ┗━┯━┛  │  ┗━┯━┛  │  ┗━┯━┛  │              
              │    │    │    │    │    │    │    │              
              ▼    ▼    ▼    ▼    ▼    ▼    ▼    ▼              
            ┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓┏━━━┓            
            ┃ 0 ┃┃ 1 ┃┃ 2 ┃┃ 3 ┃┃ 4 ┃┃ 5 ┃┃ 6 ┃┃ 7 ┃            
            ┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛┗━━━┛            
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
           Data Structure - Tree - Binary Index Tree            
'''

class BITMonoid(Generic[_T]):
    def __init__(bit, op: Callable[[_T,_T],_T], e: _T, v: Union[int,list[_T]]):
        if isinstance(v, int): bit.d, bit.n = [e]*v, v
        else: bit.build(v)
        bit.op, bit.e = op, e

    def __len__(bit) -> int:
        return bit.n

    def build(bit, d: list[_T]) -> None:
        bit.d, bit.n = d, len(d)
        for i in range(bit.n):
            if (r := i|(i+1)) < bit.n: d[r] = bit.op(d[i], d[r])

    def add(bit, i: int, x: _T) -> None:
        assert 0 <= i < bit.n
        while i < bit.n:
            bit.d[i] = bit.op(bit.d[i], x)
            i |= i+1

    def sum(bit, r: int) -> _T:
        assert 0 <= r <= bit.n
        s = bit.e
        while r: s, r = bit.op(s,bit.d[r-1]), r&r-1
        return s
       
    def prelist(bit) -> list[_T]:
        pre = [bit.e]+bit.d
        for i in range(bit.n+1): pre[i] = bit.op(pre[i&(i-1)], pre[i])
        return pre

    def bisect_left(bit, v) -> int:
        if v <= bit.e: return 0
        i, s = 0, bit.e
        ni = m = bit.lb
        while m:
            if ni <= bit.n and (ns:=bit.op(s,bit.d[ni-1])) < v: s, i = ns, ni
            ni = (m:=m>>1)|i
        return i
    
    def bisect_right(bit, v) -> int:
        i, s = 0, bit.e
        ni = m = bit.lb
        while m:
            if ni <= bit.n and (ns:=bit.op(s,bit.d[ni-1])) <= v: s, i = ns, ni
            ni = (m:=m>>1)|i
        return i




def argsort_multi(*A: list[int], reverse=False):
    s, m = pack_sm((N:=len(A[0]))-1)
    I, J = [0]*N, [*range(N)]
    if reverse:
        V = [a<<s|m^i for i,a in enumerate(A[-1])]; V.sort(reverse=True)
        for k in range(len(A)-2,-1,-1):
            B = A[k]
            for i,v in enumerate(V):V[i],I[i]=B[j:=J[m^v&m]]<<s|m^i,j
            I,J=J,I;V.sort(reverse=True)
        for i,v in enumerate(V):I[i]=J[m^v&m]
    else:
        V = [a<<s|i for i,a in enumerate(A[-1])]; V.sort()
        for k in range(len(A)-2,-1,-1):
            B = A[k]
            for i,v in enumerate(V):V[i],I[i]=B[j:=J[v&m]]<<s|i,j
            I,J=J,I;V.sort()
        for i,v in enumerate(V):I[i]=J[v&m]
    return I


def pack_sm(N: int): s=N.bit_length(); return s,(1<<s)-1


def max2(a, b):
    return a if a > b else b


def icoord_compress_with_queries(*A: list[int], x=0, distinct=False):
    N = mx = 0
    for Ai in A: N += len(Ai); mx = max2(mx, len(Ai))
    si, mi = pack_sm(mx-1); sj, mj = pack_sm((len(A)-1)<<si)
    S, k = [0]*N, 0
    for i,Ai in enumerate(A):
        for j,a in enumerate(Ai): S[k]=a << sj | i << si | j; k += 1
    S.sort(); r = p = -1
    for aji in S:
        a, i, j = aji >> sj, (aji&mj) >> si , aji & mi
        if x<=i and (distinct or a != p): r = r+1; p = a
        A[i][j] = r+(i<x)
    return A

import os
import sys
from io import BytesIO, IOBase


class FastIO(IOBase):
    BUFSIZE = 8192
    newlines = 0

    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None

    def read(self):
        BUFSIZE = self.BUFSIZE
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()

    def readline(self):
        BUFSIZE = self.BUFSIZE
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()

    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)


class IOWrapper(IOBase):
    stdin: 'IOWrapper' = None
    stdout: 'IOWrapper' = None
    
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable

    def write(self, s):
        return self.buffer.write(s.encode("ascii"))
    
    def read(self):
        return self.buffer.read().decode("ascii")
    
    def readline(self):
        return self.buffer.readline().decode("ascii")
try:
    sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
    sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
except:
    pass

def write(*args, **kwargs):
    '''Prints the values to a stream, or to stdout_fast by default.'''
    sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IOWrapper.stdout)
    at_start = True
    for x in args:
        if not at_start:
            file.write(sep)
        file.write(str(x))
        at_start = False
    file.write(kwargs.pop("end", "\n"))
    if kwargs.pop("flush", False):
        file.flush()

from typing import Iterable, Type, Union, overload
import typing
from collections import deque
from numbers import Number
from types import GenericAlias 
from typing import Callable, Collection, Iterator, Union

class TokenStream(Iterator):
    stream = IOWrapper.stdin

    def __init__(self):
        self.queue = deque()

    def __next__(self):
        if not self.queue: self.queue.extend(self._line())
        return self.queue.popleft()
    
    def wait(self):
        if not self.queue: self.queue.extend(self._line())
        while self.queue: yield
 
    def _line(self):
        return TokenStream.stream.readline().split()

    def line(self):
        if self.queue:
            A = list(self.queue)
            self.queue.clear()
            return A
        return self._line()
TokenStream.default = TokenStream()

class CharStream(TokenStream):
    def _line(self):
        return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()


ParseFn = Callable[[TokenStream],_T]
class Parser:
    def __init__(self, spec: Union[type[_T],_T]):
        self.parse = Parser.compile(spec)

    def __call__(self, ts: TokenStream) -> _T:
        return self.parse(ts)
    
    @staticmethod
    def compile_type(cls: type[_T], args = ()) -> _T:
        if issubclass(cls, Parsable):
            return cls.compile(*args)
        elif issubclass(cls, (Number, str)):
            def parse(ts: TokenStream): return cls(next(ts))              
            return parse
        elif issubclass(cls, tuple):
            return Parser.compile_tuple(cls, args)
        elif issubclass(cls, Collection):
            return Parser.compile_collection(cls, args)
        elif callable(cls):
            def parse(ts: TokenStream):
                return cls(next(ts))              
            return parse
        else:
            raise NotImplementedError()
    
    @staticmethod
    def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
        if isinstance(spec, (type, GenericAlias)):
            cls = typing.get_origin(spec) or spec
            args = typing.get_args(spec) or tuple()
            return Parser.compile_type(cls, args)
        elif isinstance(offset := spec, Number): 
            cls = type(spec)  
            def parse(ts: TokenStream): return cls(next(ts)) + offset
            return parse
        elif isinstance(args := spec, tuple):      
            return Parser.compile_tuple(type(spec), args)
        elif isinstance(args := spec, Collection):
            return Parser.compile_collection(type(spec), args)
        elif isinstance(fn := spec, Callable): 
            def parse(ts: TokenStream): return fn(next(ts))
            return parse
        else:
            raise NotImplementedError()

    @staticmethod
    def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
        if spec is int:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
            return parse
        else:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
            return parse

    @staticmethod
    def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
        fn = Parser.compile(spec)
        def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
        return parse

    @staticmethod
    def compile_children(cls: _T, specs) -> ParseFn[_T]:
        fns = tuple((Parser.compile(spec) for spec in specs))
        def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])  
        return parse
            
    @staticmethod
    def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
        if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
            return Parser.compile_line(cls, specs[0])
        else:
            return Parser.compile_children(cls, specs)

    @staticmethod
    def compile_collection(cls, specs):
        if not specs or len(specs) == 1 or isinstance(specs, set):
            return Parser.compile_line(cls, *specs)
        elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
            return Parser.compile_repeat(cls, specs[0], specs[1])
        else:
            raise NotImplementedError()

class Parsable:
    @classmethod
    def compile(cls):
        def parser(ts: TokenStream): return cls(next(ts))
        return parser

@overload
def read() -> list[int]: ...
@overload
def read(spec: Type[_T], char=False) -> _T: ...
@overload
def read(spec: _U, char=False) -> _U: ...
@overload
def read(*specs: Type[_T], char=False) -> tuple[_T, ...]: ...
@overload
def read(*specs: _U, char=False) -> tuple[_U, ...]: ...
def read(*specs: Union[Type[_T],_U], char=False):
    if not char and not specs: return [int(s) for s in TokenStream.default.line()]
    parser: _T = Parser.compile(specs)
    ret = parser(CharStream.default if char else TokenStream.default)
    return ret[0] if len(specs) == 1 else ret

if __name__ == "__main__":
    main()
Back to top page