[杭电多校3 | HDUOJ]7173.Two Permutations(DFS + 记忆化搜索 或 DP + Hash)

发布于 2022-07-27  471 次阅读


题目链接:Problem - 7173 (hdu.edu.cn)

题意 / Problem

给定2个排列,P、Q,长度为N,给定一个长度为2N的序列S,问能够通过以下方法构造出S:

序列Q初始为空,每次从P或Q的最左侧选出一个数字放到序列Q的最右侧。

问构造的方案数,只要其中一步不同就算作不同方案。

方法一(DFS + 记忆化搜索)/ Solution1

这个应该不是正解,因为复杂度太离谱了。

这是我比赛的时候写的,基本是卡过去的。

代码 / Code

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 3e5 + 9, p = 998244353;
int P[maxn], Q[maxn], S[maxn << 1], N;
unordered_map<int, int> dp[maxn * 2];

int dfs(int dep, int kp, int kq)
{
	if(dep == 2 * N + 1)return 1;
	if(dp[dep][kp])return dp[dep][kp];
	
	int res = 0;
	if(kp <= N && S[dep] == P[kp])res = (res + dfs(dep + 1, kp + 1, kq)) % p;
	if(kq <= N && S[dep] == Q[kq])res = (res + dfs(dep + 1, kp, kq + 1)) % p;
	dp[dep][kp] = res;
	return res;
}

void solve()
{
	cin >> N;
	for(int i = 1;i <= N; ++ i)cin >> P[i];
	for(int i = 1;i <= N; ++ i)cin >> Q[i];
	for(int i = 1;i <= 2 * N; ++ i)
	{
		cin >> S[i];
		dp[i].clear();
	}
	cout << dfs(1, 1, 1) << '\n';
}

signed main()
{
	ios::sync_with_stdio(0);
	int _;cin >> _; 
	while(_ --)solve();
	return 0;
}
AC图

耗时和空间都非常大!属于是卡过去的。

方法二(DP + Hash)正解 / Solution2

在S中,1~N的数字每个都会出现两次,否则直接输出0。

记录下每个数字在S中两次出现的位置,记作pos[S[i][0]和pos[S[i][1]。既然是dp,就需要一个状态,dp[i][j]表示P中前i - 1个元素方案确定的情况下,第i个元素选择S中对应元素的第一或第二个位置(也就是0 / 1的位置)的方案数。

我们这里采用从前往后dp的方法,从第i个状态到第i+1个状态有四条路径,分别是dp[i][0]->dp[i+1][0],dp[i][0]->dp[i+1][1],dp[i][1]->dp[i+1][0],dp[i][1]->dp[i+1][1],转移方法是,如果可以转移就直接全部加到后面去。

怎样是“可以转移”的呢?举个例子,从dp[i][0]转移到dp[i+1][1],比如前者在S中的位置是x,后者在S中的位置是y,如果在S中(x, y)开区间中的所有元素是由全部Q中对应元素填充的(因为P中连续两个元素之间不可能还有元素来填充这段区间),那么就认为“可以转移”。

我们首先确定在S中的(x, y)区间,那么需要判断的区间就是[x + 1, y - 1] == [l, r],还需要确定一个在Q中的[ql, qr]区间,这里ql = l - i,qr = ql + r - l = r - i,这里是通过简单的数学方法算出来的,因为到y为止,一共选了y个元素,且P中选了i + 1个,所以y - (i + 1)就是Q中所选的所有元素,那么Q中判断区间的起始位置ql就应该是y - (i + 1) - (r - l),将y = r + 1代入得ql = l - i, 所以qr = ql + r - l = r - i;

这里的对应区间相等,可以用字符串Hash来进行O(1)的判断。

具体的实现看代码。

复杂度O(N),正解属于是。

代码 / Code

正解
#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef unsigned long long ULL;
const ULL maxn = 3e5 + 9, p = 998244353, base = 1333331;
ULL P[maxn], Q[maxn], S[maxn << 1], N;

//Hash
ULL hS[maxn << 1], hQ[maxn], b[maxn << 1];
void initHash(ULL a[], ULL n, ULL h[]){for(int i = 1;i <= n; ++ i)h[i] = h[i - 1] * base + a[i];}
ULL getHash(ULL h[], int l,int r){return h[r] - h[l - 1] * b[r - l + 1];}

int dp[maxn][2], pos[maxn][2];

bool check(int i,int l,int r)
{
	if(l > r)return 1;
	int ql = l - i, qr = r - i;
	if(qr > N)return 0;//越界了返回假 
	return getHash(hQ, ql, qr) == getHash(hS, l, r);
} 


void solve()
{
	memset(dp, 0, sizeof dp);
	memset(pos, 0, sizeof pos);
	
	cin >> N;
	for(int i = 1;i <= N; ++ i)cin >> P[i];
	for(int i = 1;i <= N; ++ i)cin >> Q[i];
	bool tag = 1;//tag表示S串是否合法 
	for(int i = 1;i <= 2 * N; ++ i)
	{ 
		cin >> S[i];
		if(!pos[S[i]][0])pos[S[i]][0] = i;
		else if(!pos[S[i]][1])pos[S[i]][1] = i;
		else tag = 0;
	}
	
	if(!tag)//S串不合法直接输出0并退出 
	{
		cout << 0 << '\n';
		return; 
	}
	
	//初始化Hash 
	initHash(Q, N << 0, hQ);//N << 0仅为了代码工整好看 
	initHash(S, N << 1, hS);
	//先从 0 -> 1 进行dp的初始化
	for(int i = 0;i < 2; ++ i)
	{
		int x = 0, y = pos[P[1]][i];
		if(check(0, x + 1, y - 1))dp[1][i] = 1;
	}
	//逐层递推 
	for(int i = 1;i < N; ++ i)
		for(int j = 0;j < 2; ++ j)
		{
			int x = pos[P[i]][j];
			for(int k = 0;k < 2; ++ k)
			{
				int y = pos[P[i + 1]][k];
				if(x < y && check(i, x + 1,y - 1))dp[i + 1][k] = (dp[i + 1][k] + dp[i][j]) % p;
			}
		}
	
	int ans = 0;
	for(int i = 0;i < 2; ++ i)
	{
		int x = pos[P[N]][i], y = 2 * N + 1;
		if(check(N, x + 1, y - 1))ans = (ans + dp[N][i]) % p;
	}
	cout << ans << '\n';
}

signed main()
{
	ios::sync_with_stdio(0);
	//初始化b数组 
	b[0] = 1;
	for(int i = 1;i <= maxn * 2 - 5; ++ i)b[i] = b[i - 1] * base;
	int _;cin >> _; 
	while(_ --)solve();
	return 0;
}