Codeforces 1151F Sonya and Informatics
题目链接
https://codeforces.com/contest/1151/problem/F
题目大意
给你一个长度为 $N$ 的 $01$ 序列,每次等概率选两个数字进行交换。问进行 $K$ 次操作之后,整个序列为不下降序列的概率是多少。
题目分析
题目要求是不下降序列,满足题意的序列是先全部是 $0$ 然后全部是 $1$。那么可以统计出来在 $0$ 位置上的 $1$ 和在 $1$ 位置上的 $0$ 共 $x$ 对。考虑一次交换,对这些不符合要求的数的影响。假设共 $a$ 个 $0$ 和 $b$ 个 $1$。
-
如果不符合要求的 $0$ 和 $1$ 交换,那么显然会使得不符合要求的数对 $-1$。概率为 $\frac{x \cdot x}{C_n^2}$
-
如果已经符合要求的 $0$ 和 $1$ 交换,那么显然会使得不符合要求的数对 $+1$。概率为 $\frac{(a – x)(b – x)}{C_n^2}$
-
如果是 $0$ 和 $0$ 交换或者 $1$ 和 $1$ 交换,那么显然不符合要求的数对不会变化。概率为 $\frac{C_{(a – x)}^2 + C_{(b – x)}^2}{C_n^2}$
-
如果是不符合要求的 $0$ 和符合要求的 $1$ 交换,显然不符合要求的数对也不会发生变化,同理符合要求的 $0$ 和不符合要求的 $1$ 交换也一样。概率为 $\frac{(a – x)x + (b – x)x}{C_n^2}$
设 $f[i][j]$ 为 $i$ 次交换后有 $j$ 对数对不符合情况的概率,那么答案显然为 $f[k][0]$。有上面这几种情况,转移的情况便可以轻松写出来。
注意到 $n$ 很小但是 $k$ 很大,因此可以考虑使用矩阵快速幂进行优化。最终可在 $O(n^3\log{k})$ 时间内得到答案。
代码
#include <iostream> #include <cstdio> #include <cstdlib> #include <cmath> #include <cstring> #include <string> #include <algorithm> using namespace std; long long mod = 1e9 + 7; int n, m, t; int a[105]; long long cnt[2]; long long inv2; long long inv; struct matrix { long long f[51][51]; matrix(){ memset(f, 0, sizeof(f)); }; }; matrix e; matrix f; matrix h; matrix add(matrix x, matrix y){ int i, j; matrix ret; for(i=0;i<=50;i++){ for(j=0;j<=50;j++){ ret.f[i][j] = (x.f[i][j] + y.f[i][j]) % mod; } } return ret; } matrix mul(matrix x, matrix y){ int i, j, k; matrix ret; for(i=0;i<=50;i++){ for(j=0;j<=50;j++){ for(k=0;k<=50;k++){ ret.f[i][j] = (ret.f[i][j] + x.f[i][k] * y.f[k][j]) % mod; } } } return ret; } matrix qpow(matrix b, long long x){ matrix ret = e; while(x){ if(x & 1){ ret = mul(ret, b); } b = mul(b, b); x >>= 1; } return ret; } long long qpow(long long b, long long x){ long long ret = 1; while(x){ if(x & 1){ ret = ret * b % mod; } b = b * b % mod; x >>= 1; } return ret; } long long getc(long long x){ if(x <= 0)return 0; return x * (x - 1) % mod * inv2 % mod; } int read(){ int x; scanf("%d", &x); return x; } int main(){ long long i, j; long long x, y; inv2 = qpow(2, mod - 2); for(i=0;i<=50;i++){ e.f[i][i] = 1; } n = read(); m = read(); for(i=1;i<=n;i++){ a[i] = read(); cnt[a[i]]++; } inv = qpow(getc(n), mod - 2); x = min(cnt[0], cnt[1]); y = 0; for(i=1;i<=cnt[0];i++){ if(a[i] != 0)y++; } for(i=0;i<=x;i++){ if(i - 1 >= 0)h.f[i][i - 1] = 1ll * i * i * inv % mod; h.f[i][i] = 1ll * (getc(cnt[0]) + getc(cnt[1]) + i * (cnt[0] - i) + i * (cnt[1] - i)) * inv % mod; if(i + 1 <= x)h.f[i][i + 1] = 1ll * (cnt[0] - i) * (cnt[1] - i) * inv % mod; } f.f[0][y] = 1; printf("%lld\n", mul(f, qpow(h, m)).f[0][0]); return 0; }