The 2018 ACM-ICPC Asia Qingdao Regional Contest L.Sub-cycle Graph
题目链接
http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=4069
题目大意
告诉你一个图现在有 $n$ 个点 $m$ 条边,问你有多少种不同的图使得这些图可以通过加非负整数条边变成一个简单环。
题目分析
对于 $n < m$ 的情况,肯定不可能存在可行解。
对于 $n = m$ 的情况,所有的点已经连成一个简单环,那么答案肯定为 $\frac{n!}{2n}$。
对于 $n > m$ 的情况肯定是由这 $n$ 个点组成 $k = n – m$ 条链。
首先考虑由 $i$ 个点组成 $1$ 条链由多少种情况。很明显有 $a_1 = 1, a_i = \frac{i!}{2}(i \geq 2)$,则可以写成指数型生成函数
设 $f(x) = \sum\frac{a_i}{i!}x ^ i$
则有
$$f(x) = \frac{1}{2}(2x + x ^ 2 + x ^ 3 + …)$$ 即
$$f(x) = \frac{1}{2}(x – 1 + \frac{1}{1 – x})$$
可得
$$f(x) = \frac{1}{2}(\frac{-x ^ 2 + 2x}{1 – x}) = x(\frac{-\frac{1}{2}x + 1}{1 – x})$$
得到了一条链的生成函数之后便可以很容易得到 $k$ 条链的指数型生成函数
$$g(x) = [f(x)] ^ k = [x(\frac{-\frac{1}{2}x + 1}{1 – x})] ^ k$$
则有 $g(x) = \sum\frac{b_i}{i!}x ^ i$ ,系数 $\frac{b_n}{k!}$(需要对 $k$ 条链进行消序处理) 即为答案
将 $g(x)$ 化简,得到
$$g(x) = x ^ k (-\frac{1}{2}x + 1) ^ k (\frac{1}{1 – x}) ^ k$$
对这些项进行展开,计算 $x ^ n$ 项系数即可。
代码
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <string>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 5;
long long n, m, k, t;
long long mod = 1e9 + 7;
long long fac[maxn];
long long inv[maxn];
long long tp[maxn];
long long tpinv[maxn];
int read(){
int x = 0;
char ch = getchar();
while('0' > ch or ch > '9'){
ch = getchar();
}
while('0' <= ch and ch <= '9'){
x = x * 10 + ch - '0';
ch = getchar();
}
return x;
}
long long qpow(long long b, long long x){
long long ret = 1;
while(x){
if(x & 1){
ret = ret * b % mod;
}
b = b * b % mod;
x >>= 1;
}
return ret;
}
long long getc(long long x, long long y){
if(y > x){
return 0;
}
return ((fac[x] * inv[x - y] % mod) * inv[y]) % mod;
}
int init(){
long long i, j;
fac[0] = 1;
inv[0] = 1;
tp[0] = 1;
tpinv[0] = 1;
for(i=1;i<=100000;i++){
fac[i] = i * fac[i - 1] % mod;
tp[i] = 2ll * tp[i - 1] % mod;
}
inv[100000] = qpow(fac[100000], mod - 2);
tpinv[100000] = qpow(tp[100000], mod - 2);
for(i=100000-1;i>0;i--){
inv[i] = inv[i + 1] * (i + 1) % mod;
tpinv[i] = tpinv[i + 1] * 2ll % mod;
}
return 0;
}
int main(){
int i, j;
long long ans = 0;
long long tmp;
init();
t = read();
while(t--){
n = read();
m = read();
k = n - m;
if(m > n){
printf("0\n");
continue;
}
if(n == m){
printf("%lld\n", fac[n - 1] * tpinv[1] % mod);
continue;
}
ans = 0;
for(i=0;i<=min(m,k);i++){
tmp = tpinv[i] * getc(k, i) % mod;
tmp = tmp * getc(n - i - 1, m - i) % mod;
if(i & 1){
tmp = (mod - tmp) % mod;
}
ans = (ans + tmp) % mod;
}
ans = ans * fac[n] % mod * inv[k] % mod;
printf("%lld\n", ans);
}
return 0;
}