cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kobejean/cp-library

:heavy_check_mark: test/library-checker/linear-algebra/pow_of_matrix_modmat.test.py

Depends on

Code

# verification-helper: PROBLEM https://judge.yosupo.jp/problem/pow_of_matrix

def main():
    mint.set_mod(998244353)
    N, K = read()
    A = ModMat([read() for _ in range(N)])
    B = A**K
    write(B)

from cp_library.math.mod.mint_cls import mint
from cp_library.math.linalg.mat.mod.modmat_cls import ModMat
from cp_library.io.read_int_fn import read
from cp_library.io.write_fn import write

if __name__ == '__main__':
    main()
# verification-helper: PROBLEM https://judge.yosupo.jp/problem/pow_of_matrix

def main():
    mint.set_mod(998244353)
    N, K = read()
    A = ModMat([read() for _ in range(N)])
    B = A**K
    write(B)

'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
    
class mint(int):
    mod: int
    zero: 'mint'
    one: 'mint'
    two: 'mint'
    cache: list['mint']

    def __new__(cls, *args, **kwargs):
        if 0<= (x := int(*args, **kwargs)) <= 2:
            return cls.cache[x]
        else:
            return cls.fix(x)

    @classmethod
    def set_mod(cls, mod: int):
        mint.mod = cls.mod = mod
        mint.zero = cls.zero = cls.cast(0)
        mint.one = cls.one = cls.fix(1)
        mint.two = cls.two = cls.fix(2)
        mint.cache = cls.cache = [cls.zero, cls.one, cls.two]

    @classmethod
    def fix(cls, x): return cls.cast(x%cls.mod)

    @classmethod
    def cast(cls, x): return super().__new__(cls,x)

    @classmethod
    def mod_inv(cls, x):
        a,b,s,t = int(x), cls.mod, 1, 0
        while b: a,b,s,t = b,a%b,t,s-a//b*t
        if a == 1: return cls.fix(s)
        raise ValueError(f"{x} is not invertible in mod {cls.mod}")
    
    @property
    def inv(self): return mint.mod_inv(self)

    def __add__(self, x): return mint.fix(super().__add__(x))
    def __radd__(self, x): return mint.fix(super().__radd__(x))
    def __sub__(self, x): return mint.fix(super().__sub__(x))
    def __rsub__(self, x): return mint.fix(super().__rsub__(x))
    def __mul__(self, x): return mint.fix(super().__mul__(x))
    def __rmul__(self, x): return mint.fix(super().__rmul__(x))
    def __floordiv__(self, x): return self * mint.mod_inv(x)
    def __rfloordiv__(self, x): return self.inv * x
    def __truediv__(self, x): return self * mint.mod_inv(x)
    def __rtruediv__(self, x): return self.inv * x
    def __pow__(self, x): 
        return self.cast(super().__pow__(x, self.mod))
    def __neg__(self): return mint.mod-self
    def __pos__(self): return self
    def __abs__(self): return self
from typing import Union, List, Tuple

