0%

快速傅里叶变换

读这篇文章的前置条件

1.了解函数,多项式等概念
2.了解复数以及其乘法运算
3.会写代码
因为在看其他fft的教程的时候复数读完就不知道在说啥了,读了好多篇又抄了std才大致了解,所以通俗的讲一下fft是啥,主要是确定一下概念,梳理一下思路。一些更基础的内容就不再赘述了,如果遇到不知道复数是什么之类的,那可以看一下其他的教程。

什么是FFT

法法塔(FFT)全称快速傅里叶变换(Fast Fourier Transform),我们经常在求高精度乘法的时候用到这个算法,但是高精乘只是本身FFT的一种应用 ,事实上,傅里叶变换本身只是一种将多项式在点值表示法和多项式表示法之间转换的方法

简单的多项式乘法

我们考虑一个n次多项式P(x)和另一个m次多项式Q(x),当我们将x取为10,那么事实上多项式就表示一个十进制数。我们在做乘法的时候,可以将P(x)和Q(x)做卷积,也就是我们日常所用的竖式乘法,得到一个新的多项式 。对A(x)取x=10,就可以得到乘法的结果。这样做的复杂度是

用点值表示法来进行多项式乘法

我们上面是在多项式表示法下解决多项式乘法问题,事实上我们还有另一种解决方案。

我们先介绍多项式的点值表示法。对于一个n次多项式P(x),如果我们能在平面上取n+1个点,记为H(x)={(x1,y1),(x2,y2)…(xn+1,yn+1)},那么我们一定可以定出一个唯一的n次多项式(当然,也可以取更多的点,只要这些点都落在多项式函数上就可以,它等价于一个大于n的次数的系数都为0的更高次多项式,事实上我们在后面操作时的确会取过量个点)。此时我们可以称H(x)和P(x)等价。

那么假设我们现在有两个点值表示法下的多项式H(x),G(x),我们不妨取n=(H的次数+y的次数+1)个点,这样我们一定可以正确表示两个小于等于n次的多项式H(x)和G(x)。那么接下来我们可以很快速的求解对应的点值表示法下的多项式:考虑每一个选取的横坐标x,我们有y1=H(x),y2=G(x),当这这两个多项式的乘积也y3=Z(x)一定满足:

这个很容易理解,想象一下,当x=10时,这不就是我们熟知的乘法吗?只不过是直接在结果上的表现形式罢了。

而因为我们取得点足够多,所以求出来的n个点一定足够表述Z(x)这个多项式方程。
这样,我们就可以在O(n+m)时间[两个多项式的长度分别为n,m的情况下]内求解多项式乘法,也就是将这么多个点对应相乘即可。

那么现在我们要解决的问题就是:如何将点值表示法和多项式表示法进行快速的转换。

离散傅里叶变换

怎么样把多项式表示法变成点值表示法呢,一个简单的想法是,随便取一些点,求出对应的值,那么就完成了。那么怎么变回去呢?这就是离散傅里叶变换给我们提供的思路:如果我们在将多项式变成点值表示法时取一些特殊的点,那么我们就可以用类似的方法将点值表示法变回多项式表示法。

傅里叶告诉我们,如果我们要取n个点,那么取n次单位复根就会有一个美妙的性质。(如果你是从其他教程过来的,那么应该知道这说的是啥,就不多做赘述了。)
接下来讲一下为什么要取这些值呢?

这一篇里面的数学推导非常直白https://www.cnblogs.com/RabbitHu/p/FFT.html,我这里就直接用他的图了。
离散傅里叶变换
这是什么意思呢?
意思就是,如果我们取单位负根,那么我们就能用几乎一模一样的函数把点值表示法变回多项式表示法,这是x取单位复根的意义。事实上,这个单位负根和我们求出来的结果其实没什么关系,这个单纯是用来做变换的手段。
结论

加速

但是,直接带入点计算,就算使用高明的秦九韶算法 每个坐标需要O(n)的时间,n个坐标就是O(n^2^),那么并没有什么用处。这时候,单位负根的另一个性质就得到了体现,我们可以利用分治的思想,将这个过程的时间复杂度降到O(nlogn)。

这就是快速傅里叶变换FFT,说白了就是一个形式转换函数的加速版本。

首先,为了方便起见,我们需要项数为2的幂次,不妨记为n=2^k^。研究多项式,有多项式Pn的最高次为n-1次,超出原多项式的项系数为0,这个我们在上面讨论过。

接下来 我们可以按照奇偶将多项式函数P(x)分成两部分,即:

Pn(x)=ODD(x)+EVEN(x)

