题目链接:Problem - 7173 (hdu.edu.cn)
题意 / Problem
给定2个排列,P、Q,长度为N,给定一个长度为2N的序列S,问能够通过以下方法构造出S:
序列Q初始为空,每次从P或Q的最左侧选出一个数字放到序列Q的最右侧。
问构造的方案数,只要其中一步不同就算作不同方案。
方法一(DFS + 记忆化搜索)/ Solution1
这个应该不是正解,因为复杂度太离谱了。
这是我比赛的时候写的,基本是卡过去的。
代码 / Code
#include <bits/stdc++.h> #define int long long using namespace std; const int maxn = 3e5 + 9, p = 998244353; int P[maxn], Q[maxn], S[maxn << 1], N; unordered_map<int, int> dp[maxn * 2]; int dfs(int dep, int kp, int kq) { if(dep == 2 * N + 1)return 1; if(dp[dep][kp])return dp[dep][kp]; int res = 0; if(kp <= N && S[dep] == P[kp])res = (res + dfs(dep + 1, kp + 1, kq)) % p; if(kq <= N && S[dep] == Q[kq])res = (res + dfs(dep + 1, kp, kq + 1)) % p; dp[dep][kp] = res; return res; } void solve() { cin >> N; for(int i = 1;i <= N; ++ i)cin >> P[i]; for(int i = 1;i <= N; ++ i)cin >> Q[i]; for(int i = 1;i <= 2 * N; ++ i) { cin >> S[i]; dp[i].clear(); } cout << dfs(1, 1, 1) << '\n'; } signed main() { ios::sync_with_stdio(0); int _;cin >> _; while(_ --)solve(); return 0; }

耗时和空间都非常大!属于是卡过去的。
方法二(DP + Hash)正解 / Solution2
在S中,1~N的数字每个都会出现两次,否则直接输出0。
记录下每个数字在S中两次出现的位置,记作pos[S[i][0]和pos[S[i][1]。既然是dp,就需要一个状态,dp[i][j]表示P中前i - 1个元素方案确定的情况下,第i个元素选择S中对应元素的第一或第二个位置(也就是0 / 1的位置)的方案数。
我们这里采用从前往后dp的方法,从第i个状态到第i+1个状态有四条路径,分别是dp[i][0]->dp[i+1][0],dp[i][0]->dp[i+1][1],dp[i][1]->dp[i+1][0],dp[i][1]->dp[i+1][1],转移方法是,如果可以转移就直接全部加到后面去。
怎样是“可以转移”的呢?举个例子,从dp[i][0]转移到dp[i+1][1],比如前者在S中的位置是x,后者在S中的位置是y,如果在S中(x, y)开区间中的所有元素是由全部Q中对应元素填充的(因为P中连续两个元素之间不可能还有元素来填充这段区间),那么就认为“可以转移”。
我们首先确定在S中的(x, y)区间,那么需要判断的区间就是[x + 1, y - 1] == [l, r],还需要确定一个在Q中的[ql, qr]区间,这里ql = l - i,qr = ql + r - l = r - i,这里是通过简单的数学方法算出来的,因为到y为止,一共选了y个元素,且P中选了i + 1个,所以y - (i + 1)就是Q中所选的所有元素,那么Q中判断区间的起始位置ql就应该是y - (i + 1) - (r - l),将y = r + 1代入得ql = l - i, 所以qr = ql + r - l = r - i;
这里的对应区间相等,可以用字符串Hash来进行O(1)的判断。
具体的实现看代码。
复杂度O(N),正解属于是。
代码 / Code

#include <bits/stdc++.h> #define int long long using namespace std; typedef unsigned long long ULL; const ULL maxn = 3e5 + 9, p = 998244353, base = 1333331; ULL P[maxn], Q[maxn], S[maxn << 1], N; //Hash ULL hS[maxn << 1], hQ[maxn], b[maxn << 1]; void initHash(ULL a[], ULL n, ULL h[]){for(int i = 1;i <= n; ++ i)h[i] = h[i - 1] * base + a[i];} ULL getHash(ULL h[], int l,int r){return h[r] - h[l - 1] * b[r - l + 1];} int dp[maxn][2], pos[maxn][2]; bool check(int i,int l,int r) { if(l > r)return 1; int ql = l - i, qr = r - i; if(qr > N)return 0;//越界了返回假 return getHash(hQ, ql, qr) == getHash(hS, l, r); } void solve() { memset(dp, 0, sizeof dp); memset(pos, 0, sizeof pos); cin >> N; for(int i = 1;i <= N; ++ i)cin >> P[i]; for(int i = 1;i <= N; ++ i)cin >> Q[i]; bool tag = 1;//tag表示S串是否合法 for(int i = 1;i <= 2 * N; ++ i) { cin >> S[i]; if(!pos[S[i]][0])pos[S[i]][0] = i; else if(!pos[S[i]][1])pos[S[i]][1] = i; else tag = 0; } if(!tag)//S串不合法直接输出0并退出 { cout << 0 << '\n'; return; } //初始化Hash initHash(Q, N << 0, hQ);//N << 0仅为了代码工整好看 initHash(S, N << 1, hS); //先从 0 -> 1 进行dp的初始化 for(int i = 0;i < 2; ++ i) { int x = 0, y = pos[P[1]][i]; if(check(0, x + 1, y - 1))dp[1][i] = 1; } //逐层递推 for(int i = 1;i < N; ++ i) for(int j = 0;j < 2; ++ j) { int x = pos[P[i]][j]; for(int k = 0;k < 2; ++ k) { int y = pos[P[i + 1]][k]; if(x < y && check(i, x + 1,y - 1))dp[i + 1][k] = (dp[i + 1][k] + dp[i][j]) % p; } } int ans = 0; for(int i = 0;i < 2; ++ i) { int x = pos[P[N]][i], y = 2 * N + 1; if(check(N, x + 1, y - 1))ans = (ans + dp[N][i]) % p; } cout << ans << '\n'; } signed main() { ios::sync_with_stdio(0); //初始化b数组 b[0] = 1; for(int i = 1;i <= maxn * 2 - 5; ++ i)b[i] = b[i - 1] * base; int _;cin >> _; while(_ --)solve(); return 0; }
Comments NOTHING