class ModMat:
    __slots__ = 'data', 'R', 'C'

    def __init__(self, data: List[Union[int,mint]]):
        self.data, self.R, self.C = data, len(data), len(data[0])
    
    @classmethod
    def identity(cls, N) -> 'ModMat': return ModMat([[int(i==j) for j in range(N)] for i in range(N)])
    
    @classmethod
    def zeros(cls, R, C) -> 'ModMat': return ModMat([[0]*C for _ in range(R)])

    def inv(self) -> 'ModMat':
        assert self.R != self.C
        
        N = self.R
        A = [row[:] for row in self.data]
        I = [[int(i==j) for j in range(N)] for i in range(N)]
        
        for i in range(N):
            if A[i][i] == 0:
                for j in range(i+1, N):
                    if A[j][i] != 0:
                        A[i], A[j] = A[j], A[i]
                        I[i], I[j] = I[j], I[i]
                        break
                else:
                    raise ValueError("Matrix is not invertible")
            
            inv = pow(A[i][i], -1, mint.mod)
            for j in range(N):
                A[i][j] = (A[i][j] * inv) % mint.mod
                I[i][j] = (I[i][j] * inv) % mint.mod
            
            for j in range(N):
                if i != j:
                    factor = A[j][i]
                    for k in range(N):
                        A[j][k] = (A[j][k] - factor * A[i][k]) % mint.mod
                        I[j][k] = (I[j][k] - factor * I[i][k]) % mint.mod
        
        return ModMat(I)
    
    def T(self) -> 'ModMat': return ModMat(list(map(list,zip(*self.data))))

    def elem_wise(self, func, other):
        if isinstance(other, ModMat):
            return ModMat([[func(a,b) for a,b in zip(Ai,Bi)] for Ai,Bi in zip(self.data,other.data)])
        elif isinstance(other, int):
            return ModMat([[func(a,other) for a in Ai] for Ai in self.data])
        else:
            return NotImplemented
        
    def __str__(self): return '\n'.join(' '.join(map(str,row)) for row in self.data)
    def __iter__(self): return self.data
    def __copy__(self): return ModMat([row[:] for row in self.data])
    def copy(self): return ModMat([row[:] for row in self.data])
    def __add__(self, other): return self.elem_wise(lambda a,b: (a+b) % mint.mod, other)
    def __radd__(self, other): return self.__add__(other)
    def __sub__(self, other): return self.elem_wise(lambda a,b: (a-b) % mint.mod, other)
    def __rsub__(self, other): return self.__sub__(other)
    def __mul__(self, other): return self.elem_wise(lambda a,b: (a*b) % mint.mod, other)
    def __rmul__(self, other): return self.__mul__(other)
    def __truediv__(self, other): return self.elem_wise(lambda a,b: a*pow(b,-1,mint.mod) % mint.mod, other)
    def __rtruediv__(self, other): return self.elem_wise(lambda a,b: pow(a,-1,mint.mod)*b % mint.mod, other)
    
    def __matmul__(self, other: 'ModMat'):
        assert self.C == other.R
        R = [[0]*other.C for _ in range(self.R)]
        for i,Ri in enumerate(R):
            for k,Aik in enumerate(self.data[i]):
                for j,Bkj in enumerate(other.data[k]):
                    Ri[j] = (Ri[j] + Aik*Bkj) % mint.mod
        return ModMat(R)
    
    def __pow__(self, K):
        assert isinstance(K,int)
        assert self.R == self.C
        A = self.copy()
        R = A if K & 1 else ModMat.identity(self.R)
        for i in range(1,K.bit_length()):
            A @= A 
            if K >> i & 1:
                R @= A 
        return R
    
    def __getitem__(self, key: Union[int, Tuple[int, int], slice, Tuple[slice, slice]]):
        if isinstance(key, int):
            return self.data[key]
        elif isinstance(key, tuple):
            if len(key) == 2:
                if all(isinstance(k, int) for k in key):
                    return mint(self.data[key[0]][key[1]])
                elif all(isinstance(k, slice) for k in key):
                    return ModMat([[self.data[i][j] for j in range(*key[1].indices(self.C))] 
                                   for i in range(*key[0].indices(self.R))])
            raise IndexError("Invalid index")
        elif isinstance(key, slice):
            return ModMat([row[:] for row in self.data[key]])
        raise IndexError("Invalid index")

    def __setitem__(self, key: Union[Tuple[int, int], slice, Tuple[slice, slice]], value):
        if isinstance(key, tuple):
            if len(key) == 2:
                if all(isinstance(k, int) for k in key):
                    self.data[key[0]][key[1]] = value % mint.mod
                elif all(isinstance(k, slice) for k in key):
                    if isinstance(value, ModMat):
                        for i, row in enumerate(range(*key[0].indices(self.R))):
                            for j, col in enumerate(range(*key[1].indices(self.C))):
                                self.data[row][col] = value.data[i][j] % mint.mod
                    else:
                        for row in range(*key[0].indices(self.R)):
                            for col in range(*key[1].indices(self.C)):
                                self.data[row][col] = value % mint.mod
                else:
                    raise IndexError("Invalid index")
            else:
                raise IndexError("Invalid index")
        elif isinstance(key, slice):
            if isinstance(value, ModMat):
                for i, row in enumerate(range(*key.indices(self.R))):
                    self.data[row] = [v % mint.mod for v in value.data[i]]
            else:
                for row in range(*key.indices(self.R)):
                    self.data[row] = [value % mint.mod] * self.C
        else:
            raise IndexError("Invalid index")

    def __delitem__(self, key: Union[int, slice]):
        if isinstance(key, (int, slice)):
            del self.data[key]
            self.R = len(self.data)
            if self.R == 0:
                self.C = 0
        else:
            raise IndexError("Invalid index")

    


def read(shift=0, base=10):
    return [int(s, base) + shift for s in input().split()]
import os
import sys
from io import BytesIO, IOBase


class FastIO(IOBase):
    BUFSIZE = 8192
    newlines = 0

    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None

    def read(self):
        BUFSIZE = self.BUFSIZE
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()

    def readline(self):
        BUFSIZE = self.BUFSIZE
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()

    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)


class IOWrapper(IOBase):
    stdin: 'IOWrapper' = None
    stdout: 'IOWrapper' = None
    
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable

    def write(self, s):
        return self.buffer.write(s.encode("ascii"))
    
    def read(self):
        return self.buffer.read().decode("ascii")
    
    def readline(self):
        return self.buffer.readline().decode("ascii")

sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)

def write(*args, **kwargs):
    '''Prints the values to a stream, or to stdout_fast by default.'''
    sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IOWrapper.stdout)
    at_start = True
    for x in args:
        if not at_start:
            file.write(sep)
        file.write(str(x))
        at_start = False
    file.write(kwargs.pop("end", "\n"))
    if kwargs.pop("flush", False):
        file.flush()

if __name__ == '__main__':
    main()
Back to top page