ErikTse Runtime

  • 首页 / Home
  • | 算法学习 / Algorithm
    • 所有 / All
    • 简单 / Easy
    • 中等 / Medium
    • 困难 / Hard
  • | 技术分享 / Technology
    • 所有 / All
    • 网络技术 / NetWork
    • 资源共享 / Resource
    • 项目实践 / Event
  • ETOJ在线评测系统
Keep Going.
温故而知新.
  1. 首页
  2. 算法学习
  3. 困难
  4. 正文

[HDUOJ7191]Count Set(图论 + 排列组合 + NTT + 堆优化 + DP)

2022年8月12日 190点热度 0人点赞 0条评论

题目传送门:Problem - 7191 (hdu.edu.cn)

题目 / Problem

给定一个[1, n]的排列p和一个数字K,求[1, n]的子集T的个数。

子集T的大小为K,对于集合T中的任意一个元素x,p[x]不在该集合中。

思路 / Thought

排列的题很容易想到将i与p[i]连接起来,一定能组成若干个环,这里要求对于T中任意一个元素x,p[x]不能在T中,也就是说在若干个环中选取K个点,且K个点不能相邻(不同环上的点是不相邻的)。

现在问题转换成了一个排列组合的问题,在由B个不同的点组成的环中选择M个不相邻的点,方法一共有多少种?

如果不是环而是链的话,答案是\(ans = C_{n - m + 1}^{m}\),但是这里是环中,选中一个点x,可以分为两种情况:选x和不选x。

如果选了x,那么剩下的就是在N - 3个点组成的链中选M - 1个不相邻的点,也就是\( C_{n - m - 1}^{m - 1}\)。

如果没选x,那么就是在剩下N - 1个点组成的链中选M个不相邻的点,也就是\(C_{n - m}^{m}\)。

所以在一个大小为N的环中选M个不相邻的点的方案数为\(f(n, m) = C_{n - m - 1}^{m - 1} + C_{n-m}^{m}\)。

有了这个,现在问题进一步变化,变成了在cnt个环中取K个,这个情况就非常多了,搜索dp啥的都不好写。比如我要在4个环中取5个点,可以是\({1,1,1,2}\),也可以\({0,1,3,1}\)等等情况,每一个环的取值范围都很大,复杂度爆炸而且还不好写。

于是多项式诞生了,我们设dp[i][j]表示第i个环中取j个不相邻点的方案数,比如\(dp[i] = {3,2,1}\),我们就将这个数组等价于\(f(x) = 3x^0 + 2x^1 + x^2\),这样每一个环都是一个多项式,最后只需要将多项式全部是相乘,得到的最终的多项式中\(x ^ K\)的系数就是答案。

这个\(x ^ K\)由前面的许多合法路径转移过来的,假如K = 4,那么最终的多项式中的\(x ^ 4\)可能由\(x ^ 1 * x ^ 1 * x ^ 2\)或者\(x ^ 2 * x ^ 2 * x ^ 0\)等等方案做出贡献。

那么现在问题就是如何求多项式的卷积呢?朴素的算法,一个个地相乘,复杂度是O(N ^ 2),肯定会超时,所以要用NTT(快速数论变换),复杂度是O(NlogN),利用了快速傅里叶变换的思想。

毕竟是竞赛选手,会用模板就行,不用深究原理。

知识点总结 / Knowledge

大小为N的链中取M个不相邻的点方案数:\(C_{n - m + 1}^{m}\)。

大小为N的环中取M个不相邻的点方案数:\(C_{n - m - 1}^{m - 1} + C_{n-m}^{m}\)。

NTT模板(取自多项式各种板子的快速实现 - mioi - 博客园 (cnblogs.com))

namespace NTT
{
	#define SZ(v) ((int)(v).size())
	typedef vector<int> Poly;
	const int N = 2e6 + 9, G = 3;//N要开4倍 
	int qmi(int a, int b,int mod)
	{
		int res = 1;
		while(b) 
		{
			if(b & 1) res = res * a % mod;
			a = a * a % mod, b >>= 1;
		}
		return res;
	}
	
