[HDUOJ7191]Count Set(图论 + 排列组合 + NTT + 堆优化 + DP)

发布于 2022-08-12  366 次阅读


题目传送门:Problem - 7191 (hdu.edu.cn)

题目 / Problem

给定一个[1, n]的排列p和一个数字K,求[1, n]的子集T的个数。

子集T的大小为K,对于集合T中的任意一个元素x,p[x]不在该集合中。

思路 / Thought

排列的题很容易想到将i与p[i]连接起来,一定能组成若干个环,这里要求对于T中任意一个元素x,p[x]不能在T中,也就是说在若干个环中选取K个点,且K个点不能相邻(不同环上的点是不相邻的)。

现在问题转换成了一个排列组合的问题,在由B个不同的点组成的环中选择M个不相邻的点,方法一共有多少种?

如果不是环而是链的话,答案是\(ans = C_{n - m + 1}^{m}\),但是这里是环中,选中一个点x,可以分为两种情况:选x不选x

如果选了x,那么剩下的就是在N - 3个点组成的链中选M - 1个不相邻的点,也就是\( C_{n - m - 1}^{m - 1}\)。

如果没选x,那么就是在剩下N - 1个点组成的链中选M个不相邻的点,也就是\(C_{n - m}^{m}\)。

所以在一个大小为N的环中选M个不相邻的点的方案数为\(f(n, m) = C_{n - m - 1}^{m - 1} + C_{n-m}^{m}\)。

有了这个,现在问题进一步变化,变成了在cnt个环中取K个,这个情况就非常多了,搜索dp啥的都不好写。比如我要在4个环中取5个点,可以是\({1,1,1,2}\),也可以\({0,1,3,1}\)等等情况,每一个环的取值范围都很大,复杂度爆炸而且还不好写。

于是多项式诞生了,我们设dp[i][j]表示第i个环中取j个不相邻点的方案数,比如\(dp[i] = {3,2,1}\),我们就将这个数组等价于\(f(x) = 3x^0 + 2x^1 + x^2\),这样每一个环都是一个多项式,最后只需要将多项式全部是相乘,得到的最终的多项式中\(x ^ K\)的系数就是答案。

这个\(x ^ K\)由前面的许多合法路径转移过来的,假如K = 4,那么最终的多项式中的\(x ^ 4\)可能由\(x ^ 1 * x ^ 1 * x ^ 2\)或者\(x ^ 2 * x ^ 2 * x ^ 0\)等等方案做出贡献。

那么现在问题就是如何求多项式的卷积呢?朴素的算法,一个个地相乘,复杂度是O(N ^ 2),肯定会超时,所以要用NTT(快速数论变换),复杂度是O(NlogN),利用了快速傅里叶变换的思想。

毕竟是竞赛选手,会用模板就行,不用深究原理。

知识点总结 / Knowledge

大小为N的链中取M个不相邻的点方案数:\(C_{n - m + 1}^{m}\)。

大小为N的环中取M个不相邻的点方案数:\(C_{n - m - 1}^{m - 1} + C_{n-m}^{m}\)。

