Files
Data-Structure/Algorithm/FFT/template-iterative.cpp
2026-02-13 12:38:29 +08:00

104 lines
2.7 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include <iostream>
#include <vector>
#include <complex>
#include <cmath>
#include <algorithm>
using namespace std;
// 使用复数类complex<double> 自带加减乘运算
typedef complex<double> cd;
const double PI = acos(-1.0);
/**
* FFT 核心函数
* @param a 待转换的系数向量
* @param invert 是否为逆变换 (IFFT)。false 为 DFTtrue 为 IDFT
*/
void fft(vector<cd>& a, bool invert) {
int n = a.size();
// 1. 位逆序置换 (Bit-reversal permutation)
// 将原数组的下标进行二进制翻转,使得最底层的系数排列在正确位置
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
// 2. 自底向上合并 (迭代实现蝴蝶变换)
// len 是当前合并区间的长度,从 2, 4, 8 ... 一直合并到 n
for (int len = 2; len <= n; len <<= 1) {
double ang = 2 * PI / len * (invert ? -1 : 1);
cd wlen(cos(ang), sin(ang)); // 单位根 w_n^1
for (int i = 0; i < n; i += len) {
cd w(1); // 初始 w_n^0 = 1
for (int j = 0; j < len / 2; j++) {
// 蝴蝶变换核心公式:
// u = A_e, v = w * A_o
cd u = a[i + j], v = a[i + j + len / 2] * w;
a[i + j] = u + v; // A(w^k) = A_e + w*A_o
a[i + j + len / 2] = u - v; // A(w^{k+n/2}) = A_e - w*A_o
w *= wlen; // 迭代单位根
}
}
}
// 3. 如果是逆变换,最后需要除以 n
if (invert) {
for (cd & x : a)
x /= n;
}
}
/**
* 多项式乘法
*/
vector<long long> multiply(vector<int> const& a, vector<int> const& b) {
vector<cd> fa(a.begin(), a.end()), fb(b.begin(), b.end());
// 补零到 2 的幂次
int n = 1;
while (n < a.size() + b.size())
n <<= 1;
fa.resize(n);
fb.resize(n);
// 1. DFT: 系数 -> 点值
fft(fa, false);
fft(fb, false);
// 2. Pointwise Multiplication: 点值直接相乘
for (int i = 0; i < n; i++)
fa[i] *= fb[i];
// 3. IDFT: 点值 -> 系数
fft(fa, true);
// 4. 取实部并四舍五入
vector<long long> result(n);
for (int i = 0; i < n; i++)
result[i] = round(fa[i].real());
return result;
}
int main() {
// 你的例子:
// f1 = 1 + 8x + 1x^2 + 1x^3 -> [1, 8, 1, 1]
// f2 = -2 - 5x + 1x^2 -> [-2, -5, 1]
vector<int> f1 = {1, 8, 1, 1};
vector<int> f2 = {-2, -5, 1};
vector<long long> res = multiply(f1, f2);
cout << "Result coefficients: ";
for (int i = 0; i < f1.size() + f2.size() - 1; i++) {
cout << res[i] << " ";
}
return 0;
}