抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

NTT

有时候,题目要求对一个大质数(特别是998244353之类的数)取模,就不能用FFT,而是用NTT。

NTT采用原根替代单位根,如果不了解原根,请参考这篇博客

常见NTT质数表:

a×2b+1a\times 2^b+1 a b g
3 1 1 2
5 1 2 2
17 1 4 3
97 3 5 5
193 3 6 5
257 1 8 3
7681 15 9 17
12289 3 12 11
40961 5 13 3
65537 1 16 3
786433 3 18 10
5767169 11 19 3
7340033 7 20 3
23068673 11 21 3
104857601 25 22 3

采用NTT编写A*B problem。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>
#define MAXN 240005
#define MOD 998244353
#define invG 332748118 //G的逆元
#define G 3
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<3)+(x<<1)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int ksm(int b,int k){
int ans=1;
while (k){
if (k&1) ans=(1ll*ans*b)%MOD;
b=(1ll*b*b)%MOD;
k>>=1;
}
return ans;
}
static int r[MAXN],a[MAXN],b[MAXN];
inline void NTT(int *A,int n,int type){
for (register int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
for (register int i=1;i<n;i<<=1){
int R=i<<1;
int Gn=ksm(type==1?G:invG,(MOD-1)/R);
for (register int j=0;j<n;j+=R){
int g=1;
for (register int k=0;k<i;++k,g=(1ll*g*Gn)%MOD){
int x=A[j+k],y=(1ll*g*A[i+j+k])%MOD;
A[j+k]=(x+y)%MOD,A[i+j+k]=(x-y+MOD)%MOD;
}
}
}
}
char s1[MAXN],s2[MAXN];
int ans[MAXN];
int main(){
int n=read();
scanf("%s%s",s1+1,s2+1);
for (register int i=1;i<=n;++i) a[i-1]=s1[n-i+1]-'0';
for (register int i=1;i<=n;++i) b[i-1]=s2[n-i+1]-'0';
int m=1,L=0;
while (m<=2*n) m<<=1,L++;
for (register int i=0;i<=m;++i){
r[i]=(r[i>>1]>>1|((i&1)<<(L-1)));
}
NTT(a,m,1),NTT(b,m,1);
for (register int i=0;i<=m;++i) a[i]=(1ll*a[i]*b[i])%MOD;
NTT(a,m,-1);
int inv=ksm(m,MOD-2);
for (register int i=0;i<=m;++i){
ans[i]+=(1ll*a[i]*inv)%MOD;//要乘上m的逆元
ans[i+1]+=ans[i]/10,ans[i]%=10;
}
while (ans[m]==0) m--;
for (register int i=m;i>=0;--i) putchar(ans[i]+'0');
}

常见多项式运算

这里才是重点部分。

多项式求逆

传送门

给你一个多项式AA,要你求出一个多项式BB,满足AB=1(modxn)AB=1\pmod {x^n}

系数对998244353998244353取模。

首先,理解(modxn)\pmod {x^n}是为了去掉后面的部分。

下文中的除法全部代表向下取整。

我们假设求出了一个BB',满足AB=1(modxn2)AB'=1\pmod{x^{\frac{n}{2}}},我们要求BB,满足AB=1(modxn)AB=1 \pmod{x^n}。其实就是一个小范围的解推出大范围的解。

那么我们有后面的部分BB=0(modxn2)B-B'=0 \pmod{x^{\frac{n}{2}}}

根据套路,我们要把后面的模数搞成xnx^n,于是两边平方

(BB)2=0(modxn)(B-B')^2=0\pmod{x^n}

拆开B22BB+B2=0(modxn)B^2 -2BB'+B'^2=0 \pmod{x^n}

两边同时乘AA,发现可以消掉很多。

AB22ABB+AB2=0(modxn)AB^2-2ABB'+AB'^2=0\pmod{x^n}

显然AB=1AB=1,可以消掉,变成:

B2B+AB2=0(modxn)B-2B'+AB'^2=0 \pmod{x^n}

得出B=2BAB2=B(2AB)(modxn)B=2B'-AB'^2=B'(2-AB') \pmod{x^n}

于是可以根据这个公式从一个小的解推出大的。

别忘了边界条件A[0]=B[0]1A[0]=B[0]^{-1}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <bits/stdc++.h>
#define MAXN 1000005
#define MOD 998244353
#define invG 332748118
#define G 3
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<3)+(x<<1)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int ksm(int b,int k){
int ans=1;
while (k){
if (k&1) ans=(1ll*ans*b)%MOD;
b=(1ll*b*b)%MOD;
k>>=1;
}
return ans;
}
int r[MAXN],C[MAXN];
inline void NTT(int *A,int n,int type){
for (register int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
for (register int i=1;i<n;i<<=1){
int R=i<<1;
int Gn=ksm(type==1?G:invG,(MOD-1)/R);
for (register int j=0;j<n;j+=R){
int g=1;
for (register int k=0;k<i;++k,g=(1ll*g*Gn)%MOD){
int x=A[j+k],y=(1ll*g*A[i+j+k])%MOD;
A[j+k]=(x+y)%MOD,A[i+j+k]=(x-y+MOD)%MOD;
}
}
}
}
inline int get_inv(int x){
return ksm(x,MOD-2);
}
int m,L;
inline void Init(int len){
m=1,L=0;
while (m<2*len) m<<=1,L++;
for (register int i=0;i<m;++i){
r[i]=(r[i>>1]>>1|((i&1)<<(L-1)));
}
}
inline void Inv(int *A,int *B,int len){
if (len==1){
B[0]=get_inv(A[0]);
return ;
}
Inv(A,B,(len+1)>>1);//注意是向上取整
Init(len);
for (register int i=0;i<len;++i) C[i]=A[i];//只用前面一部分
for (register int i=len;i<m;++i) C[i]=0;
NTT(C,m,1),NTT(B,m,1);
for (register int i=0;i<m;++i){
B[i]=(2ll-1ll*B[i]*C[i]%MOD+MOD)*B[i]%MOD;
}
NTT(B,m,-1);
int inv=get_inv(m);
for (register int i=0;i<len;++i) B[i]=(1ll*B[i]*inv)%MOD;//推完之后B'->B
for (register int i=len;i<m;++i) B[i]=0;//多出来的部分要舍去
}
int F[MAXN],Ans[MAXN];
int main(){
int n=read();
for (register int i=0;i<n;++i) F[i]=read();
Inv(F,Ans,n);
for (register int i=0;i<n;++i) printf("%d ",Ans[i]);
}

分治FFT

传送门

此题我们使用生成函数做法。

构造生成函数f(x)=i=0f[i]×xi,g(x)=i=0g[i]×xif(x)=\sum_{i=0}^\infty f[i] \times x^i,g(x)=\sum_{i=0}^\infty g[i] \times x^i

注意到f(x)g(x)=f(x)f(x)*g(x)=f(x)

欸这个有问题吧,这个应该无解才对啊!

再看一眼题目,注意到f(x)×g(x)f(x) \times g(x)没有取到x0x^0

怎么办呢?补上一个11即可。(也可以这么理解,卷积起来之后整体后移了一位,所以要补上)

得到式子f(x)g(x)+1=f(x)f(x)*g(x)+1=f(x)

解得f(x)(1g(x))=1f(x)(1-g(x))=1

因为只用求前nn项,转化为f(x)(1g(x))=1(modxn)f(x)(1-g(x))=1 \pmod{x^n}

于是f(x)=(1g(x))1(modxn)f(x)=(1-g(x))^{-1} \pmod{x^n}

对于生成函数做法,还有一个解释,暴力做法是每次分别求出f[i]f[i],而生成函数做法是一起解出ff

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>
#define MAXN 1000005
#define MOD 998244353
#define invG 332748118
#define G 3
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<3)+(x<<1)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int ksm(int b,int k){
int ans=1;
while (k){
if (k&1) ans=(1ll*ans*b)%MOD;
b=(1ll*b*b)%MOD;
k>>=1;
}
return ans;
}
int r[MAXN],C[MAXN];
inline void NTT(int *A,int n,int type){
for (register int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
for (register int i=1;i<n;i<<=1){
int R=i<<1;
int Gn=ksm(type==1?G:invG,(MOD-1)/R);
for (register int j=0;j<n;j+=R){
int g=1;
for (register int k=0;k<i;++k,g=(1ll*g*Gn)%MOD){
int x=A[j+k],y=(1ll*g*A[i+j+k])%MOD;
A[j+k]=(x+y)%MOD,A[i+j+k]=(x-y+MOD)%MOD;
}
}
}
}
inline int get_inv(int x){
return ksm(x,MOD-2);
}
int m,L;
inline void Init(int len){
m=1,L=0;
while (m<2*len) m<<=1,L++;
for (register int i=0;i<m;++i){
r[i]=(r[i>>1]>>1|((i&1)<<(L-1)));
}
}
inline void Inv(int *A,int *B,int len){
if (len==1){
B[0]=get_inv(A[0]);
return ;
}
Inv(A,B,(len+1)>>1);
Init(len);
for (register int i=0;i<len;++i) C[i]=A[i];
for (register int i=len;i<m;++i) C[i]=0;
NTT(C,m,1),NTT(B,m,1);
for (register int i=0;i<m;++i){
B[i]=(2ll-1ll*B[i]*C[i]%MOD+MOD)*B[i]%MOD;
}
NTT(B,m,-1);
int inv=get_inv(m);
for (register int i=0;i<len;++i) B[i]=(1ll*B[i]*inv)%MOD;
for (register int i=len;i<m;++i) B[i]=0;
}
int F[MAXN],Ans[MAXN];
int main(){
int n=read();
F[0]=1;
for (register int i=1;i<n;++i) F[i]=(-read()+MOD)%MOD;
Inv(F,Ans,n);
for (register int i=0;i<n;++i) printf("%d ",Ans[i]);
}

多项式开根

传送门

给你一个n1n-1次多项式A(x)A(x),求B(x)B(x),使得B(x)2=A(x)(modxn)B(x)^2=A(x) \pmod{x^n}

还是多项式求逆的套路。

考虑现在已经求出B(x)B'(x),使得B(x)2=A(x)(modxn2)B'(x) ^2 = A(x) \pmod{x^{\frac{n}{2}}}

求出一个B(x)B(x)使得B(x)2=A(x)(modxn)B(x)^2=A(x)\pmod{x^{n}}

两边相减

B(x)2B(x)2=0(modxn2)B(x)^2-B'(x)^2=0\pmod{x^{\frac{n}{2}}}

(B(x)+B(x))(B(x)B(x))=0(modxn2)(B(x)+B'(x))(B(x)-B'(x))=0 \pmod{x^{\frac{n}{2}}}

显然B(x)+B(x)0B(x)+B'(x)\not=0

B(x)B(x)=0(modxn2)B(x)-B'(x)=0\pmod{x^\frac{n}{2}}

还是老套路,两边平方。

B(x)22B(x)B(x)+B(x)2=0(modxn)B(x)^2-2B(x)B'(x)+B'(x)^2=0\pmod{x^n}

注意到B(x)2=A(x)B(x)^2=A(x)

代入得:

A(x)2B(x)B(x)+B(x)2=0(modxn)A(x)-2B(x)B'(x)+B'(x)^2=0\pmod{x^n}

所以B(x)=A(x)+B(x)22B(x)=inv(2)(A(x)B1(x)+B(x))B(x)=\frac{A(x)+B'(x)^2}{2B'(x)}=inv(2) *(A(x)B'^{-1}(x)+B'(x))

再次套用多项式求逆的模板即可。

边界B[0]=A[0]=1B[0]=\sqrt {A[0]}=1

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <bits/stdc++.h>
#define MAXN 1000005
#define MOD 998244353
#define invG 332748118
#define G 3
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<3)+(x<<1)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int ksm(int b,int k){
int ans=1;
while (k){
if (k&1) ans=(1ll*ans*b)%MOD;
b=(1ll*b*b)%MOD;
k>>=1;
}
return ans;
}
int r[MAXN],C[MAXN];
inline void NTT(int *A,int n,int type){
for (register int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
for (register int i=1;i<n;i<<=1){
int R=i<<1;
int Gn=ksm(type==1?G:invG,(MOD-1)/R);
for (register int j=0;j<n;j+=R){
int g=1;
for (register int k=0;k<i;++k,g=(1ll*g*Gn)%MOD){
int x=A[j+k],y=(1ll*g*A[i+j+k])%MOD;
A[j+k]=(x+y)%MOD,A[i+j+k]=(x-y+MOD)%MOD;
}
}
}
}
inline int get_inv(int x){
return ksm(x,MOD-2);
}
int m,L;
inline void Init(int len){
m=1,L=0;
while (m<2*len) m<<=1,L++;
for (register int i=0;i<m;++i){
r[i]=(r[i>>1]>>1|((i&1)<<(L-1)));
}
}
inline void Inv(int *A,int *B,int len){
if (len==1){
B[0]=get_inv(A[0]);
return ;
}
Inv(A,B,(len+1)>>1);
Init(len);
for (register int i=0;i<len;++i) C[i]=A[i];
for (register int i=len;i<m;++i) C[i]=0;
NTT(C,m,1),NTT(B,m,1);
for (register int i=0;i<m;++i){
B[i]=(2ll-1ll*B[i]*C[i]%MOD+MOD)*B[i]%MOD;
}
NTT(B,m,-1);
int inv=get_inv(m);
for (register int i=0;i<len;++i) B[i]=(1ll*B[i]*inv)%MOD;
for (register int i=len;i<m;++i) B[i]=0;
}
int D[MAXN],inv2;
inline void Sqrt(int *A,int *B,int len){
if (len==1){
B[0]=A[0];
return ;//保证a_0=1
}
Sqrt(A,B,(len+1)>>1);
for (register int i=0;i<(len<<1);++i) D[i]=0;
Inv(B,D,len);//得出B^-1(x)
Init(len);
for (register int i=0;i<len;++i) C[i]=A[i];
for (register int i=len;i<m;++i) C[i]=0;
NTT(B,m,1),NTT(C,m,1),NTT(D,m,1);//B:目标 C:A(x) D:B^-1(x)
//注意要做三次NTT
for (register int i=0;i<m;++i){
B[i]=1ll*inv2*(1ll*C[i]*D[i]%MOD+B[i])%MOD;
}
NTT(B,m,-1);
int inv=get_inv(m);
for (register int i=0;i<len;++i) B[i]=(1ll*B[i]*inv)%MOD;
for (register int i=len;i<m;++i) B[i]=0;
}
int F[MAXN],Ans[MAXN];
int main(){
inv2=ksm(2,MOD-2);
int n=read();
for (register int i=0;i<n;++i) F[i]=read();
Sqrt(F,Ans,n);
for (register int i=0;i<n;++i) printf("%d ",Ans[i]);
}

例题

例题1

P4841 城市规划

此题即是数据加强版POJ1737,题解在此

这里不再讨论dpdp方程式,而是着重讨论优化方法。

我们有F(i)=2Ci2j=1i1F(j)×2Cij2×Ci1j1F(i)=2^{C_i^2} - \sum_{j=1}^{i-1} F(j) \times 2^{C_{i-j}^2} \times C_{i-1}^{j-1}

爆拆一波CC,有Ci1j1=(i1)!(ij)!(j1)!C_{i-1}^{j-1}=\frac{(i-1)!}{(i-j)!(j-1)!}

两边同时除(i1)!(i-1)!,得 F(i)(i1)!=2Ci2(i1)!j=1i1F(j)(j1)!×2Cij2(ij)!\frac{F(i)}{(i-1)!}=\frac{2^{C_i^2}}{(i-1)!}-\sum_{j=1}^{i-1} \frac{F(j)}{(j-1)!} \times \frac {2^{C_{i-j}^2}}{(i-j)!}(剩下的(ij)!(i-j)!(j1)!(j-1)!刚好每人分一个)

A(x)=F(i)(i1)!,B(x)=2Ci2(i1)!,C(x)=2Ci2i!A(x)=\frac{F(i)}{(i-1)!},B(x)=\frac{2^{C_{i}^2}}{(i-1)!},C(x)=\frac{2^{C_{i}^2}}{i!}

推导这里发现有个致命的错误,j=1i1\sum _{j=1}^{i-1}无法解决,必须重新构造式子。

把后面一坨扔到前面去,左边部分变成F(i)+j=1i1F(j)×2Cij2×Ci1j1F(i)+\sum_{j=1}^{i-1} F(j) \times 2^{C_{i-j}^2} \times C_{i-1}^{j-1}

注意到j=ij=i时,2Cij2=20=12^{C_{i-j}^2}=2^0=1Ci1j1=1C_{i-1}^{j-1}=1

于是可以把他们两个合并起来,变成j=1iF(j)×2Cij2×Ci1j1\sum_{j=1}^iF(j)\times 2^{C_{i-j}^2} \times C_{i-1}^{j-1}

等价于A(x)B(x)=C(x)A(x)B(x)=C(x)

于是得到A(x)=C(x)B(x)1A(x)=C(x) B(x)^{-1}

可以使用多项式求逆解决。

别忘了A(x)=F(i)(i1)!A(x)=\frac {F(i)}{(i-1)!},所以答案要乘(i1)!(i-1)!

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include <bits/stdc++.h>
#define MAXN 2000005
#define MOD 1004535809
#define invG 334845270
#define G 3
#define int long long
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<3)+(x<<1)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int ksm(int b,int k){
int ans=1;
while (k){
if (k&1) ans=(1ll*ans*b)%MOD;
b=(1ll*b*b)%MOD;
k>>=1;
}
return ans;
}
inline int get_inv(int x){
return ksm(x,MOD-2);
}
int r[MAXN],C[MAXN];
inline void NTT(int *A,int n,int type){
for (register int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
for (register int i=1;i<n;i<<=1){
int R=i<<1;
int Gn=ksm(type==1?G:invG,(MOD-1)/R);
for (register int j=0;j<n;j+=R){
int g=1;
for (register int k=0;k<i;++k,g=(1ll*g*Gn)%MOD){
int x=A[j+k],y=(1ll*g*A[i+j+k])%MOD;
A[j+k]=(x+y)%MOD,A[i+j+k]=(x-y+MOD)%MOD;
}
}
}
if (type==1) return ;
int inv=get_inv(n);
for (register int i=0;i<n;++i){
A[i]=1ll*A[i]*inv%MOD;
}
}
int m,L;
inline void InitNTT(int len){
m=1,L=0;
while (m<2*len) m<<=1,L++;
for (register int i=0;i<m;++i){
r[i]=(r[i>>1]>>1|((i&1)<<(L-1)));
}
}
inline void Inv(int *A,int *B,int len){
if (len==1){
B[0]=get_inv(A[0]);
return ;
}
Inv(A,B,(len+1)>>1);
InitNTT(len);
for (register int i=0;i<len;++i) C[i]=A[i];
for (register int i=len;i<m;++i) C[i]=0;
NTT(C,m,1),NTT(B,m,1);
for (register int i=0;i<m;++i){
B[i]=((2ll-1ll*B[i]*C[i]%MOD+MOD)*B[i]%MOD+MOD)%MOD;
}
NTT(B,m,-1);
for (register int i=len;i<m;++i) B[i]=0;
}
int F[MAXN];
int A[MAXN],B[MAXN],Ans[MAXN];
int fac[MAXN],inv_fac[MAXN],n;
int inv2=ksm(2,MOD-2);
//F=A*B^-1
inline int Calc(int x){
return 1ll*x*(x-1)/2;//不能取模,要取也只能是MOD-1
}
inline void Init(){
fac[0]=1;
for (register int i=1;i<=n;++i) fac[i]=1ll*fac[i-1]*i%MOD;
for (register int i=0;i<=n;++i) inv_fac[i]=ksm(fac[i],MOD-2);
}
#undef int
int main(){
#define int long long
n=read();
Init();
for (register int i=1;i<=n;++i){
A[i]=1ll*ksm(2,Calc(i))*inv_fac[i-1]%MOD;
}
for (register int i=0;i<=n;++i){
B[i]=1ll*ksm(2,Calc(i))*inv_fac[i]%MOD;
}
InitNTT(n);
Inv(B,Ans,m);

InitNTT(n);
NTT(A,m,1);NTT(Ans,m,1);
for (register int i=0;i<m;++i) A[i]=1ll*A[i]*Ans[i]%MOD;
NTT(A,m,-1);

printf("%lld\n",1ll*A[n]*fac[n-1]%MOD);
}

评论