Codeforces 1096G Lucky Tickets
题目链接
http://codeforces.com/contest/1096/problem/G
题目大意
一个长度为 $n$ 位的卡号每一位都由给你的 $k$ 个数字中的一个构成,如果一个卡号的前 $\frac{n}{2}$ 个数的和与后 $\frac{n}{2}$ 个数字的和一样,那么它就是一个幸运卡。问一共有多少个幸运卡。
题目分析
显然,如果由这 $k$ 个数所组成长度为 $\frac{n}{2}$ 的串的和为 $i$ 共有 $S_i$ 种可能的话,那么答案就为 $\sum{{S_i} ^ 2}$。那么如何求每种总和各有多少种情况呢。第一想法显然是背包。但是背包的复杂度太高了,并不能满足这道题。
考虑使用生成函数。设这 $k$个数为 $a_1, a_2, …, a_k$ ,则有
$$f(x) = (\sum\limits_{i=1}^{k}{x ^ {a_i}}) ^ {\frac{n}{2}}$$
计算出展开式,最后每项的系数的平方加和即为答案。计算多项式乘法可以使用 $NTT$ 或者 $FFT$ 优化,利用快速幂计算该多项式的 $\frac{n}{2}$ 次方。
代码
#include <iostream> #include <cstdio> #include <cstdlib> #include <cmath> #include <cstring> #include <string> #include <algorithm> #include <vector> using namespace std; int mod = 998244353; int ww[(1 << 24)], *e = ww + (1 << 23); int qpow(int b, int x){ int ret = 1; while(x){ if(x & 1){ ret = 1ll * ret * b % mod; } b = 1ll * b * b % mod; x >>= 1; } return ret; } int DFT(vector <int> &A, int N, int C){ int w, l, r; vector <int> B(N); int i, j, k; for(i=N;i>1;i>>=1,swap(A, B)){ for(j=0;j<N;j+=i){ for(k=0;k<i;k+=2){ B[j + (k >> 1)] = A[j + k]; B[j + (k >> 1) + (i >> 1)] = A[j + k + 1]; } } } for(i=2;i<=N;i<<=1,swap(A, B)){ for(w=0,k=0;k<(i>>1);k++,w+=N/i*C){ for(j=0;j<N;j+=i){ l = A[j + k], r = 1ll * e[w] * A[j + (i >> 1) + k] % mod; B[j + k] = ((l + r >= mod) ? l + r - mod : l + r); B[j + (i >> 1) + k] = ((l - r < 0) ? l - r + mod : l - r); } } } return 0; } vector <int> NTT(vector <int> &A, vector <int> &B){ int AZ = A.size(); int BZ = B.size(); int N; for(N=1;N<AZ+BZ;N<<=1); vector <int> AN(N); vector <int> BN(N); copy(A.begin(), A.end(), AN.begin()); copy(B.begin(), B.end(), BN.begin()); e[0] = e[-N] = 1;e[1] = e[1 - N] = qpow(3, (mod - 1) / N); for(int i=2;i<N;i++)e[i - N] = e[i] = 1ll * e[i - 1] * e[1] % mod; DFT(AN, N, 1);DFT(BN, N, 1); for(int i=0;i<N;i++)AN[i] = 1ll * AN[i] * BN[i] % mod; DFT(AN, N, -1); int NI = qpow(N, mod - 2); for(int i=0;i<N;i++)AN[i] = 1ll * AN[i] * NI % mod; return AN; } vector <int> NTT2(vector <int> &A){ int AZ = A.size(); int N; for(N=1;N<AZ+AZ;N<<=1); vector <int> AN(N); copy(A.begin(), A.end(), AN.begin()); e[0] = e[-N] = 1;e[1] = e[1 - N] = qpow(3, (mod - 1) / N); for(int i=2;i<N;i++)e[i - N] = e[i] = 1ll * e[i - 1] * e[1] % mod; DFT(AN, N, 1); for(int i=0;i<N;i++)AN[i] = 1ll * AN[i] * AN[i] % mod; DFT(AN, N, -1); int NI = qpow(N, mod - 2); for(int i=0;i<N;i++)AN[i] = 1ll * AN[i] * NI % mod; return AN; } int n, m; vector <int> a; vector <int> b; int main(){ int i, j; int x; int ans = 0; scanf("%d%d", &n, &m); for(i=0;i<10;i++){ a.push_back(0); } for(i=1;i<=m;i++){ scanf("%d", &x); a[x] = 1; } x = n / 2 - 1; b = a; while(x){ if(x & 1){ b = NTT(b, a); for(i=b.size()-1;i>=0;i--){ if(!b[i]){ b.pop_back(); }else{ break; } } } //a = NTT(a, a); a = NTT2(a); for(i=a.size()-1;i>=0;i--){ if(!a[i]){ a.pop_back(); }else{ break; } } x >>= 1; } for(i=0;i<b.size();i++){ ans = (ans + 1ll * b[i] * b[i] % mod) % mod; } printf("%d\n", ans); return 0; }