我们将ODD(x)中提出一个x,那么我们得到

x * EVEN(x)=even(x)

联立可以得到:

Pn(x)=x * even(x)+EVEN(x)

不难发现,偶次项组成的多项式中,我们可以用x^2来代替x,即得到新的

Pn(x)=An/2(x2)+Bn/2(x2)

其中A和B都是项数减半的多项式。

假设我们现在代入x=wnk(当 k < n/2时),有:

Pn(wnk)=An/2(wnwk)+Bn/2(wn2k) * x

整理变形得到:

Pn(wnk)=An/2(wn2k)+Bn/2(wn2k) * wnk

由单位复根性质可知:

wn2k=wn/2k

即:
Pn(wnk)=An/2(wn/2k)+Bn/2(wn/2k) * wnk

类似地,当n=k+n/2,我们有:

Pn(wnk+n/2)=An/2(wn/2k)+Bn/2(wn/2k) * wnk+n/2

整理得到:

Pn(wnk+n/2)=An/2(wn/2k)-Bn/2(wn/2k) * wnk

综上,我们得到两个递推公式:

Pn(wnk)=An/2(wn/2k)+Bn/2(wn/2k) * wnk
Pn(wnk+n/2)=An/2(wn/2k)-Bn/2(wn/2k) * wnk

从而可以在nlogn时间内进行傅里叶变换。

迭代解法

分治常规的处理方法是递归。但是众所周知,迭代比递归快,所以我们将考虑如何迭代来求解该问题。
迭代

1
2
3
4
5
6
7
8
//求二进制翻转数位置
int lim=0;
while((1<<lim)<n)++lim;
for(int i=0;i<n;++i){
int t=0;
for(int j=0;j<lim;++j)if((i>>j)&1)t|=(1<<(lim-j-1));
if(i<t)swap(a[i],a[t]);
}

蝴蝶操作

这个操作是什么意思呢?
我们有两个式子:

Pn(wnk)=An/2(wn/2k)+Bn/2(wn/2k) * wnk
Pn(wnk+n/2)=An/2(wn/2k)-Bn/2(wn/2k) * wnk

但是我们又要同时更新这两个式子,所以怎么办呢?

那么我们另一个新的t等于后项,然后再去分别更新。有点类似于什么呢?大概是swap里面的t的作用。

于是我们就可以在没有buf的情况下计算fft了。

1
2
3
4
5
6
7
8
9
10
for(int l=2;l<=n;l<<=1){
int m=l/2;
for(cp *p=a;p!=a+n;p+=l){
for(int i=0;i<m;++i){
cp t=omg[n/l*i]*p[i+m];
p[i+m]=p[i]-t;
p[i]+=t;
}
}
}

高精度完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include<bits/stdc++.h>
using namespace std;
#define cp complex<double>
const double PI=acos(-1.0);
const int N=1005;
int n=1,res[N];
string s1,s2;

void init(cp *omg,cp *inv){
for(int i=0;i<N;++i){
omg[i]=cp(cos(2*PI*i/n),sin(2*PI*i/n));
inv[i]=conj(omg[i]);
}
}
void fft(cp *a,cp *omg){
int lim=0;
while((1<<lim)<n)++lim;
for(int i=0;i<n;++i){
int t=0;
for(int j=0;j<lim;++j)if((i>>j)&1)t|=(1<<(lim-j-1));
if(i<t)swap(a[i],a[t]);
}
for(int l=2;l<=n;l<<=1){
int m=l/2;
for(cp *p=a;p!=a+n;p+=l){
for(int i=0;i<m;++i){
cp t=omg[n/l*i]*p[i+m];
p[i+m]=p[i]-t;
p[i]+=t;
}
}
}
}

int main(){
while(cin>>s1>>s2){
cp a[N],b[N],omg[N],inv[N];
memset(res,0,sizeof(res));
int l1=s1.length(),l2=s2.length();
n=1;
while(n<l1+l2)n<<=1;
for(int i=0;i<l1;++i)a[i].real(s1[l1-i-1]-'0');
for(int i=0;i<l2;++i)b[i].real(s2[l2-i-1]-'0');
init(omg,inv);
fft(a,omg);
fft(b,omg);
for(int i=0;i<n;++i)a[i]*=b[i];
fft(a,inv);
for(int i=0;i<n;++i){
res[i]+=floor(a[i].real()/n+0.5);
res[i+1]+=res[i]/10;
res[i]%=10;
}
int flag=0;
for(int i=n;~i;--i){
if(flag||res[i]){
flag=1;
cout<<res[i];
}
}
if(flag==0)cout<<0;
cout<<endl;
}
}