POJ 3233 Matrix Power Series

题目在这里>_<

题意

Given a $n \times n$ matrix $A$ and a positive integer $k$, find the sum $S = A + A^2 + A^3 + \cdots + A^k$.

思路

因为:
$S_k=A+A^2+A^3+ \cdots + A^k$
$S_{k+1}=AS_k+A$
所以有:

$$
\begin{bmatrix}
A & E \\\\
O & E
\end{bmatrix}
\begin{bmatrix}
S_k \\\\
A
\end{bmatrix}=
\begin{bmatrix}
S_{k+1} \\\\
A
\end{bmatrix}
$$

其中$A,S,E,O$均为$n \times n$的矩阵。

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <iostream>
#include <map>
#include <set>
//#define test
using namespace std;
const int Nmax=200;
typedef long long ll;
int mod;
int n,k;
struct Matrix
{
    int n,m;
    int map[Nmax][Nmax];
    Matrix(int x,int y)
    {
        n=x;m=y;
        for(int i=1;i<=n;i++)
            for(int j=1;j<=m;j++)
                map[i][j]=0;
    }
    Matrix operator * (const Matrix b)
    {
        Matrix c(n,b.m);
        if(m==b.n)
        {
            for(int i=1;i<=c.n;i++)
                for(int k=1;k<=m;k++)
                    for(int j=1;j<=c.m;j++)
                        c.map[i][j]=(c.map[i][j]+(map[i][k]*b.map[k][j])%mod)%mod;
            return c;
        }
        printf("error!!!!!!!!!!!!!!\n");   
        exit(0);
        return c;
    }
    void print(int nn,int mm)
    {
        //printf("n:%d m:%d\n",n,m);
        for(int i=1;i<=nn;i++)
            for(int j=1;j<=mm;j++)
                printf("%d%c",map[i][j],j==mm?'\n':' ');
    }
    void show()
    {
        printf("n:%d m:%d\n",n,m);
        for(int i=1;i<=n;i++)
            for(int j=1;j<=m;j++)
                printf("%d%c",map[i][j],j==m?'\n':' ');
    }
};
int mp[Nmax][Nmax];
int work()
{
    Matrix base(2*n,2*n);
    for(int i=1;i<=2*n;i++)
    {
        for(int j=1;j<=2*n;j++)
        {
            if(i<=n)
            {
                if(j<=n)
                    base.map[i][j]=mp[i][j];
                else
                {
                    if(j-n==i)
                        base.map[i][j]=1;
                }
            }
            else
            {
                if(j>n)
                {
                    if(j==i)
                        base.map[i][j]=1;
                }
            }

        }
    }
    //base.show();
    Matrix ans(2*n,2*n);
    for(int i=1;i<=2*n;i++)
        for(int j=1;j<=2*n;j++)
            ans.map[i][j]=i==j?1:0;
    k--;
    while(k>0)
    {
        if(k&1)
            ans=ans*base;
        base=base*base;
        k>>=1;
    }
    Matrix now(2*n,n);
    for(int i=1;i<=2*n;i++)
        for(int j=1;j<=n;j++)
        {
            if(i<=n)
                now.map[i][j]=mp[i][j];
            else 
                now.map[i][j]=mp[i-n][j];
        }
    //now.show();
    //ans.show();
    ans=ans*now;
    ans.print(n,n);
    //now.show();
    return 0;
}
int main()
{
    #ifdef test
    #endif
    //freopen("e.in","r",stdin);
    scanf("%d%d%d",&n,&k,&mod);
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            scanf("%d",&mp[i][j]);
    if(k==1)
    {
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++)
                printf("%d%c",mp[i][j],j==n?'\n':' ');
    }
    else 
        work(); 
    return 0;
}