NTT模板(取自多项式各种板子的快速实现 - mioi - 博客园 (cnblogs.com)

namespace NTT
{
	#define SZ(v) ((int)(v).size())
	typedef vector<int> Poly;
	const int N = 2e6 + 9, G = 3;//N要开4倍 
	int qmi(int a, int b,int mod)
	{
		int res = 1;
		while(b) 
		{
			if(b & 1) res = res * a % mod;
			a = a * a % mod, b >>= 1;
		}
		return res;
	}
	
	void ntt(int a[],int lim,int inv)
	{
		//lim是预先处理好的一个大于最大长度的2的倍数 
		for(int i = 0,j = 0;i < lim; ++ i)
		{
			if(i < j) swap(a[i], a[j]);
			for(int l = (lim >> 1); (j ^= l) < l; l >>= 1);
		}
		
		for(int m = 1,k = 2;k <= lim; m = k, k <<= 1) 
		{
			int gn = qmi(G, (P - 1) / k, P);
			if(inv == -1) gn = qmi(gn, P-2, P);
			
			for(int i = 0;i < lim;i += k) 
			{
				int g = 1;
				for(int j = 0;j < m;++ j,g = g * gn % P) 
				{
					int u = a[i + j], v = a[i + j + m];
					a[i + j] = u + g * v;
					a[i + j + m]= u - g * v;
				}
			}
			
			for(int i = 0;i < lim; ++ i) a[i] = (a[i] % P + P) % P;
		}
		
		if(inv == 1) return;
		int inv_ = qmi(lim, P - 2, P);
		for(int i = 0;i < lim; ++ i) a[i] = a[i] * inv_ % P;
	}
	
	Poly mul(Poly a,Poly b)
	{
		static int A[N], B[N], C[N];
		int lim = 1;while(lim < SZ(a) + SZ(b) - 1) lim <<= 1;
		//标准化AB数组 
		for(int i = 0;  i < SZ(a); ++ i) A[i] = a[i];
		for(int i = SZ(a);i < lim; ++ i) A[i] = 0;
		for(int i = 0;  i < SZ(b); ++ i) B[i] = b[i];
		for(int i = SZ(b);i < lim; ++ i) B[i] = 0;
		
		ntt(A, lim, 1), ntt(B, lim, 1);//系数表达转成点值表达 
		for(int i = 0;i < lim; ++ i) C[i] = A[i] * B[i] % P;
		ntt(C, lim, -1);//点值表达转换成系数表达 
		//将结果存入C数组 
		Poly c;for(int i = 0;i < SZ(a) + SZ(b) - 1; ++ i)c.push_back(C[i]);
		return c;
	}
	//堆的排序方法,较小的先积 
	struct cmp{bool operator ()(const Poly &a, const Poly &b)const{return SZ(a) > SZ(b);}};
}

代码 / Code

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 2e6 + 9, P = 998244353;
int fac[maxn], a[maxn], b[maxn], cnt = 0;

namespace NTT
{
	#define SZ(v) ((int)(v).size())
	#define int long long 
	typedef vector<int> Poly;//N要开4倍 
	const int N = 2e6 + 9, G = 3, P = 998244353;
	int qmi(int a, int b,int mod)
	{
		int res = 1;
		while(b) 
		{
			if(b & 1) res = res * a % mod;
			a = a * a % mod, b >>= 1;
		}
		return res;
	}
	
	void ntt(int a[],int lim,int inv)
	{
		//lim是预先处理好的一个大于最大长度的2的倍数 
		for(int i = 0,j = 0;i < lim; ++ i)
		{
			if(i < j) swap(a[i], a[j]);
			for(int l = (lim >> 1); (j ^= l) < l; l >>= 1);
		}
		
		for(int m = 1,k = 2;k <= lim; m = k, k <<= 1) 
		{
			int gn = qmi(G, (P - 1) / k, P);
			if(inv == -1) gn = qmi(gn, P-2, P);
			
			for(int i = 0;i < lim;i += k) 
			{
				int g = 1;
				for(int j = 0;j < m;++ j,g = g * gn % P) 
				{
					int u = a[i + j], v = a[i + j + m];
					a[i + j] = u + g * v;
					a[i + j + m]= u - g * v;
				}
			}
			
			for(int i = 0;i < lim; ++ i) a[i] = (a[i] % P + P) % P;
		}
		
		if(inv == 1) return;
		int inv_ = qmi(lim, P - 2, P);
		for(int i = 0;i < lim; ++ i) a[i] = a[i] * inv_ % P;
	}
	
	Poly mul(Poly a,Poly b)
	{
		static int A[N], B[N], C[N];
		int lim = 1;while(lim < SZ(a) + SZ(b) - 1) lim <<= 1;
		//标准化AB数组 
		for(int i = 0;  i < SZ(a); ++ i) A[i] = a[i];
		for(int i = SZ(a);i < lim; ++ i) A[i] = 0;
		for(int i = 0;  i < SZ(b); ++ i) B[i] = b[i];
		for(int i = SZ(b);i < lim; ++ i) B[i] = 0;
		
		ntt(A, lim, 1), ntt(B, lim, 1);//系数表达转成点值表达 
		for(int i = 0;i < lim; ++ i) C[i] = A[i] * B[i] % P;
		ntt(C, lim, -1);//点值表达转换成系数表达 
		//将结果存入C数组 
		Poly c;for(int i = 0;i < SZ(a) + SZ(b) - 1; ++ i)c.push_back(C[i]);
		return c;
	}
	//堆的排序方法,较小的先卷积 
	struct cmp{bool operator ()(const Poly &a, const Poly &b)const{return SZ(a) > SZ(b);}};
}

int inv(int x){return NTT::qmi(x, P - 2, P);}
int C(int n,int m){return (n < 0 || m < 0 || n < m) ? 0 : fac[n] * inv(fac[m]) % P * inv(fac[n - m]) % P;}
int f(int n,int m){return (C(n - m - 1, m - 1) + C(n - m, m)) % P;}
int dfs(int x, bitset<maxn> &vis)
{//求环的大小的一个简单递归 
	if(vis[x])return 0;vis[x] = 1;
	return dfs(a[x], vis) + 1;
}


void solve()
{
	int N, K;cin >> N >> K; 
	cnt = 0;
	bitset<maxn> vis;
	
	for(int i = 1;i <= N; ++ i)cin >> a[i];
	for(int i = 1;i <= N; ++ i)if(!vis[i])b[++ cnt] = dfs(i, vis);
	//环更新完毕 ,b[i]表示第i个环的大小,一共cnt个环 
	priority_queue<NTT::Poly, vector<NTT::Poly>, NTT::cmp> pq;
	for(int i = 1;i <= cnt; ++ i)
	{
		NTT::Poly tmp;//求得多项式并存入堆中 
		for(int j = 0;j <= b[i] / 2 && j <= K; ++ j)tmp.push_back(f(b[i], j));
		pq.push(tmp);
	}
	
	while(pq.size() >= 2)
	{
		//每次选出项数最少的两个多项式进行合并,复杂度最小 
		NTT::Poly t1 = pq.top();pq.pop();
		NTT::Poly t2 = pq.top();pq.pop();
		pq.push(NTT::mul(t1, t2)); 
	}
	
	NTT::Poly ans(pq.top()); //注意输出的时候需要判断如果无法选择K个,那么结果就0 
	cout << (K < ans.size() ? ans[K] : 0) << '\n';
}

signed main()
{
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
	//初始化阶乘 
	fac[0] = 1;
	for(int i = 1;i <= 5e5 + 1; ++ i)fac[i] = i * fac[i - 1] % P;
	
	int _;cin >> _;
	while(_ --)solve();
	return 0;
}