Files
Data-Structure/Algorithm/FFT/template-recursive.cpp
2026-02-13 12:38:29 +08:00

106 lines
2.6 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include <iostream>
#include <vector>
#include <complex>
#include <cmath>
using namespace std;
typedef complex<double> cd;
const double PI = acos(-1.0);
/**
* 递归版 FFT
* @param a 传入的系数向量
* @param invert 是否为逆变换(FFT or IFFT)
*/
void fft_recursive(vector<cd>& a, bool invert) {
int n = a.size();
// 递归终点:当多项式只有一个常数项时,直接返回
if (n == 1) return;
// 1. 拆分阶段 (Divide)
// 按照下标奇偶性将 a 拆分成 a_e 和 a_o
vector<cd> a0(n / 2), a1(n / 2);
for (int i = 0; 2 * i < n; i++) {
a0[i] = a[2 * i]; // 偶数项
a1[i] = a[2 * i + 1]; // 奇数项
}
// 2. 递归阶段 (Conquer)
fft_recursive(a0, invert);
fft_recursive(a1, invert);
// 3. 合并阶段 (Combine / Butterfly Operation)
// 计算单位根 w_n^k
double ang = 2 * PI / n * (invert ? -1 : 1);
cd w(1); // 初始 w = w_n^0 = 1
cd wn(cos(ang), sin(ang)); // 单位步进 w_n^1
for (int k = 0; k < n / 2; k++) {
// 利用蝴蝶变换公式:
// A(w^k) = Ae(w^k) + w * Ao(w^k)
// A(w^{k+n/2}) = Ae(w^k) - w * Ao(w^k)
cd t = w * a1[k]; // 旋转因子与奇数项乘积
a[k] = a0[k] + t;
a[k + n / 2] = a0[k] - t;
if (invert) {
// 如果是逆变换,顺便除以 2递归层层除以 2最终相当于除以 n
a[k] /= 2;
a[k + n / 2] /= 2;
}
w *= wn; // 移动到下一个单位根 w_n^{k+1}
}
}
/**
* 多项式乘法封装
*/
vector<long long> multiply(vector<int> const& f1, vector<int> const& f2) {
vector<cd> fa(f1.begin(), f1.end()), fb(f2.begin(), f2.end());
// 补零到 2 的幂
int n = 1;
while (n < f1.size() + f2.size()) n <<= 1;
fa.resize(n);
fb.resize(n);
// 变换到点值表示
fft_recursive(fa, false);
fft_recursive(fb, false);
// 点值相乘
for (int i = 0; i < n; i++)
fa[i] *= fb[i];
// 变换回系数表示
fft_recursive(fa, true);
// 结果取整
vector<long long> result;
for (int i = 0; i < n; i++)
result.push_back(round(fa[i].real()));
return result;
}
int main() {
// A(x) = 1x^3 + 1x^2 + 8x + 1
// B(x) = 1x^2 - 5x - 2
vector<int> a = {1, 8, 1, 1}; // 注意低位在前
vector<int> b = {-2, -5, 1};
vector<long long> res = multiply(a, b);
cout << "Final Coefficients: ";
for (int i = 0; i < a.size() + b.size() - 1; i++) {
cout << res[i] << " ";
}
cout << endl;
return 0;
}