FFT 快速傅里叶变换

duo xiang shi
本文仅讨论 FFT 在 OI 中的使用。

表述可能不甚严谨,如有错漏,欢迎指出。

能解决什么问题?

有两个多项式 f(x)=a0+a1x1+a2x2++akxkf(x) = a_0 + a_1x^1 + a_2x^2 + \cdots + a_kx^kg(x)=b0+b1x1+b2x2++blxlg(x) = b_0 + b_1x^1 + b_2x^2 + \cdots + b_lx^l

h(x)=f(x)g(x)h(x) = f(x) \cdot g(x)

显然使用朴素算法,时间复杂度是 O(kl)\Omicron(kl)

但 FFT 可以将其优化到 O(nlogn)\Omicron(n \log n) 级别。

引理 I (插值)

nn 个 2元组,为 (x0,y0),(x1,y1),,(xn1,yn1)(x_0, y_0), (x_1, y_1), \cdots , (x_{n - 1}, y_{n - 1})ij,xixj\forall i \not= j, x_i \not= x_j
存在一个唯一的 n1n - 1次多项式,使得 i[0,n) f(xi)=yi\forall i \in [0, n)\ f(x_i) = y_i

唯一性

若有 f(x),g(x)f(x), g(x) 满足以上条件,设 r(x)=f(x)g(x)r(x) = f(x) - g(x)

i[0,n) r(xi)=0\forall i \in [0, n)\ r(x_i) = 0,即 rrnn 个根。

且因为 n1n - 1 次方程最多只有 n1n - 1 个根。

r(x)=0r(x) = 0

存在性

f(x)f(x) 可以通过以下方式构造:

f(x)=i=0n1yi(jixxjxixj)f(x) = \sum_{i = 0}^{n - 1} y_i \cdot (\prod_{j \not= i} \frac {x - x_j} {x_i - x_j})

解释:对于所有的 ii,若 x=xix = x_i 时,对应乘积的部分每一项均为 11,否则会有一项为 00

显然因为 xi,yix_i, y_i 均为已知量,并且乘积中只有 n1n - 1xxjx - x_j 相乘,故 f(x)f(x)n1n - 1 次多项式,符合要求。

具体步骤

由于 引理I,只要有 nn 个不同的 xx 和对应的 f(x)f(x) ,那么就可以确定 f(x)f(x),故我们只需要找出 h(x)h(x)nn 个互不相同的参数和对应的取值即可。

为了方便,我们设 n=k+l1n = k + l - 1

I. 选取 x0,x1,xn1x_0, x_1, \dots x_{n-1}

II. 确定 f(x0),f(x1),f(xn1)f(x_0), f(x_1), \ldots f(x_{n-1})g(x0),g(x1),g(xn1)g(x_0), g(x_1), \ldots g(x_{n-1})

III. 根据以上,得出 h(xi)=f(xi)g(xi)h(x_i) = f(x_i) \cdot g(x_i),求出 hh

乍看时间复杂度还是 O(n2)\Omicron(n^2) 的,但是不要着急,我们慢慢来。

步骤 I

对于 n1n - 1 次多项式,我们选取

xk=ωnk=e2πkni=cos2πn+isin2πnx_k = \omega_n^k = e^{\frac {2 \pi k} n i} = \cos{\frac {2\pi} n} + i \sin{\frac {2\pi} n}

其中 ωni\omega_n^ixn=1x^n = 1 的所有根,即 nn 次单位虚根。

假如放到复平面上理解,就是单位圆上的 nn 等分点。

例如 n=8n = 8

1

为什么选取它呢?因为有一个性质:折半定理。

n,kZ,ωnk=ω2n2k\forall n, k \in \mathbb{Z}, \omega_n^k = \omega_{2n}^{2k}

从直觉上来说比较显然,故不作证明。

那么这个有什么用呢?当然有用,我们可以借此折半问题规模,继续往下看。

步骤 II (DFT)