	void ntt(int a[],int lim,int inv)
	{
		//lim是预先处理好的一个大于最大长度的2的倍数 
		for(int i = 0,j = 0;i < lim; ++ i)
		{
			if(i < j) swap(a[i], a[j]);
			for(int l = (lim >> 1); (j ^= l) < l; l >>= 1);
		}
		
		for(int m = 1,k = 2;k <= lim; m = k, k <<= 1) 
		{
			int gn = qmi(G, (P - 1) / k, P);
			if(inv == -1) gn = qmi(gn, P-2, P);
			
			for(int i = 0;i < lim;i += k) 
			{
				int g = 1;
				for(int j = 0;j < m;++ j,g = g * gn % P) 
				{
					int u = a[i + j], v = a[i + j + m];
					a[i + j] = u + g * v;
					a[i + j + m]= u - g * v;
				}
			}
			
			for(int i = 0;i < lim; ++ i) a[i] = (a[i] % P + P) % P;
		}
		
		if(inv == 1) return;
		int inv_ = qmi(lim, P - 2, P);
		for(int i = 0;i < lim; ++ i) a[i] = a[i] * inv_ % P;
	}
	
	Poly mul(Poly a,Poly b)
	{
		static int A[N], B[N], C[N];
		int lim = 1;while(lim < SZ(a) + SZ(b) - 1) lim <<= 1;
		//标准化AB数组 
		for(int i = 0;  i < SZ(a); ++ i) A[i] = a[i];
		for(int i = SZ(a);i < lim; ++ i) A[i] = 0;
		for(int i = 0;  i < SZ(b); ++ i) B[i] = b[i];
		for(int i = SZ(b);i < lim; ++ i) B[i] = 0;
		
		ntt(A, lim, 1), ntt(B, lim, 1);//系数表达转成点值表达 
		for(int i = 0;i < lim; ++ i) C[i] = A[i] * B[i] % P;
		ntt(C, lim, -1);//点值表达转换成系数表达 
		//将结果存入C数组 
		Poly c;for(int i = 0;i < SZ(a) + SZ(b) - 1; ++ i)c.push_back(C[i]);
		return c;
	}
	//堆的排序方法,较小的先积 
	struct cmp{bool operator ()(const Poly &a, const Poly &b)const{return SZ(a) > SZ(b);}};
}

代码 / Code

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 2e6 + 9, P = 998244353;
int fac[maxn], a[maxn], b[maxn], cnt = 0;

namespace NTT
{
	#define SZ(v) ((int)(v).size())
	#define int long long 
	typedef vector<int> Poly;//N要开4倍 
	const int N = 2e6 + 9, G = 3, P = 998244353;
	int qmi(int a, int b,int mod)
	{
		int res = 1;
		while(b) 
		{
			if(b & 1) res = res * a % mod;
			a = a * a % mod, b >>= 1;
		}
		return res;
	}
	
	void ntt(int a[],int lim,int inv)
	{
		//lim是预先处理好的一个大于最大长度的2的倍数 
		for(int i = 0,j = 0;i < lim; ++ i)
		{
			if(i < j) swap(a[i], a[j]);
			for(int l = (lim >> 1); (j ^= l) < l; l >>= 1);
		}
		
		for(int m = 1,k = 2;k <= lim; m = k, k <<= 1) 
		{
			int gn = qmi(G, (P - 1) / k, P);
			if(inv == -1) gn = qmi(gn, P-2, P);
			
			for(int i = 0;i < lim;i += k) 
			{
				int g = 1;
				for(int j = 0;j < m;++ j,g = g * gn % P) 
				{
					int u = a[i + j], v = a[i + j + m];
					a[i + j] = u + g * v;
					a[i + j + m]= u - g * v;
				}
			}
			
			for(int i = 0;i < lim; ++ i) a[i] = (a[i] % P + P) % P;
		}
		
		if(inv == 1) return;
		int inv_ = qmi(lim, P - 2, P);
		for(int i = 0;i < lim; ++ i) a[i] = a[i] * inv_ % P;
	}
	
