抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

可以这么写ww维前缀和(原文):

1
2
3
4
5
6
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
a[i][j]+=a[i][j-1];
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
a[i][j]+=a[i-1][j];
1
2
3
4
5
6
7
8
9
10
11
12
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
for(int k=1;k<=p;k++)
a[i][j][k]+=a[i-1][j][k];
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
for(int k=1;k<=p;k++)
a[i][j][k]+=a[i][j-1][k];
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
for(int k=1;k<=p;k++)
a[i][j][k]+=a[i][j][k-1];

这样其实是O(w×nw)O(w \times n^w),但是如果用容斥原理的话,是O(2w×nw)O(2^w \times n^w)2n=i=0nCni2^n = \sum _{i=0}^n C_n^i)。

注意到nwn^w相比2w2^w来说增长很快,所以这个优化很屑吗?

很多情况我们使用的是n=2n=2,这样第一种是O(w×2w)O(w \times 2^w),第二种是O(2w×2w)O(2^w \times 2^w)。这样差的就很大了。

比如说子集统计,有mm个大小为2020的集合SiS_i。(m1000000m \le 1000000)要你求F(S)=i=1m[SSi]F(S)=\sum _{i=1}^m [S \in S_i],容易看出w=20,n=2w=20,n=2,代表每个元素选或者不选。

可以写出以下的代码:

1
2
3
4
5
6
7
8
9
10
for(int i=1;i<=1;i++)
for(int j=0;j<=1;j++)
for(int k=0;k<=1;k++)
......
a[i][j][k][][][]+=a[i-1][j][k][][][];
for(int i=0;i<=1;i++)
for(int j=1;j<=1;j++)
for(int k=0;k<=1;k++)
......
a[i][j][k][][][]+=a[i][j-1][k][][][];

但是显然这样又臭又长。

考虑将a[][][][]a[][][][]数组后面的ww维压成11维,用二进制表示。

可以这么写:

1
2
3
4
5
for (int i=0;i<w;++i){
for (int j=0;j<max(a[i]);++j){
if (j&(1<<i)) f[j^(1<<i)]=(f[j^(1<<i)]+f[j])%MOD;
}
}

例题

CF449D Jzzhu and Numbers

考虑一个容斥:令g(x)g(x)为使得这些a[ij]a[i_j]与起来有至少xx11的方案数。

那么显然有ans=g(i)(1)ians=\sum g(i) (-1)^i

那么考虑预处理g(x)g(x),可以考虑预处理与起来为statusstatus的方案数f(status)f(status)(这时g(x)==bitcount(i)==xf(i)g(x)==\sum _{bitcount(i)==x} f(i)),那就是要求statusa[ij]status \in a[i_j],假设我们有cntcntstatusa[i]status \in a[i],那么答案就是2cnt12^{cnt}-1(去掉什么都不选的)

可以用上面说的前缀和方法预处理出来f(status)f(status),就做完了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
#define MAXN 1000005
#define MAXM 25
#define MOD 1000000007
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<3)+(x<<1)+(ch^'0');
ch=getchar();
}
return x*f;
}
int f[MAXN];
#define lowbit(x) x&-x
int cnt[MAXN],pow2[MAXN];
int main(){
int n=read();
cnt[0]=0;
for (register int i=1;i<MAXN;++i){
cnt[i]=cnt[i^(lowbit(i))]+1;
}
for (register int i=1;i<=n;++i){
int x=read();
f[x]++;
}
pow2[0]=1;
for (register int i=1;i<MAXN;++i){
pow2[i]=(pow2[i-1]*2)%MOD;
}
for (register int i=0;i<MAXM;++i){
for (register int j=0;j<MAXN;++j){
if (j&(1<<i)) f[j^(1<<i)]=(f[j^(1<<i)]+f[j])%MOD;
}
}
int ans=0;
for (register int i=0;i<MAXN;++i){
ans=(ans+(cnt[i]&1?-1:1)*(pow2[f[i]]-1))%MOD;
}
printf("%d\n",(ans%MOD+MOD)%MOD);
return 0;
}

评论