因为求 f(xi)f(x_i) 与求 g(xi)g(x_i) 本质相同,故仅以 f(xi)f(x_i) 为例。

为了叙述和编写代码的方便,我们令 nn 为总是可以表示成 n=2qn = 2^q 。(多出的项补 00 即可,此处 nn 与上面不同)

同时,用下标表示序列的长度, fm(xi),m=2qf_m(x_i), m = 2^q ,设 am,ia_{m, i}fmf_m 各项的系数。

fn(xi),i[0,n)f_n(x_i), i \in [0, n) 依然是 O(n2)\Omicron(n^2) 的,要优化到 O(nlogn)\Omicron(n \log n)

考虑求 fn(xi)f_n(x_i) 的情况:

n=1n = 1,因为 ω1=1\omega_1 = 1,故 f1=a1,0f_1 = a_{1, 0}

否则,nn 为偶数,所以有:(根据上面的约定 fn(x)f_n(x)n1n-1 次多项式)

fn(x)f_n(x) 根据奇偶性拆开:

fn(x)=an,0+an,1x1++an,n1xn1=(an,0+an,2x2++an,n2xn2)  +(an,1x1+an,3x3++an,n1xn1)=(an,0(x2)0+an,2(x2)1++an,n2(x2)n22)  +x(an,1(x2)0+an,3(x2)1++an,n1(x2)n22)\begin{aligned} f_n(x) &= a_{n, 0} + a_{n, 1}x^1 + \cdots + a_{n, n - 1}x^{n - 1} \\ &= (a_{n, 0} + a_{n, 2}x^2 + \cdots + a_{n, n - 2}x^{n - 2}) \\ &\ \ + (a_{n, 1}x^1 + a_{n, 3}x^3 + \cdots + a_{n, n - 1}x^{n - 1}) \\ &= (a_{n, 0}(x^2)^0 + a_{n, 2}(x^2)^1 + \cdots + a_{n, n - 2}(x ^ 2)^{\frac {n - 2} 2}) \\ &\ \ + x(a_{n, 1}(x^2)^0 + a_{n, 3}(x^2)^1 + \cdots + a_{n, n - 1}(x ^ 2)^{\frac {n - 2} 2}) \\ \end{aligned}

把平方提出来,得到:

fn,0(x)=an,0x0+an,2x1++an,n2xn22fn,1(x)=an,1x0+an,3x1++an,n1xn22f_{n, 0}(x) = a_{n, 0}x^0 + a_{n, 2}x^1 + \cdots + a_{n, n - 2}x^{\frac {n - 2} 2} \\ f_{n, 1}(x) = a_{n, 1}x^0 + a_{n, 3}x^1 + \cdots + a_{n, n - 1}x^{\frac {n - 2} 2} \\

那么:

fn(x)=fn,0(x2)+xfn,1(x2)fn(ωni)=fn,0(ωn2i)+ωnifn,1(ωn2i)f_n(x) = f_{n, 0}(x ^ 2) + xf_{n, 1}(x ^ 2) \\ f_n(\omega_n^i) = f_{n, 0}(\omega_n^{2i}) + \omega_n^i f_{n, 1}(\omega_n^{2i})

发现,可以引用折半定理,令 p=n2p = \frac n 2

fn(ωni)=fn,0(ωpi)+ωnifn,1(ωpi)f_n(\omega_n^i) = f_{n, 0}(\omega_p^i) + \omega_n^i f_{n, 1}(\omega_p^i)

我们可以观察一下这一番操作的结果:

