[Codeforces *2100]D. Carry Bit(位运算 + 组合计数)

发布于 2022-11-23  353 次阅读


题目链接:Problem - D - Codeforces

Problem

给定两个整数\(n, k (0 \le k<n \le 10^6)\),求出满足以下条件的\((a, b)\)二元组的个数,\((a, b)\)和\((b, a)\)视作不同的两个二元组。

\((a, b)\)条件:\(0 \le a, b < 2^n\),且\(a + b\)的二进制运算中,进位的个数为\(k\)。

转化为公式就是:

$$设f(a, b)为a + b进位个数$$

$$ans = \sum\limits_{a=0}^{2^n - 1}\sum\limits_{b=0}^{2^n - 1}[f(a, b) == k]$$

下面是一个例子:

举例说明

Analyse

定义一个数组\(c\),表示进位。

\(c_i = 1\)表示第\(i\)位将会向更高位产生进位,规定\(c_{-1} = 0\),这是显然的,因为第0位并不需要考虑低位的进位,相当于第-1位的进位为0。

我们规定\(a_i, b_i\)分别为数字\(a, b\)的二进制位第\(i\)位,当然也可以理解为\(a, b\)的二进制位所形成的数组的第\(i\)位。

那么就不难发现\(c_i, a_i, b_i, c_{i - 1}\)之间的关系。当三元集合\(\{ a_i, b_i, c_{i-1} \}\)中的\(1\)的个数\(\ge 2\)时,\(c_i = 1\),其余情况,\(c_i = 0\)。

组合计数的关键在于:将复杂的问题不重不漏划分为多个可以计算的子问题,并找出足够强规律,运用经典模型进行计算。

那么我们可以以\(c_i, c_{i-1}\)为根据,分类讨论:

\(c_i\)\(c_{i-1}\)\((a, b)\)
\(0\)\(0\)\((0, 0), (0, 1), (1, 0)\)
\(1\)\(1\)\((1, 1), (0, 1), (1, 0)\)
\(1\)\(0\)\((1, 1)\)
\(0\)\(1\)\( (0, 0) \)

现在来找规律,其实已经很明显了,当\(c_i = c_{i-1}\)时,\((a, b)\)取值有3种,而\(c_i \ne c_{i-1}\)时,\((a, b)\)的取值只有1种。

我们设一个变量\(q\),表示\(c\)数组中\(c_i \ne c_{i-1}\)的组数。

对于任意一个\(q\),在简单画个图之后不难发现,当\(q\)确定后,整个\(c\)数组被划分为\(q+1\)段,每一段都是连续的\(1\)或连续的\(0\)。并且这个段数和\(q\)的大小以及奇偶性有关。

我们定义\(s_0, s_1\)分别为\(c\)的\([-1, n - 1]\)区间内的\(0,1\)的段数。

q与s的关系

现在问题转化成了简单的模型:将\(a\)个物品分为不为空的\(b\)份,用隔板法容易算出答案是\(C_{a - 1}^{b - 1}\)。

现在枚举\(q\),在\(q\)已知的情况下,我们就可以算出在某一种状态下对答案的贡献时\(3^{n-q}\),\(n-q\)也就是\(c_i = c_{i-1}\)的组数,而每一组都会贡献3的\(a, b\)方案数。

现在再求一下已知\(q\)情况下的状态总数,其实只需要将\(1\)的个数分为\(s1\)份,\(0\)的个数分为\(s0\)份即可,分完之后排列方式是唯一的,一定是从右往左\(0, 0, .., 1, 1, ...\)这样排,因为\(c_{-1}\)定义为\(0\)了。

显然题目给定了\(1\)的个数和\(0\)的个数,分别为\(k\)和\(n - k + 1\),注意加上\(c_{-1}\)这个0。

注意特判\(k=0\)的情况,\(q\)只能为\(0\)且状态数只有一种,所以答案就是\(3^n\)。

至此,问题解决。预处理一下阶乘和3的幂,时间复杂度\(O(n)\)。

Code

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 1e6 + 9, p = 1e9 + 7;
int fac[maxn], invfac[maxn], pow3[maxn];


int qmi(int a,int b)
{
	int res = 1;
	while(b)
	{
		if(b & 1)res = res * a % p;
		a = a * a % p, b >>= 1;
	}	
	return res;
}

int inv(int x){return x == 0 ? 1 : qmi(x, p - 2);}

int C(int n, int m)
{
	if(n < m || n < 0 || m < 0)return 0;
	return fac[n] * invfac[m] % p * invfac[n - m] % p;
}

signed main()
{
	int n, k;scanf("%lld %lld", &n, &k);
	
	fac[0] = 1;
	for(int i = 1;i <= n; ++ i)fac[i] = fac[i - 1] * i % p;
	
	invfac[n] = inv(fac[n]);
	for(int i = n - 1;i >= 0; -- i)invfac[i] = invfac[i + 1] * (i + 1) % p;
	pow3[0] = 1;
	for(int i = 1;i <= n; ++ i)pow3[i] = pow3[i - 1] * 3 % p;

	int ans = 0;
	
	if(k == 0)ans = qmi(3, n);
	else
	{
		for(int q = 0;q <= n; ++ q)
		{
			int tmp = pow3[n - q] * C(k - 1, q / 2 + (q % 2) - 1) % p * C(n - k, q / 2) % p; 
			ans = (ans + tmp) % p;
		}
	}

	printf("%lld\n", ans);
	
	return 0;
}