	Poly mul(Poly a,Poly b)
	{
		static int A[N], B[N], C[N];
		int lim = 1;while(lim < SZ(a) + SZ(b) - 1) lim <<= 1;
		//标准化AB数组 
		for(int i = 0;  i < SZ(a); ++ i) A[i] = a[i];
		for(int i = SZ(a);i < lim; ++ i) A[i] = 0;
		for(int i = 0;  i < SZ(b); ++ i) B[i] = b[i];
		for(int i = SZ(b);i < lim; ++ i) B[i] = 0;
		
		ntt(A, lim, 1), ntt(B, lim, 1);//系数表达转成点值表达 
		for(int i = 0;i < lim; ++ i) C[i] = A[i] * B[i] % P;
		ntt(C, lim, -1);//点值表达转换成系数表达 
		//将结果存入C数组 
		Poly c;for(int i = 0;i < SZ(a) + SZ(b) - 1; ++ i)c.push_back(C[i]);
		return c;
	}
	//堆的排序方法,较小的先卷积 
	struct cmp{bool operator ()(const Poly &a, const Poly &b)const{return SZ(a) > SZ(b);}};
}

int inv(int x){return NTT::qmi(x, P - 2, P);}
int C(int n,int m){return (n < 0 || m < 0 || n < m) ? 0 : fac[n] * inv(fac[m]) % P * inv(fac[n - m]) % P;}
int f(int n,int m){return (C(n - m - 1, m - 1) + C(n - m, m)) % P;}
int dfs(int x, bitset<maxn> &vis)
{//求环的大小的一个简单递归 
	if(vis[x])return 0;vis[x] = 1;
	return dfs(a[x], vis) + 1;
}


void solve()
{
	int N, K;cin >> N >> K; 
	cnt = 0;
	bitset<maxn> vis;
	
	for(int i = 1;i <= N; ++ i)cin >> a[i];
	for(int i = 1;i <= N; ++ i)if(!vis[i])b[++ cnt] = dfs(i, vis);
	//环更新完毕 ,b[i]表示第i个环的大小,一共cnt个环 
	priority_queue<NTT::Poly, vector<NTT::Poly>, NTT::cmp> pq;
	for(int i = 1;i <= cnt; ++ i)
	{
		NTT::Poly tmp;//求得多项式并存入堆中 
		for(int j = 0;j <= b[i] / 2 && j <= K; ++ j)tmp.push_back(f(b[i], j));
		pq.push(tmp);
	}
	
	while(pq.size() >= 2)
	{
		//每次选出项数最少的两个多项式进行合并,复杂度最小 
		NTT::Poly t1 = pq.top();pq.pop();
		NTT::Poly t2 = pq.top();pq.pop();
		pq.push(NTT::mul(t1, t2)); 
	}
	
	NTT::Poly ans(pq.top()); //注意输出的时候需要判断如果无法选择K个,那么结果就0 
	cout << (K < ans.size() ? ans[K] : 0) << '\n';
}

signed main()
{
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
	//初始化阶乘 
	fac[0] = 1;
	for(int i = 1;i <= 5e5 + 1; ++ i)fac[i] = i * fac[i - 1] % P;
	
	int _;cin >> _;
	while(_ --)solve();
	return 0;
} 
本作品采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可
标签: C++ DP NTT 动态规划 图论 堆优化 快速数论变换 排列组合 算法竞赛 集合论
最后更新:2022年10月9日

Eriktse

18岁,性别未知,ACM-ICPC现役选手,ICPC亚洲区域赛银牌摆烂人,CCPC某省赛铜牌蒟蒻,武汉某院校计算机科学与技术专业本科在读。

点赞
< 上一篇
下一篇 >

文章评论

取消回复

Eriktse

18岁,性别未知,ACM-ICPC现役选手,ICPC亚洲区域赛银牌摆烂人,CCPC某省赛铜牌蒟蒻,武汉某院校计算机科学与技术专业本科在读。

文章目录
  • 题目 / Problem
  • 思路 / Thought
  • 知识点总结 / Knowledge
  • 代码 / Code

友情链接 | 站点地图

COPYRIGHT © 2022 ErikTse Runtime. ALL RIGHTS RESERVED.

Theme Kratos | Hosted In TENCENT CLOUD

赣ICP备2022001555号-1

赣公网安备 36092402000057号