题意
Given a $n \times n$ matrix $A$ and a positive integer $k$, find the sum $S = A + A^2 + A^3 + \cdots + A^k$.
思路
因为:
$S_k=A+A^2+A^3+ \cdots + A^k$
$S_{k+1}=AS_k+A$
所以有:
$$
\begin{bmatrix}
A & E \\\\
O & E
\end{bmatrix}
\begin{bmatrix}
S_k \\\\
A
\end{bmatrix}=
\begin{bmatrix}
S_{k+1} \\\\
A
\end{bmatrix}
$$
其中$A,S,E,O$均为$n \times n$的矩阵。
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <iostream>
#include <map>
#include <set>
//#define test
using namespace std;
const int Nmax=200;
typedef long long ll;
int mod;
int n,k;
struct Matrix
{
int n,m;
int map[Nmax][Nmax];
Matrix(int x,int y)
{
n=x;m=y;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
map[i][j]=0;
}
Matrix operator * (const Matrix b)
{
Matrix c(n,b.m);
if(m==b.n)
{
for(int i=1;i<=c.n;i++)
for(int k=1;k<=m;k++)
for(int j=1;j<=c.m;j++)
c.map[i][j]=(c.map[i][j]+(map[i][k]*b.map[k][j])%mod)%mod;
return c;
}
printf("error!!!!!!!!!!!!!!\n");
exit(0);
return c;
}
void print(int nn,int mm)
{
//printf("n:%d m:%d\n",n,m);
for(int i=1;i<=nn;i++)
for(int j=1;j<=mm;j++)
printf("%d%c",map[i][j],j==mm?'\n':' ');
}
void show()
{
printf("n:%d m:%d\n",n,m);
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
printf("%d%c",map[i][j],j==m?'\n':' ');
}
};
int mp[Nmax][Nmax];
int work()
{
Matrix base(2*n,2*n);
for(int i=1;i<=2*n;i++)
{
for(int j=1;j<=2*n;j++)
{
if(i<=n)
{
if(j<=n)
base.map[i][j]=mp[i][j];
else
{
if(j-n==i)
base.map[i][j]=1;
}
}
else
{
if(j>n)
{
if(j==i)
base.map[i][j]=1;
}
}
}
}
//base.show();
Matrix ans(2*n,2*n);
for(int i=1;i<=2*n;i++)
for(int j=1;j<=2*n;j++)
ans.map[i][j]=i==j?1:0;
k--;
while(k>0)
{
if(k&1)
ans=ans*base;
base=base*base;
k>>=1;
}
Matrix now(2*n,n);
for(int i=1;i<=2*n;i++)
for(int j=1;j<=n;j++)
{
if(i<=n)
now.map[i][j]=mp[i][j];
else
now.map[i][j]=mp[i-n][j];
}
//now.show();
//ans.show();
ans=ans*now;
ans.print(n,n);
//now.show();
return 0;
}
int main()
{
#ifdef test
#endif
//freopen("e.in","r",stdin);
scanf("%d%d%d",&n,&k,&mod);
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
scanf("%d",&mp[i][j]);
if(k==1)
{
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
printf("%d%c",mp[i][j],j==n?'\n':' ');
}
else
work();
return 0;
}