Codeforces 1151F Sonya and Informatics

Codeforces 1151F Sonya and Informatics

题目链接

https://codeforces.com/contest/1151/problem/F

题目大意

给你一个长度为 $N$ 的 $01$ 序列,每次等概率选两个数字进行交换。问进行 $K$ 次操作之后,整个序列为不下降序列的概率是多少。

题目分析

题目要求是不下降序列,满足题意的序列是先全部是 $0$ 然后全部是 $1$。那么可以统计出来在 $0$ 位置上的 $1$ 和在 $1$ 位置上的 $0$ 共 $x$ 对。考虑一次交换,对这些不符合要求的数的影响。假设共 $a$ 个 $0$ 和 $b$ 个 $1$。

  1. 如果不符合要求的 $0$ 和 $1$ 交换,那么显然会使得不符合要求的数对 $-1$。概率为 $\frac{x \cdot x}{C_n^2}$

  2. 如果已经符合要求的 $0$ 和 $1$ 交换,那么显然会使得不符合要求的数对 $+1$。概率为 $\frac{(a – x)(b – x)}{C_n^2}$

  3. 如果是 $0$ 和 $0$ 交换或者 $1$ 和 $1$ 交换,那么显然不符合要求的数对不会变化。概率为 $\frac{C_{(a – x)}^2 + C_{(b – x)}^2}{C_n^2}$

  4. 如果是不符合要求的 $0$ 和符合要求的 $1$ 交换,显然不符合要求的数对也不会发生变化,同理符合要求的 $0$ 和不符合要求的 $1$ 交换也一样。概率为 $\frac{(a – x)x + (b – x)x}{C_n^2}$

设 $f[i][j]$ 为 $i$ 次交换后有 $j$ 对数对不符合情况的概率,那么答案显然为 $f[k][0]$。有上面这几种情况,转移的情况便可以轻松写出来。

注意到 $n$ 很小但是 $k$ 很大,因此可以考虑使用矩阵快速幂进行优化。最终可在 $O(n^3\log{k})$ 时间内得到答案。

代码

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <string>
#include <algorithm>
 
using namespace std;
 
long long mod = 1e9 + 7;
 
int n, m, t;
 
int a[105];
long long cnt[2];
 
long long inv2;
long long inv;
 
struct matrix {
	long long f[51][51];
	
	matrix(){
		memset(f, 0, sizeof(f));
	};
};
 
matrix e;
matrix f;
matrix h;
 
matrix add(matrix x, matrix y){
	int i, j;
	matrix ret;
	for(i=0;i<=50;i++){
		for(j=0;j<=50;j++){
			ret.f[i][j] = (x.f[i][j] + y.f[i][j]) % mod;
		}
	}
	
	return ret;
}
 
matrix mul(matrix x, matrix y){
	int i, j, k;
	matrix ret;
	for(i=0;i<=50;i++){
		for(j=0;j<=50;j++){
			for(k=0;k<=50;k++){
				ret.f[i][j] = (ret.f[i][j] + x.f[i][k] * y.f[k][j]) % mod;
			}
		}
	}
	
	return ret;
}
 
matrix qpow(matrix b, long long x){
	matrix ret = e;
	
	while(x){
		if(x & 1){
			ret = mul(ret, b);
		}
		
		b = mul(b, b);
		x >>= 1;
	}
	
	return ret;
}
 
long long qpow(long long b, long long x){
	long long ret = 1;
	
	while(x){
		if(x & 1){
			ret = ret * b % mod;
		}
		
		b = b * b % mod;
		x >>= 1;
	}
	
	return ret;
}
 
long long getc(long long x){
	if(x <= 0)return 0;
	return x * (x - 1) % mod * inv2 % mod;
}
 
int read(){
	int x;
	scanf("%d", &x);
	return x;
}
 
int main(){
	long long i, j;
	long long x, y;
	
	inv2 = qpow(2, mod - 2);
	
	for(i=0;i<=50;i++){
		e.f[i][i] = 1;
	}
	
	n = read();
	m = read();
	
	for(i=1;i<=n;i++){
		a[i] = read();
		cnt[a[i]]++;
	}
	
	inv = qpow(getc(n), mod - 2);
	
	x = min(cnt[0], cnt[1]);
	y = 0;
	
	for(i=1;i<=cnt[0];i++){
		if(a[i] != 0)y++;
	}
	
	for(i=0;i<=x;i++){
		if(i - 1 >= 0)h.f[i][i - 1] = 1ll * i * i * inv % mod;
		h.f[i][i] = 1ll * (getc(cnt[0]) + getc(cnt[1]) + i * (cnt[0] - i) + i * (cnt[1] - i)) * inv % mod;
		if(i + 1 <= x)h.f[i][i + 1] = 1ll * (cnt[0] - i) * (cnt[1] - i) * inv % mod;
	}
	
	f.f[0][y] = 1;
	printf("%lld\n", mul(f, qpow(h, m)).f[0][0]);
	
	return 0;
}

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据