106 lines
2.6 KiB
C++
106 lines
2.6 KiB
C++
#include <iostream>
|
||
#include <vector>
|
||
#include <complex>
|
||
#include <cmath>
|
||
|
||
using namespace std;
|
||
|
||
typedef complex<double> cd;
|
||
const double PI = acos(-1.0);
|
||
|
||
/**
|
||
* 递归版 FFT
|
||
* @param a 传入的系数向量
|
||
* @param invert 是否为逆变换(FFT or IFFT)
|
||
*/
|
||
void fft_recursive(vector<cd>& a, bool invert) {
|
||
int n = a.size();
|
||
|
||
// 递归终点:当多项式只有一个常数项时,直接返回
|
||
if (n == 1) return;
|
||
|
||
// 1. 拆分阶段 (Divide)
|
||
// 按照下标奇偶性将 a 拆分成 a_e 和 a_o
|
||
vector<cd> a0(n / 2), a1(n / 2);
|
||
for (int i = 0; 2 * i < n; i++) {
|
||
a0[i] = a[2 * i]; // 偶数项
|
||
a1[i] = a[2 * i + 1]; // 奇数项
|
||
}
|
||
|
||
// 2. 递归阶段 (Conquer)
|
||
fft_recursive(a0, invert);
|
||
fft_recursive(a1, invert);
|
||
|
||
// 3. 合并阶段 (Combine / Butterfly Operation)
|
||
// 计算单位根 w_n^k
|
||
double ang = 2 * PI / n * (invert ? -1 : 1);
|
||
cd w(1); // 初始 w = w_n^0 = 1
|
||
cd wn(cos(ang), sin(ang)); // 单位步进 w_n^1
|
||
|
||
for (int k = 0; k < n / 2; k++) {
|
||
// 利用蝴蝶变换公式:
|
||
// A(w^k) = Ae(w^k) + w * Ao(w^k)
|
||
// A(w^{k+n/2}) = Ae(w^k) - w * Ao(w^k)
|
||
|
||
cd t = w * a1[k]; // 旋转因子与奇数项乘积
|
||
a[k] = a0[k] + t;
|
||
a[k + n / 2] = a0[k] - t;
|
||
|
||
if (invert) {
|
||
// 如果是逆变换,顺便除以 2(递归层层除以 2,最终相当于除以 n)
|
||
a[k] /= 2;
|
||
a[k + n / 2] /= 2;
|
||
}
|
||
|
||
w *= wn; // 移动到下一个单位根 w_n^{k+1}
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 多项式乘法封装
|
||
*/
|
||
vector<long long> multiply(vector<int> const& f1, vector<int> const& f2) {
|
||
vector<cd> fa(f1.begin(), f1.end()), fb(f2.begin(), f2.end());
|
||
|
||
// 补零到 2 的幂
|
||
int n = 1;
|
||
while (n < f1.size() + f2.size()) n <<= 1;
|
||
fa.resize(n);
|
||
fb.resize(n);
|
||
|
||
// 变换到点值表示
|
||
fft_recursive(fa, false);
|
||
fft_recursive(fb, false);
|
||
|
||
// 点值相乘
|
||
for (int i = 0; i < n; i++)
|
||
fa[i] *= fb[i];
|
||
|
||
// 变换回系数表示
|
||
fft_recursive(fa, true);
|
||
|
||
// 结果取整
|
||
vector<long long> result;
|
||
for (int i = 0; i < n; i++)
|
||
result.push_back(round(fa[i].real()));
|
||
|
||
return result;
|
||
}
|
||
|
||
int main() {
|
||
// A(x) = 1x^3 + 1x^2 + 8x + 1
|
||
// B(x) = 1x^2 - 5x - 2
|
||
vector<int> a = {1, 8, 1, 1}; // 注意低位在前
|
||
vector<int> b = {-2, -5, 1};
|
||
|
||
vector<long long> res = multiply(a, b);
|
||
|
||
cout << "Final Coefficients: ";
|
||
for (int i = 0; i < a.size() + b.size() - 1; i++) {
|
||
cout << res[i] << " ";
|
||
}
|
||
cout << endl;
|
||
|
||
return 0;
|
||
}
|