题目传送门:Problem - 7191 (hdu.edu.cn)
题目 / Problem
给定一个[1, n]的排列p和一个数字K,求[1, n]的子集T的个数。
子集T的大小为K,对于集合T中的任意一个元素x,p[x]不在该集合中。
思路 / Thought
排列的题很容易想到将i与p[i]连接起来,一定能组成若干个环,这里要求对于T中任意一个元素x,p[x]不能在T中,也就是说在若干个环中选取K个点,且K个点不能相邻(不同环上的点是不相邻的)。
现在问题转换成了一个排列组合的问题,在由B个不同的点组成的环中选择M个不相邻的点,方法一共有多少种?
如果不是环而是链的话,答案是\(ans = C_{n - m + 1}^{m}\),但是这里是环中,选中一个点x,可以分为两种情况:选x和不选x。
如果选了x,那么剩下的就是在N - 3个点组成的链中选M - 1个不相邻的点,也就是\( C_{n - m - 1}^{m - 1}\)。
如果没选x,那么就是在剩下N - 1个点组成的链中选M个不相邻的点,也就是\(C_{n - m}^{m}\)。
所以在一个大小为N的环中选M个不相邻的点的方案数为\(f(n, m) = C_{n - m - 1}^{m - 1} + C_{n-m}^{m}\)。
有了这个,现在问题进一步变化,变成了在cnt个环中取K个,这个情况就非常多了,搜索dp啥的都不好写。比如我要在4个环中取5个点,可以是\({1,1,1,2}\),也可以\({0,1,3,1}\)等等情况,每一个环的取值范围都很大,复杂度爆炸而且还不好写。
于是多项式诞生了,我们设dp[i][j]表示第i个环中取j个不相邻点的方案数,比如\(dp[i] = {3,2,1}\),我们就将这个数组等价于\(f(x) = 3x^0 + 2x^1 + x^2\),这样每一个环都是一个多项式,最后只需要将多项式全部是相乘,得到的最终的多项式中\(x ^ K\)的系数就是答案。
这个\(x ^ K\)由前面的许多合法路径转移过来的,假如K = 4,那么最终的多项式中的\(x ^ 4\)可能由\(x ^ 1 * x ^ 1 * x ^ 2\)或者\(x ^ 2 * x ^ 2 * x ^ 0\)等等方案做出贡献。
那么现在问题就是如何求多项式的卷积呢?朴素的算法,一个个地相乘,复杂度是O(N ^ 2),肯定会超时,所以要用NTT(快速数论变换),复杂度是O(NlogN),利用了快速傅里叶变换的思想。
毕竟是竞赛选手,会用模板就行,不用深究原理。
知识点总结 / Knowledge
大小为N的链中取M个不相邻的点方案数:\(C_{n - m + 1}^{m}\)。
大小为N的环中取M个不相邻的点方案数:\(C_{n - m - 1}^{m - 1} + C_{n-m}^{m}\)。
NTT模板(取自多项式各种板子的快速实现 - mioi - 博客园 (cnblogs.com))
namespace NTT { #define SZ(v) ((int)(v).size()) typedef vector<int> Poly; const int N = 2e6 + 9, G = 3;//N要开4倍 int qmi(int a, int b,int mod) { int res = 1; while(b) { if(b & 1) res = res * a % mod; a = a * a % mod, b >>= 1; } return res; } void ntt(int a[],int lim,int inv) { //lim是预先处理好的一个大于最大长度的2的倍数 for(int i = 0,j = 0;i < lim; ++ i) { if(i < j) swap(a[i], a[j]); for(int l = (lim >> 1); (j ^= l) < l; l >>= 1); } for(int m = 1,k = 2;k <= lim; m = k, k <<= 1) { int gn = qmi(G, (P - 1) / k, P); if(inv == -1) gn = qmi(gn, P-2, P); for(int i = 0;i < lim;i += k) { int g = 1; for(int j = 0;j < m;++ j,g = g * gn % P) { int u = a[i + j], v = a[i + j + m]; a[i + j] = u + g * v; a[i + j + m]= u - g * v; } } for(int i = 0;i < lim; ++ i) a[i] = (a[i] % P + P) % P; } if(inv == 1) return; int inv_ = qmi(lim, P - 2, P); for(int i = 0;i < lim; ++ i) a[i] = a[i] * inv_ % P; } Poly mul(Poly a,Poly b) { static int A[N], B[N], C[N]; int lim = 1;while(lim < SZ(a) + SZ(b) - 1) lim <<= 1; //标准化AB数组 for(int i = 0; i < SZ(a); ++ i) A[i] = a[i]; for(int i = SZ(a);i < lim; ++ i) A[i] = 0; for(int i = 0; i < SZ(b); ++ i) B[i] = b[i]; for(int i = SZ(b);i < lim; ++ i) B[i] = 0; ntt(A, lim, 1), ntt(B, lim, 1);//系数表达转成点值表达 for(int i = 0;i < lim; ++ i) C[i] = A[i] * B[i] % P; ntt(C, lim, -1);//点值表达转换成系数表达 //将结果存入C数组 Poly c;for(int i = 0;i < SZ(a) + SZ(b) - 1; ++ i)c.push_back(C[i]); return c; } //堆的排序方法,较小的先积 struct cmp{bool operator ()(const Poly &a, const Poly &b)const{return SZ(a) > SZ(b);}}; }
代码 / Code
#include <bits/stdc++.h> #define int long long using namespace std; const int maxn = 2e6 + 9, P = 998244353; int fac[maxn], a[maxn], b[maxn], cnt = 0; namespace NTT { #define SZ(v) ((int)(v).size()) #define int long long typedef vector<int> Poly;//N要开4倍 const int N = 2e6 + 9, G = 3, P = 998244353; int qmi(int a, int b,int mod) { int res = 1; while(b) { if(b & 1) res = res * a % mod; a = a * a % mod, b >>= 1; } return res; } void ntt(int a[],int lim,int inv) { //lim是预先处理好的一个大于最大长度的2的倍数 for(int i = 0,j = 0;i < lim; ++ i) { if(i < j) swap(a[i], a[j]); for(int l = (lim >> 1); (j ^= l) < l; l >>= 1); } for(int m = 1,k = 2;k <= lim; m = k, k <<= 1) { int gn = qmi(G, (P - 1) / k, P); if(inv == -1) gn = qmi(gn, P-2, P); for(int i = 0;i < lim;i += k) { int g = 1; for(int j = 0;j < m;++ j,g = g * gn % P) { int u = a[i + j], v = a[i + j + m]; a[i + j] = u + g * v; a[i + j + m]= u - g * v; } } for(int i = 0;i < lim; ++ i) a[i] = (a[i] % P + P) % P; } if(inv == 1) return; int inv_ = qmi(lim, P - 2, P); for(int i = 0;i < lim; ++ i) a[i] = a[i] * inv_ % P; } Poly mul(Poly a,Poly b) { static int A[N], B[N], C[N]; int lim = 1;while(lim < SZ(a) + SZ(b) - 1) lim <<= 1; //标准化AB数组 for(int i = 0; i < SZ(a); ++ i) A[i] = a[i]; for(int i = SZ(a);i < lim; ++ i) A[i] = 0; for(int i = 0; i < SZ(b); ++ i) B[i] = b[i]; for(int i = SZ(b);i < lim; ++ i) B[i] = 0; ntt(A, lim, 1), ntt(B, lim, 1);//系数表达转成点值表达 for(int i = 0;i < lim; ++ i) C[i] = A[i] * B[i] % P; ntt(C, lim, -1);//点值表达转换成系数表达 //将结果存入C数组 Poly c;for(int i = 0;i < SZ(a) + SZ(b) - 1; ++ i)c.push_back(C[i]); return c; } //堆的排序方法,较小的先卷积 struct cmp{bool operator ()(const Poly &a, const Poly &b)const{return SZ(a) > SZ(b);}}; } int inv(int x){return NTT::qmi(x, P - 2, P);} int C(int n,int m){return (n < 0 || m < 0 || n < m) ? 0 : fac[n] * inv(fac[m]) % P * inv(fac[n - m]) % P;} int f(int n,int m){return (C(n - m - 1, m - 1) + C(n - m, m)) % P;} int dfs(int x, bitset<maxn> &vis) {//求环的大小的一个简单递归 if(vis[x])return 0;vis[x] = 1; return dfs(a[x], vis) + 1; } void solve() { int N, K;cin >> N >> K; cnt = 0; bitset<maxn> vis; for(int i = 1;i <= N; ++ i)cin >> a[i]; for(int i = 1;i <= N; ++ i)if(!vis[i])b[++ cnt] = dfs(i, vis); //环更新完毕 ,b[i]表示第i个环的大小,一共cnt个环 priority_queue<NTT::Poly, vector<NTT::Poly>, NTT::cmp> pq; for(int i = 1;i <= cnt; ++ i) { NTT::Poly tmp;//求得多项式并存入堆中 for(int j = 0;j <= b[i] / 2 && j <= K; ++ j)tmp.push_back(f(b[i], j)); pq.push(tmp); } while(pq.size() >= 2) { //每次选出项数最少的两个多项式进行合并,复杂度最小 NTT::Poly t1 = pq.top();pq.pop(); NTT::Poly t2 = pq.top();pq.pop(); pq.push(NTT::mul(t1, t2)); } NTT::Poly ans(pq.top()); //注意输出的时候需要判断如果无法选择K个,那么结果就0 cout << (K < ans.size() ? ans[K] : 0) << '\n'; } signed main() { ios::sync_with_stdio(0), cin.tie(0), cout.tie(0); //初始化阶乘 fac[0] = 1; for(int i = 1;i <= 5e5 + 1; ++ i)fac[i] = i * fac[i - 1] % P; int _;cin >> _; while(_ --)solve(); return 0; }
Comments NOTHING