题目传送门:P4245 【模板】任意模数多项式乘法 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
题目 / Problem
NTT模板题,任意模数。
思路 / Thought
中国剩余定理。
代码 / Code
#include <bits/stdc++.h> #define int long long using namespace std; const int maxn = 4e5 + 9; namespace NTT { #define SZ(v) ((int)v.size()) #define int long long typedef vector<int> poly; const int N = maxn; const int P[3] = {998244353, 1004535809, 469762049}, G = 3; int qmi(int a,int b, int p) { int res = 1; while(b) { if(b & 1)res = res * a % p; a = a * a % p, b >>= 1; } return res; } void ntt(int a[], int lim, int inv,int p) { for(int i = 0, j = 0;i < lim; ++ i){ if(i < j)swap(a[i], a[j]); for(int l = (lim >> 1); (j ^= l) < l;l >>= 1); } for(int m = 1,k = 2;k <= lim;m = k, k <<= 1) { int gn = qmi(G, (p - 1) / k, p); if(inv == -1)gn = qmi(gn, p - 2, p); for(int i = 0;i < lim;i += k) for(int j = 0, g = 1;j < m; ++ j, g = g * gn % p){ int u = a[i + j], v = a[i + j + m]; a[i + j] = u + g * v; a[i+j+m] = u - g * v; } for(int i = 0;i < lim; ++ i)a[i] = (a[i] % p + p) % p; } if(inv == 1)return; int inv_ = qmi(lim, p - 2, p); for(int i = 0;i < lim; ++ i)a[i] = a[i] * inv_ % p; } int get(int a, int b, int c,int p)//中国剩余定理转换为mod p情况下的数字 { int m1 = P[0], m2 = P[1], m3 = P[2]; int i1 = qmi(m1, m2 - 2, m2), i2 = qmi(m1 * m2 % m3, m3 - 2, m3); int x = (b - a + m2) % m2 * i1 % m2 * m1 + a; return ((c - x % m3 + m3) % m3 * i2 % m3 * (m1 * m2 % p) % p + x) % p; } poly mul(poly a,poly b,int mod)//mod是题目给的模数 { poly c[3];//若果模数不是 998244353 就需要用3个模数再用中国剩余定理合并 for(int t = 0;t < 3; ++ t) { int p = P[t]; static int A[N], B[N], C[N]; int lim = 1;while(lim < SZ(a) + SZ(b) - 1)lim <<= 1; for(int i = 0;i < SZ(a); ++ i) A[i] = a[i]; for(int i = SZ(a);i < lim; ++ i)A[i] = 0; for(int i = 0;i < SZ(b); ++ i) B[i] = b[i]; for(int i = SZ(b);i < lim; ++ i)B[i] = 0; ntt(A, lim, 1, p), ntt(B, lim, 1, p); for(int i = 0;i < lim; ++ i)C[i] = A[i] * B[i] % p; ntt(C, lim, -1, p); for(int i = 0;i < SZ(a) + SZ(b) - 1; ++ i)c[t].push_back(C[i]); } for(int i = 0;i < c[0].size(); ++ i)c[0][i] = get(c[0][i], c[1][i], c[2][i], mod); return c[0]; } } signed main() { ios::sync_with_stdio(0), cin.tie(0), cout.tie(0); int n, m, p;cin >> n >> m >> p; NTT::poly a, b; for(int i = 0;i <= n; ++ i) { int x;cin >> x; a.push_back(x); } for(int i = 0;i <= m; ++ i) { int x;cin >> x; b.push_back(x); } NTT::poly c = NTT::mul(a, b, p); for(int i = 0;i < c.size(); ++ i)cout << c[i] << ' '; return 0; }
Comments NOTHING