[CCPC2017哈尔滨站]Palindrome(Manacher + 树状数组)

发布于 2022-11-02  318 次阅读


题目链接:Palindrome - HDU 6230 - Virtual Judge (vjudge.net)

从这里进也可以:Problem - 6230 (hdu.edu.cn)

题目大意 / Problem

有多组样例。

每组样例给一个字符串,问有多少个“一个半回文串”,一个半回文串的定义是,第一个回文串的右端点是第二个回文串的中心,第一个回文串的中心是第二个回文串的端点。

比如abcbabc就是“一个半”回文串,其中abcba是第一个回文串,cbabc是第二个回文串,两者合到一起组成了“一个半回文串”。

分析 / Analyse

不难发现,当两个回文串的中心相互包含,就可以组成1个且仅有一个“一个半回文串”。比如有两个回文串p1和p2,只要p1的中心在p2范围内,p2的中心在p1范围内即可。

解释一下为什么“仅有一个”:因为确定了两个中心,那么这两个回文串的半径也就确定了,这两个回文串也就确定了,这“一个半”回文串也就确定了,仅有一个。

同时可以发现,题目转变为:找二元组(i, j)使得i < j 且 i,j分别为p1和p2的中心。只要找出这样的二元组数量即可。其中i < j是构造偏序关系,避免二元组的重复,比如(1, 3)和(3, 1)其实只应该被算一次。

那么如何找出所有的二元组呢?

我们需要得到所有点为中心的回文串半径,有了这个就可以通过一些方法来计数了。

用manacher算法得到所有点为中心的回文串半径后,我们可以将所有的(i - p[i] + 1, i)作为“左边”存下来。

示意图

其中l[i], r[i]分别表示以i为中心点的左右两端,存下来方便后续处理。

manacher算法需要数组开3倍空间。

代码 / Code

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 2e6 + 9;
char s[maxn];
int p[maxn], l[maxn], r[maxn];
int N;


void Manacher(int N)
{
	memset(p, 0, sizeof(int) * (2 * N));
	for(int i = 2 * N + 1;i >= 1; -- i)
		s[i] = (i & 1) ? '#' : s[i / 2];      
     
    s[0] = '^', s[2 * N + 2] = '$';//将头部和尾部特殊化 
     
	int C = 0, R = 0;//R为右边界开区间
	for(int i = 1;i <= 2 * N + 1; ++ i)
	{
		p[i] = i < R ? min(R - i, p[2 * C - i]) : 1;
		while(s[i + p[i]] == s[i - p[i]])p[i] ++;
		if(i + p[i] > R)R = i + p[i], C = i;
	} 
	for(int i = 1;i <= N; ++ i)s[i] = s[i * 2], p[i] = p[i * 2] / 2;
	for(int i = 1;i <= N; ++ i)l[i] = i - p[i] + 1, r[i] = i + p[i] - 1;
}

int d[maxn];
int lowbit(int x)
{
	return x & -x;
}
void add(int k, int x)
{
	for(int i = k;i <= N;i += lowbit(i))d[i] += x;
}

int getsum(int k)
{
	int res = 0;
	for(int i = k;i > 0;i -= lowbit(i))res += d[i];
	return res;
}

void solve()
{
	cin >> s + 1;
	N = strlen(s + 1);
	Manacher(N);
	memset(d, 0, sizeof(int) * (N + 5));
	//for(int i = 1;i <= N; ++ i)cout << p[i] << ' '; 
	
	vector<pair<int, int> > v;
	for(int i = 1;i <= N; ++ i)
		if(r[i] - l[i] + 1 >= 2)v.push_back({l[i], i});
	
	sort(v.begin(), v.end(), [](const pair<int, int> &a, const pair<int, int> &b)
	{
		return a.first == b.first ? a.second < b.second : a.first < b.first;
	});
	
	int t = 0, ans = 0;
	for(int i = 1;i <= N; ++ i)
	{
		while(t < (int)v.size() && v[t].first <= i)add(v[t].second, 1), t ++;
		ans += getsum(r[i]) - getsum(i);
	}
	cout << ans << '\n';
}

signed main()
{
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
	int _;cin >> _;
	while(_ --)solve();
	return 0;
}