带入的值 00 11 \cdots p1p - 1 p+0p + 0 p+1p + 1 \cdots 2p12p - 1
fn(x)f_n(x) ωn0\omega_n^0 ωn1\omega_n^1 \cdots ωnp1\omega_n^{p-1} ωnp+0\omega_n^{p+0} ωnp+1\omega_n^{p+1} \cdots ωnp+p1\omega_n^{p+p-1}
fn,0(x)f_{n, 0}(x) ωp0\omega_p^0 ωp1\omega_p^1 \cdots ωpp1\omega_p^{p-1} ωpp+0\omega_p^{p+0} ωpp+1\omega_p^{p+1} \cdots ωpp+p1\omega_p^{p+p-1}
fn,1(x)f_{n, 1}(x) ωp0\omega_p^0 ωp1\omega_p^1 \cdots ωpp1\omega_p^{p-1} ωpp+0\omega_p^{p+0} ωpp+1\omega_p^{p+1} \cdots ωpp+p1\omega_p^{p+p-1}

可以发现,由于显然有 ωpi=ωpi+p\omega_p^i = \omega_p^{i+p}
每次只用处理 fn,0f_{n,0} 带入 i[0,p),ωpii\in [0, p), \omega_p^i 的值即可,fn,1f_{n, 1} 同理。

由于 p=n2p = \frac n 2,所以每次处理的规模折半,可以实现 O(nlogn)O(n \log n)

以下是代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
typedef double db;
typedef complex<double> comp;

void fft(int n, comp* a /* 输入数组 */ , int op /* 会有其他的用途,这里暂时忽略,可以认为总为 1 */)
{
if (n == 1) return;
comp a0[n >> 1], a1[n >> 1];
for (int i = 0; i < n >> 1; i++) a0[i] = a[i << 1], a1[i] = a[i << 1 | 1]; // 拆成两半
fft(n >> 1, a0, op), fft(n >> 1, a1, op); // 分治

comp wn = comp(cos(2.0 * PI / n), op * sin(2.0 * PI / n)), w = comp(1, 0);
for (int i = 0; i < (n >> 1); i++, w *= wn) { // 合起来
a[i] = a0[i] + w * a1[i]; // 前一半
a[i + (n >> 1)] = a0[i] - w * a1[i]; // 后一半
}
}

步骤 III (IDFT)

ffgg 插值后相乘,现在我们需要根据这个相乘的结果反退出 h(x)h(x) 的系数,即点值。

可以从矩阵的角度考虑,插值过程中发生了什么。

A=[an,0an,1an,n1] V=[ωn00ωn01ωn0(n1)ωn10ωn11ωn1(n1)ωn(n1)0ωn(n1)1ωn(n1)(n1)] Y=[fn(ωn0)fn(ωn1)fn(ωnn1)]=VA\vec A = \begin{bmatrix} a_{n, 0} \\ a_{n, 1} \\ \vdots \\ a_{n, n-1} \\ \end{bmatrix} \\ \ \\ V = \begin {bmatrix} \omega_n^{0\cdot0} & \omega_n^{0\cdot1} & \cdots & \omega_n^{0\cdot(n-1)} \\ \omega_n^{1\cdot0} & \omega_n^{1\cdot1} & \cdots & \omega_n^{1\cdot(n-1)} \\ \vdots & \vdots & \ddots & \vdots \\ \omega_n^{(n-1)\cdot0} & \omega_n^{(n-1)\cdot1} & \cdots & \omega_n^{(n-1)\cdot(n-1)} \\ \end{bmatrix} \\ \ \\ \vec Y = \begin{bmatrix} f_n(\omega_n^0)\\ f_n(\omega_n^1)\\ \vdots \\ f_n(\omega_n^{n-1})\\ \end{bmatrix} = V \vec A \\

现在对于 hh 知道了 Y\vec YVV,求 A\vec A
显然 A=YV1\vec A = \vec Y V^{-1}

