#include #include #include #include #include using namespace std; // 使用复数类,complex 自带加减乘运算 typedef complex cd; const double PI = acos(-1.0); /** * FFT 核心函数 * @param a 待转换的系数向量 * @param invert 是否为逆变换 (IFFT)。false 为 DFT,true 为 IDFT */ void fft(vector& a, bool invert) { int n = a.size(); // 1. 位逆序置换 (Bit-reversal permutation) // 将原数组的下标进行二进制翻转,使得最底层的系数排列在正确位置 for (int i = 1, j = 0; i < n; i++) { int bit = n >> 1; for (; j & bit; bit >>= 1) j ^= bit; j ^= bit; if (i < j) swap(a[i], a[j]); } // 2. 自底向上合并 (迭代实现蝴蝶变换) // len 是当前合并区间的长度,从 2, 4, 8 ... 一直合并到 n for (int len = 2; len <= n; len <<= 1) { double ang = 2 * PI / len * (invert ? -1 : 1); cd wlen(cos(ang), sin(ang)); // 单位根 w_n^1 for (int i = 0; i < n; i += len) { cd w(1); // 初始 w_n^0 = 1 for (int j = 0; j < len / 2; j++) { // 蝴蝶变换核心公式: // u = A_e, v = w * A_o cd u = a[i + j], v = a[i + j + len / 2] * w; a[i + j] = u + v; // A(w^k) = A_e + w*A_o a[i + j + len / 2] = u - v; // A(w^{k+n/2}) = A_e - w*A_o w *= wlen; // 迭代单位根 } } } // 3. 如果是逆变换,最后需要除以 n if (invert) { for (cd & x : a) x /= n; } } /** * 多项式乘法 */ vector multiply(vector const& a, vector const& b) { vector fa(a.begin(), a.end()), fb(b.begin(), b.end()); // 补零到 2 的幂次 int n = 1; while (n < a.size() + b.size()) n <<= 1; fa.resize(n); fb.resize(n); // 1. DFT: 系数 -> 点值 fft(fa, false); fft(fb, false); // 2. Pointwise Multiplication: 点值直接相乘 for (int i = 0; i < n; i++) fa[i] *= fb[i]; // 3. IDFT: 点值 -> 系数 fft(fa, true); // 4. 取实部并四舍五入 vector result(n); for (int i = 0; i < n; i++) result[i] = round(fa[i].real()); return result; } int main() { // 你的例子: // f1 = 1 + 8x + 1x^2 + 1x^3 -> [1, 8, 1, 1] // f2 = -2 - 5x + 1x^2 -> [-2, -5, 1] vector f1 = {1, 8, 1, 1}; vector f2 = {-2, -5, 1}; vector res = multiply(f1, f2); cout << "Result coefficients: "; for (int i = 0; i < f1.size() + f2.size() - 1; i++) { cout << res[i] << " "; } return 0; }