高速フーリエ変換を図で理解する

問題

(n-1)多項式 f(x) := a_0 + a_1x + \ldots + a _ {n-1}x ^{n-1} について、 x = \zeta_n^ i \ (i = 0, 1, \ldots, n-1) での値を計算したい。

(※ ただし\zeta_n := \exp\left(\frac{2\pi\sqrt{-1}}{n}\right) とおいた。)

愚直に計算するとO(n^ 2)かかるが、これをO(n\log n)で計算する方法を考える。

アルゴリズム

簡単のため、以降 n2の冪乗とする。

(n/2-1)多項式 f_0(x), \ f_1(x) を以下で定義する。

  • f_0(x) := a_0 + a_2x + \ldots + a _ {n-2}x^ {n/2 - 1}
  • f_1(x) := a_1 + a_3x + \ldots + a _ {n-1}x^ {n/2 - 1}

すると、f(x) = f_0(x^ 2) + x \ f_1(x^ 2) が成り立つ。

x = \zeta_n^ i \ (0 \leq i \lt n) での f(x)の値を求めたいので、

0 \leq i \lt n について f_0(\zeta_n^ {2i}), \ f_1(\zeta_n^ {2i})の値が求まっていれば良い。

実は

  • \zeta_n^ {2i} = \zeta_ {n/2}^ i \ \ (0 \leq i \lt n)
  • \zeta _ {n/2}^ i = \zeta _ {n/2}^ {i + n/2} \ \ (0 \leq i \lt n/2)

より、\zeta_n^ {2i} = \zeta _ {n/2}^ {i \ \mathrm{mod} \ n/2} \ \ (0 \leq i \lt n) であることがわかる。

つまり 0 \leq i \lt n/2 について f_0(\zeta _ {n/2}^ i), \ f_1(\zeta _ {n/2}^ i) が求まっていれば良い。

(※ 以下にn=8の場合の図を示す。x = \zeta_n^ {2i} は添字が2づつ進んで2周するため、●における f_0(x), \ f_1(x)の値が求まっていれば良いことがわかる。) f:id:habara_k:20200520234237p:plain

したがって、以下の再帰的な計算で (n-1)多項式 f(x) = \sum_{i=0}^ {n-1} a_i x^ i に対する f(\zeta_n^ i) \ (0 \leq i \lt n) を求めることができる。

  • n = 1のとき、f(x) = a_0 より f(\zeta_1^ 0) = a_0を返す。
  • そうでないとき、

    1. f_0(x) := \sum _ {i=0}^ {n/2-1} a _ {2i} x^ i, \ f_1(x) := \sum _ {i=0}^ {n/2-1} a _ {2i+1} x^ i を定義する。

    2. (n/2-1)多項式 f_0(x), \ f_1(x) に対する f_0(\zeta _ {n/2}^ i), \ f_1(\zeta _ {n/2}^ i) \ \ (0 \leq i \lt n/2) をそれぞれ再帰的に計算する。

    3. 0 \leq i \lt n について f(\zeta_n^ i) = f_0(\zeta _ {n/2}^ {i \ \mathrm{mod} \ (n/2)}) + \zeta _ n^ i \ f_1(\zeta _ {n/2}^ {i \ \mathrm{mod} \ (n/2)}) を計算して返す。

各呼び出しで扱う多項式は、n=8の場合以下の図のようになる。 ただし、f _ {*} から定義した2つの多項式をそれぞれ f _ {0*}, \ f _ {1*} と表記した。 f:id:habara_k:20200520182929p:plain 再帰の終端において、f _ {*}(x) = a _ {*} になっていることがわかる。 次節の実装ではこれを利用している。

計算量はマージソートと同様の解析で O(n\log n)となる (分割統治法)。

実装

using Complex = std::complex<double>;

struct FFT {

    std::vector<Complex> a_;
    int n;

    FFT(const std::vector<Complex>& a) : a_(a), n(1) {
        // n を2の冪乗にする
        while (n < a.size()) n <<= 1;
        a_.resize(n);
    };

    std::vector<Complex> solve() {
        return fft(0, 0);
    }

    std::vector<Complex> fft(int d, int bit) {
        int sz = n >> d;
        if (sz == 1) return {a_[bit]};

        auto f0 = fft(d+1, bit);
        auto f1 = fft(d+1, bit | 1<<d);

        std::vector<Complex> f(sz);

        for (int i = 0; i < sz; ++i) {
            Complex x = std::polar(1.0, 2*M_PI / sz * i);
            f[i] = f0[i % (sz / 2)] + x * f1[i % (sz / 2)];
        }
        return f;
    }
};

参考

http://compro.tsutajiro.com/archive/fft.pdf

C - 高速フーリエ変換