这是一道模板题。
给你两个多项式,请输出乘起来后的多项式。
输入格式
第一行两个整数 nn 和 mm,分别表示两个多项式的次数。
第二行 n+1n+1 个整数,分别表示第一个多项式的 00 到 nn 次项前的系数。
第三行 m+1m+1 个整数,分别表示第一个多项式的 00 到 mm 次项前的系数。
输出格式
一行 n+m+1n+m+1 个整数,分别表示乘起来后的多项式的 00 到 n+mn+m 次项前的系数。
Solution
推荐一篇博客从多项式乘法到快速傅里叶变换(<–链接请点这里),这篇讲得非常好!
FFT:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> using namespace std; const double pi=acos(-1.0); int n,m,L,H; struct complex { double r,v; inline complex operator + (const complex &a) { return (complex){r+a.r,v+a.v};} inline complex operator - (const complex &a) { return (complex){r-a.r,v-a.v};} inline complex operator * (const complex &a) { return (complex){r*a.r-v*a.v,r*a.v+v*a.r};} }a[300100],b[300100],w[300100]; void FFT(complex *a,int f) { for (int i=0,j=0;i<L;i++) { if (i>j) swap(a[i],a[j]); for (int k=L>>1;(j^=k)<k;k>>=1); } for (int len=2;len<=L;len<<=1) { int l=len>>1; complex W=(complex){cos(pi/l),f*sin(pi/l)}; for (int i=1;i<l;i++) w[i]=w[i-1]*W; for (int i=0;i<L;i+=len) for (int j=0;j<l;j++) { complex x=a[i+j],y=w[j]*a[i+j+l]; a[i+j]=x+y;a[i+j+l]=x-y; } } if (f==-1) { for (int i=0;i<L;i++) a[i].r/=L; } } int main() { scanf("%d%d",&n,&m); n++;m++; w[0]=(complex){1.0,0.0}; for (int i=0;i<n;i++) scanf("%lf",&a[i].r); for (int i=0;i<m;i++) scanf("%lf",&b[i].r); L=1;H=0; while (L<n+m) L<<=1,H++; FFT(a,1);FFT(b,1); for (int i=0;i<L;i++) a[i]=a[i]*b[i]; FFT(a,-1); for (int i=0;i<n+m-1;i++) printf("%d%c",(int)(a[i].r+0.5)," \n"[i==n+m-2]); return 0; } |
NTT:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #define ll long long #define mo 998244353 #define G 3 using namespace std; int n,m,L,H; ll a[500100],b[500100],w[500100],INV; ll mi(ll x,int y) { ll ans=1; while (y) { if (y&1) ans=ans*x%mo; x=x*x%mo; y>>=1; } return ans; } void NTT(ll *a,int f) { for (int i=0,j=0;i<L;i++) { if (i>j) swap(a[i],a[j]); for (int k=L>>1;(j^=k)<k;k>>=1); } for (int len=2;len<=L;len<<=1) { int l=len>>1; ll W=mi(G,(mo-1)/len); for (int i=1;i<l;i++) w[i]=w[i-1]*W%mo; for (int i=0;i<L;i+=len) for (int j=0;j<l;j++) { ll x=a[i+j],y=w[j]*a[i+j+l]%mo; a[i+j]=(x+y>=mo)?x+y-mo:x+y;a[i+j+l]=(x-y<0)?x-y+mo:x-y; } } if (f==-1) { for (int i=1;i<L/2;i++) swap(a[i],a[L-i]); for (int i=0;i<L;i++) a[i]=a[i]*INV%mo; } } int main() { scanf("%d%d",&n,&m); n++;m++; w[0]=1; for (int i=0;i<n;i++) scanf("%lld",&a[i]); for (int i=0;i<m;i++) scanf("%lld",&b[i]); L=1;H=0; while (L<n+m) L<<=1,H++; INV=mi(L,mo-2)%mo; NTT(a,1);NTT(b,1); for (int i=0;i<L;i++) a[i]=a[i]*b[i]%mo; NTT(a,-1); for (int i=0;i<n+m-1;i++) printf("%lld%c",a[i]," \n"[i==n+m-2]); return 0; } |
任意模数NTT:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
#include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #define ll long long #define G 3 using namespace std; const int N=500100; const int mo1=469762049,mo2=998244353,mo3=1004535809; int n,m,mo,a[N],b[N],A[N],B[N],C[N],w[N],L,H; ll INV; ll mi(ll x,ll y,ll mo) { ll ans=1; while (y) { if (y&1) ans=ans*x%mo; x=x*x%mo; y>>=1; } return ans; } namespace ntt { int A[N],B[N]; void NTT(int *a,int f,int mo) { for (int i=0,j=0;i<L;i++) { if (i>j) swap(a[i],a[j]); for (int k=L>>1;(j^=k)<k;k>>=1); } for (int len=2;len<=L;len<<=1) { int l=len>>1; ll W=mi(G,(mo-1)/len,mo); for (int i=1;i<l;i++) w[i]=(ll)w[i-1]*W%mo; for (int i=0;i<L;i+=len) for (int j=0;j<l;j++) { ll x=a[i+j],y=(ll)w[j]*a[i+j+l]%mo; a[i+j]=(x+y>=mo)?x+y-mo:x+y;a[i+j+l]=(x-y<0)?x-y+mo:x-y; } } if (f==-1) { for (int i=1;i<L/2;i++) swap(a[i],a[L-i]); for (int i=0;i<L;i++) a[i]=(ll)a[i]*INV%mo; } } void pre(int l,int mo) { w[0]=1; L=1,H=0; while (L<l) L<<=1,H++; INV=mi(L,mo-2,mo)%mo; for (int i=0;i<L;i++) A[i]=B[i]=0; } void calc(int *a,int *b,int n,int m,int *C,int mo) { pre(n+m,mo); for (int i=0;i<L;i++) A[i]=a[i],B[i]=b[i]; NTT(A,1,mo);NTT(B,1,mo); for (int i=0;i<L;i++) A[i]=(ll)A[i]*B[i]%mo; NTT(A,-1,mo); for (int i=0;i<L;i++) C[i]=A[i]; } } int main() { scanf("%d%d%d",&n,&m,&mo); n++;m++; for (int i=0;i<n;i++) scanf("%d",&a[i]); for (int i=0;i<m;i++) scanf("%d",&b[i]); ntt::calc(a,b,n,m,A,mo1); ntt::calc(a,b,n,m,B,mo2); ntt::calc(a,b,n,m,C,mo3); for (int i=0;i<n+m-1;i++) { ll lcm=(ll)mo1*mo2; ll X=(((ll)(B[i]-A[i])%mo2+mo2)%mo2*mi(mo1,mo2-2,mo2)%mo2*mo1%lcm+A[i])%lcm; ll Y=(((ll)(C[i]-X)%mo3+mo3)%mo3*mi(lcm%mo3,mo3-2,mo3)%mo3*(lcm%mo)%mo+X%mo)%mo; printf("%lld%c",Y," \n"[i==n+m-2]); } return 0; } |