U=[ωn00ωn01ωn0(n1)ωn10ωn11ωn1(n1)ωn(n1)0ωn(n1)1ωn(n1)(n1)] VU=[n0000n0000n0000n] V1=1nUU = \begin {bmatrix} \omega_n^{-0\cdot0} & \omega_n^{-0\cdot1} & \cdots & \omega_n^{-0\cdot(n-1)} \\ \omega_n^{-1\cdot0} & \omega_n^{-1\cdot1} & \cdots & \omega_n^{-1\cdot(n-1)} \\ \vdots & \vdots & \ddots & \vdots \\ \omega_n^{-(n-1)\cdot0} & \omega_n^{-(n-1)\cdot1} & \cdots & \omega_n^{-(n-1)\cdot(n-1)} \\ \end{bmatrix} \\ \ \\ VU = \begin {bmatrix} n & 0 & 0 & \cdots & 0 \\ 0 & n & 0 & \cdots & 0 \\ 0 & 0 & n & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \cdots & n \\ \end {bmatrix} \\ \ \\ V^{-1} = \frac 1 n U \\

解释:

Xi,jX_{i, j} 为矩阵 XXiijj 列的值。令 B=VUB = VU

要证明

Bi,j={n, if i=j0, otherwiseB_{i,j} = \left\{\begin{aligned} n,\ & \text{if}\ i = j \\ 0,\ & \text{otherwise} \end{aligned} \right. \\

Bi,j=k=0nVi,kUk,j=k=0nωikωkj=k=0nωk(ij)B_{i, j} = \sum_{k=0}^n V_{i,k}U_{k, j} = \sum_{k=0}^n \omega^{ik} \omega^{-kj} = \sum_{k=0}^n \omega^{k(i - j)}

i=ji = j 时,Bi,j=k=0nω0=nB_{i, j} = \sum_{k=0}^n \omega^0 = n

否则设 t=ij,t0t = i - j, t \not= 0,要证明

Bi,j=k=0nωkt=0B_{i, j} = \sum_{k=0}^n \omega^{kt} = 0

即将其看作复平面的单位圆上 nn 个点,第 00 个点为 (1,0)(1, 0),相邻点间隔 2πnt\frac {2\pi} n t,求每个点是否都有唯一的点对应使得二者之和为 00 (过原点对称,且这里只考虑 nn2q2^q 的情况)。

t=2pct = 2^p \cdot cn=2qn = 2^qcc 为奇数,且 p<qp < q

nn 个点分成 2p2^p 组,每组 2qp2^{q-p} 个,显然每组相同,故只考虑第 00 组。

对于该组的第 k<2qp2k \lt \frac {2^{q-p}} 2 个,其与第 k+2qp2k + \frac {2^{q-p}} 2 个相对应,因为有

ωk+t2qp2=ωkω2pc2q2p+1=ωk(ω2q2)c=ωk(ωn2)c=ωk(1)c=ωk\omega^{k + t \frac {2^ {q-p}} 2} = \omega^k \omega^{2^p \cdot c \cdot \frac {2^q} {2^{p + 1}}} = \omega^k (\omega^{\frac{2^q} 2})^c = \omega^k (\omega^{\frac n 2})^c = \omega^k (-1)^c = -\omega^k

所以只要把插值过程中的 ωni\omega_n^i 变为 ωni\omega_n^{-i} ,再对结果乘 1n\frac 1 n 即为答案。

代码 (luogu P3803):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/*
author: byronwan
problem:
url: https://www.luogu.com.cn/problem/P3803
title: FFT
date: 2022-11-12
*/

#include <bits/stdc++.h>

using namespace std;

typedef double db;
typedef complex<double> comp;

const int MAXN = 4e6 + 10;
const double PI = acos(-1);

int k, l, n;
comp a[MAXN], b[MAXN];

void setup()
{
cin >> k >> l;
for (int i = 0; i <= k; i++) cin >> a[i];
for (int i = 0; i <= l; i++) cin >> b[i];
for (n = 1; n <= k + l; n <<= 1)
;
}

void fft(int n, comp* a, int op)
{
if (n == 1) return;
comp a0[n >> 1], a1[n >> 1];
for (int i = 0; i < n >> 1; i++) a0[i] = a[i << 1], a1[i] = a[i << 1 | 1];
fft(n >> 1, a0, op), fft(n >> 1, a1, op);

comp wn = comp(cos(2.0 * PI / n), op * sin(2.0 * PI / n)), w = comp(1, 0);
for (int i = 0; i < (n >> 1); i++, w *= wn) {
a[i] = a0[i] + w * a1[i];
a[i + (n >> 1)] = a0[i] - w * a1[i];
}
}

int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr), clog.tie(nullptr);

setup();
fft(n, a, 1), fft(n, b, 1);

for (int i = 0; i < n; i++) a[i] = a[i] * b[i];
fft(n, a, -1);

for (int i = 0; i <= k + l; i++) cout << (long long)(a[i].real() / n + 0.5) << " ";
cout << endl;
}

