104 lines
2.7 KiB
C++
104 lines
2.7 KiB
C++
#include <iostream>
|
||
#include <vector>
|
||
#include <complex>
|
||
#include <cmath>
|
||
#include <algorithm>
|
||
|
||
using namespace std;
|
||
|
||
// 使用复数类,complex<double> 自带加减乘运算
|
||
typedef complex<double> cd;
|
||
const double PI = acos(-1.0);
|
||
|
||
/**
|
||
* FFT 核心函数
|
||
* @param a 待转换的系数向量
|
||
* @param invert 是否为逆变换 (IFFT)。false 为 DFT,true 为 IDFT
|
||
*/
|
||
void fft(vector<cd>& a, bool invert) {
|
||
int n = a.size();
|
||
|
||
// 1. 位逆序置换 (Bit-reversal permutation)
|
||
// 将原数组的下标进行二进制翻转,使得最底层的系数排列在正确位置
|
||
for (int i = 1, j = 0; i < n; i++) {
|
||
int bit = n >> 1;
|
||
for (; j & bit; bit >>= 1)
|
||
j ^= bit;
|
||
j ^= bit;
|
||
if (i < j) swap(a[i], a[j]);
|
||
}
|
||
|
||
// 2. 自底向上合并 (迭代实现蝴蝶变换)
|
||
// len 是当前合并区间的长度,从 2, 4, 8 ... 一直合并到 n
|
||
for (int len = 2; len <= n; len <<= 1) {
|
||
double ang = 2 * PI / len * (invert ? -1 : 1);
|
||
cd wlen(cos(ang), sin(ang)); // 单位根 w_n^1
|
||
|
||
for (int i = 0; i < n; i += len) {
|
||
cd w(1); // 初始 w_n^0 = 1
|
||
for (int j = 0; j < len / 2; j++) {
|
||
// 蝴蝶变换核心公式:
|
||
// u = A_e, v = w * A_o
|
||
cd u = a[i + j], v = a[i + j + len / 2] * w;
|
||
a[i + j] = u + v; // A(w^k) = A_e + w*A_o
|
||
a[i + j + len / 2] = u - v; // A(w^{k+n/2}) = A_e - w*A_o
|
||
w *= wlen; // 迭代单位根
|
||
}
|
||
}
|
||
}
|
||
|
||
// 3. 如果是逆变换,最后需要除以 n
|
||
if (invert) {
|
||
for (cd & x : a)
|
||
x /= n;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 多项式乘法
|
||
*/
|
||
vector<long long> multiply(vector<int> const& a, vector<int> const& b) {
|
||
vector<cd> fa(a.begin(), a.end()), fb(b.begin(), b.end());
|
||
|
||
// 补零到 2 的幂次
|
||
int n = 1;
|
||
while (n < a.size() + b.size())
|
||
n <<= 1;
|
||
fa.resize(n);
|
||
fb.resize(n);
|
||
|
||
// 1. DFT: 系数 -> 点值
|
||
fft(fa, false);
|
||
fft(fb, false);
|
||
|
||
// 2. Pointwise Multiplication: 点值直接相乘
|
||
for (int i = 0; i < n; i++)
|
||
fa[i] *= fb[i];
|
||
|
||
// 3. IDFT: 点值 -> 系数
|
||
fft(fa, true);
|
||
|
||
// 4. 取实部并四舍五入
|
||
vector<long long> result(n);
|
||
for (int i = 0; i < n; i++)
|
||
result[i] = round(fa[i].real());
|
||
|
||
return result;
|
||
}
|
||
|
||
int main() {
|
||
// 你的例子:
|
||
// f1 = 1 + 8x + 1x^2 + 1x^3 -> [1, 8, 1, 1]
|
||
// f2 = -2 - 5x + 1x^2 -> [-2, -5, 1]
|
||
vector<int> f1 = {1, 8, 1, 1};
|
||
vector<int> f2 = {-2, -5, 1};
|
||
|
||
vector<long long> res = multiply(f1, f2);
|
||
|
||
cout << "Result coefficients: ";
|
||
for (int i = 0; i < f1.size() + f2.size() - 1; i++) {
|
||
cout << res[i] << " ";
|
||
}
|
||
return 0;
|
||
}
|