迭代实现

刚才的代码是递归实现,由于一些原因(个人猜测是要在栈上开数组),递归实现虽然写起来方便,但可能会被卡,以下介绍不递归的迭代实现。

首先,我们以 n=8n = 8 为例,模拟递归的过程。

1
2
3
4
5
6
               a0 a1 a2 a3 a4 a5 a6 a7
a0 a2 a4 a6 | a1 a3 a5 a7
a0 a4 | a2 a6 | a1 a5 | a3 a7
a0 | a4 | a2 | a6 | a1 | a5 | a3 | a7

binary 000 100 010 110 001 101 011 111

不难发现,每一位上 a 的下标是该位的下标的二进制反过来。

我们首先把 a 按照这个规律放好(跳过了递归实现中分配 a 的步骤),直接开始合并,而合并的方法和原来相同(合并过程中不需要 a,只要位置对了就行)。

那么问题就是如何把 a 放好,设 rev(i)rev(i)ii 二进制反过来的数,显然有 rev(rev(i))=irev(rev(i)) = i,所以 rev(i)rev(i)ii 是一一映射的。

考虑用递推的方式求 rev(i)rev(i)

1
2
3
4
5
6
7
8
9
     i: a b c d e  f
^---+---^ ^
| |
+--+ |
| |
+------|---+
| |
+ +---+---+
rev(i): f e d c b a

所以

1
2
3
rev[i] = 
(rev[i >> 1] >> 1 /* 除以2是因为它的长度为 log n 原来的最高位即现在的最低位为 0 */)
| ((i & 1) << (l - 1 /*长度*/))

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/*
author: byronwan
problem:
url: https://www.luogu.com.cn/problem/P3803
title: FFT
date: 2023-01-10
*/

#include <bits/stdc++.h>

using namespace std;

typedef double db;
typedef complex<double> comp;

const int MAXN = 4e6 + 10;
const double PI = acos(-1);

int k, l, n, m;
int rev[MAXN];
comp a[MAXN], b[MAXN];

void setup()
{
cin >> k >> l;
for (int i = 0; i <= k; i++) cin >> a[i];
for (int i = 0; i <= l; i++) cin >> b[i];
for (n = 1; n <= k + l; n <<= 1, m++)
;
}

void fft(comp* a, int op)
{
for (int i = 0; i < n; i++) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (m - 1));
if (rev[i] > i) swap(a[rev[i]], a[i]);
}

for (int p = 1; p < n; p <<= 1) {
int len = p * 2;
auto wn = comp(cos(2.0 * PI / len), op * sin(2.0 * PI / len));
for (int i = 0; i < n; i += len) {
auto w = comp(1, 0);
for (int j = 0; j < p; j++, w *= wn) {
auto x = a[i + j], y = w * a[i + j + p];
a[i + j] = x + y, a[i + j + p] = x - y;
}
}
}
}

int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr), clog.tie(nullptr);

setup();
fft(a, 1), fft(b, 1);

for (int i = 0; i < n; i++) a[i] = a[i] * b[i];
fft(a, -1);

for (int i = 0; i <= k + l; i++) cout << (long long) (a[i].real() / n + 0.5) << " ";
cout << endl;
}

参考资料

「自为风月马前卒」大大的博客

oi-wiki.org