TFHE実装入門

7.FFTによる多項式乗算

松岡 航太郎

モチベーション

  • TFHEにおいて一番重い処理は多項式乗算(External Product)
  • ここを速くすることが全体の高速化に大きく寄与する
  • 今回の話を一般化すると多倍長精度の乗算につながる
  • 8基底FFTまで説明する

FFTとは

  • Fast Fourier Transformの略で日本語では高速フーリエ変換という
    • これは離散フーリエ変換と呼ばれる処理を高速に実行するアルゴリズム
    • なので離散フーリエ変換を使って多項式乗算を表現することを先に行う

離散フーリエ変換とは

  • 文字通りフーリエ変換を離散化したもの
    • フーリエ変換は時間信号を複数の周波数成分に分解する処理
    • Discrete Fourier TransformでDFTとよく呼ぶ
    • ここでは多項式乗算に転用するだけなのでフーリエ変換については割愛する
  • 入力となる多項式をa[X]a[X]、出力をA[X]A[X]としよう
  • このとき離散フーリエ変換は以下のように与えられる
  • 指数の符号は逆でもいい
  • ei2πNe^{-i\frac{2π}{N}}は回転因子とよぶ

At=x=0N1axei2πtxN,t{0,1,...,N1}A_t = ∑^{N-1}_{x=0} a_x⋅e^{-i\frac{2πtx}{N}},t∈\{0,1,...,N-1\}

離散フーリエ変換の行列表現(N=4)

  • 積分よりこっちがイメージしやすい人が居るかもしれない
  • ここではN=4N=4の場合を示す
    • 行列の形には規則性があるので簡単にNNは大きくできる

WN=ei2πN(A0A1A2A3)=(W0W0W0W0W0W1W2W3W0W2W4W6W0W3W6W9)(a0a1a2a3)\begin{aligned} W_N &= e^{-i\frac{2π}{N}}\\ \begin{pmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \\ \end{pmatrix} &= \begin{pmatrix} W^0 & W^0 & W^0 & W^0 \\ W^0 & W^1 & W^2 & W^3 \\ W^0 & W^2 & W^4 & W^6 \\ W^0 & W^3 & W^6 & W^9 \end{pmatrix} \cdot \begin{pmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \\ \end{pmatrix} \end{aligned}

逆離散フーリエ変換

  • フーリエ変換の重要な性質として、逆変換が存在することがある
    • 離散フーリエ変換でも逆変換がありInverseでIDFTという
    • 定数倍と指数の符号くらいしか変わらない
    • ちなみに順と逆のどちらを1N\frac{1}{N}倍してもいいし符号も逆にしていい
      • 両方に1N\frac{1}{\sqrt{N}}を掛ける流儀もある
    • 実際にもとに戻ることは自分で確かめるか適当な本を参照してほしい
    • 行列表現にして実際に掛けてみるのが簡単かも?

ax=1Nt=0N1Atei2πxtN,x{0,1,...,N1}a_x = \frac{1}{N}∑^{N-1}_{t=0} A_t⋅e^{i\frac{2πxt}{N}},x∈\{0,1,...,N-1\}

畳み込み定理

  • フーリエ変換の重要な性質として畳み込み定理がある
    • 畳み込みはConvolutionなのでCNNのCと一緒
  • これはフーリエ変換をした結果(周波数成分)の要素ごとの積をとったものの逆変換が入力の畳み込みになっているという定理
  • 離散フーリエ変換での畳み込み定理がXN1X^N-1を法とする多項式乗算になる
  • つまり、a[X]b[X]c[X]modXN1a[X]⋅b[X]≡c[X]\mod{X^N-1}が以下と同値になる
    • まずはこれを証明する

ck=1N[l=0N1(n=0N1anei2πnkNm=0N1bmei2πmkN)ei2πklN]=1N[l=0N1(AkBk)ei2πklN]\begin{aligned} c_k &= \frac{1}{N}[\sum^{N-1}_{l=0}(\sum^{N-1}_{n=0}a_ne^{-i\frac{2\pi n k}{N}} \cdot \sum^{N-1}_{m=0}b_me^{-i\frac{2\pi m k}{N}})e^{i\frac{2\pi k l}{N}}]\\ & = \frac{1}{N}[\sum^{N-1}_{l=0}(A_k \cdot B_k)e^{i\frac{2\pi k l}{N}}] \end{aligned}

補題:指数が非ゼロなら回転因子の総和がゼロになること

  • 前の式が同値なことを示すには以下の式が成り立つことを示す必要がある
    • 実際に計算するとたしかにそうなることが簡単にわかる
      • l≢0modNl\not\equiv 0\bmod{N}なら回転因子が全部打ち消し合う
  • すべての係数が1の多項式の離散フーリエ変換はA0=NA_0=Nで他が0
    • 定数関数のフーリエ変換はインパルス関数である

n=0N1ei2πlnN={N if l0modN0 otherwise (e2πlnN1)n=0N1ei2πlnN=n=1Nei2πlnNn=0N1ei2πlnN=eN2πlnN1=0\begin{aligned} \sum^{N-1}_{n=0} e^{-i\frac{2\pi l n}{N}} &= \begin{cases}N\ if\ l\equiv 0\bmod{N}\\ 0\ otherwise\end{cases}\\\ ∴ (e^{-\frac{2\pi l n}{N}}-1)\sum^{N-1}_{n=0} e^{-i\frac{2\pi l n}{N}}&=\sum^{N}_{n=1} e^{-i\frac{2\pi l n}{N}}-\sum^{N-1}_{n=0} e^{-i\frac{2\pi l n}{N}}=e^{-N\frac{2\pi l n}{N}}-1=0 \end{aligned}

畳み込み定理(続き)

  • 補題を利用してさっきの式を展開してみよう
    • XN1X^N-1を法とするので上の項が下の項に足し算として回り込んでくる
    • このような回り込み方を正巡回という
    • 全係数が1の場合だと簡単に確かめられる
  • TFHEではXN+1X^N+1が法なので逆の負巡回が欲しく、工夫がいる

ck=1N[l=0N1(n=0N1anei2πnlNm=0N1bmei2πmlN)ei2πklN]=n=0N1m=0N1anbm(1Nl=0N1ei2π[k(n+m)]lN)=n+m=kanbm+n+m=k+Nanbm\begin{aligned} c_k &= \frac{1}{N}[\sum^{N-1}_{l=0}(\sum^{N-1}_{n=0}a_ne^{-i\frac{2\pi n l}{N}} \cdot \sum^{N-1}_{m=0}b_me^{-i\frac{2\pi m l}{N}})e^{i\frac{2\pi k l}{N}}]\\ &= \sum^{N-1}_{n=0}\sum^{N-1}_{m=0}a_nb_m (\frac{1}{N}\sum^{N-1}_{l=0}e^{i\frac{2\pi [k-(n+m)]l}{N}})\\ &= \sum_{n+m = k}a_nb_m + \sum_{n+m = k+N}a_nb_m \end{aligned}

離散荷重変換

  • 負巡回を作るにはこれがいる
  • Discrete Weighted fourier TransformでDWTともいう
    • 離散ウェーブレット変換とは関係ない
  • 定義的には下の式のように重みwxw_xをかけて(逆)離散フーリエ変換をするだけ
    • wxw_x倍して離散フーリエ変換して逆フーリエ変換するとwxw_x倍された係数が戻ってくるので割ればもとに戻る
  • 重みをうまく選べば負巡回乗算が実現できる

At=x=0N1wxaxei2πtxN,t{0,1,...,N1}ax=1wxNt=0N1Atei2πxtN,x{0,1,...,N1}\begin{aligned} A_t &= ∑^{N-1}_{x=0} w_xa_x⋅e^{-i\frac{2πtx}{N}},t∈\{0,1,...,N-1\}\\ a_x &= \frac{1}{w_xN}∑^{N-1}_{t=0} A_t⋅e^{i\frac{2πxt}{N}},x∈\{0,1,...,N-1\} \end{aligned}

負巡回乗算

  • wx=ei2πx2Nw_x = e^{i\frac{2πx}{2N}}と選んで畳み込み乗算
  • さらに上位N2\frac{N}{2}桁分を虚数部に入れておく
    • 虚数部も使うとメモリと計算量の両方を節約できる
    • 実数部と虚数部を連続した別の配列にすると変換が簡単
  • 1行目から2行目はDFTによる正巡回の乗算として変形する

ck=(ei2πk2N1N2)l=0N21[n=0N21ei2πn2N(an+iaN2+n)ei2πlnN2m=0N21ei2πm2N(bm+ibN2+m)ei2πlmN2]ei2πlkN2=ei2πk2N[n+m=kei2πk2N(an+iaN2+n)(bm+ibN2+m)+n+m=k+N2ei2π(k+N2)2N(an+iaN2+n)(bm+ibN2+m)]=n+m=k(an+iaN2+n)(bm+ibN2+m)+n+m=k+N2i(an+iaN2+n)(bm+ibN2+m)=[n+m=k(anbmaN2+nbN2+m)n+m=k+N2(anbN2+m+aN2+nbm)]+i[n+m=k(anbN2+m+aN2+nbn)+n+m=k+N2(anbmaN2+nbN2+m)]\begin{aligned} c_k &= (e^{-i\frac{2\pi k}{2N}}\frac{1}{\frac{N}{2}})\sum^{\frac{N}{2}-1}_{l=0}[\sum^{\frac{N}{2}-1}_{n=0}e^{i\frac{2\pi n}{2N}}(a_n+ia_{\frac{N}{2}+n})e^{-i\frac{2\pi l n}{\frac{N}{2}}} \cdot \sum^{\frac{N}{2}-1}_{m=0}e^{i\frac{2\pi m}{2N}}(b_{m}+ib_{\frac{N}{2}+m})e^{-i\frac{2\pi l m}{\frac{N}{2}}}]e^{i\frac{2\pi l k}{\frac{N}{2}}}\\ &= e^{-i\frac{2\pi k}{2N}}[\sum_{n+m=k}e^{i\frac{2\pi k}{2N}}(a_n+ia_{\frac{N}{2}+n})(b_m+ib_{\frac{N}{2}+m}) + \sum_{n+m=k+\frac{N}{2}}e^{i\frac{2\pi (k+\frac{N}{2})}{2N}}(a_n+ia_{\frac{N}{2}+n})(b_m+ib_{\frac{N}{2}+m})]\\ &= \sum_{n+m=k}(a_n+ia_{\frac{N}{2}+n})(b_m+ib_{\frac{N}{2}+m}) + \sum_{n+m=k+\frac{N}{2}}i(a_n+ia_{\frac{N}{2}+n})(b_m+ib_{\frac{N}{2}+m})\\ &= [\sum_{n+m=k}(a_nb_m-a_{\frac{N}{2}+n}b_{\frac{N}{2}+m}) - \sum_{n+m=k+\frac{N}{2}}(a_nb_{\frac{N}{2}+m}+a_{\frac{N}{2}+n}b_m)]+i[\sum_{n+m=k}(a_nb_{\frac{N}{2}+m}+a_{\frac{N}{2}+n}b_n) + \sum_{n+m=k+\frac{N}{2}}(a_nb_m-a_{\frac{N}{2}+n}b_{\frac{N}{2}+m})] \end{aligned}

高速フーリエ変換

  • ここまで見たように離散フーリエ変換ができると負巡回乗算ができる
    • ∴離散フーリエ変換が高速化できれば負巡回乗算が高速化できる
  • そのアルゴリズムが高速フーリエ変換(FFT: Fast Fourier Transform)

2基底周波数間引き高速フーリエ変換(数式)

  • 高速フーリエ変換の基本的アイデアは問題の再帰的分割
    • nZ+n∈\mathbb{Z}^+個に分割するとnn基底
  • 以下のような式から2つに分割できることがわかる

At=x=0N1axei2πtxN=x=0N21axei2πtxN+ax+N2ei2πt(x+N2)N={x=0N21(ax+ax+N2)ei2πt2xN2ift0mod2x=0N21[(axax+N2)ei2πxN]ei2πt12xN2otherwise\begin{aligned} A_t &= ∑^{N-1}_{x=0} a_x⋅e^{-i\frac{2πtx}{N}}\\ &= ∑^{\frac{N}{2}-1}_{x=0} a_x⋅e^{-i\frac{2πtx}{N}}+a_{x+\frac{N}{2}}⋅e^{-i\frac{2πt(x+\frac{N}{2})}{N}}\\ &=\begin{cases}∑^{\frac{N}{2}-1}_{x=0} (a_{x}+a_{x+\frac{N}{2}})⋅e^{-i\frac{2π\frac{t}{2}x}{\frac{N}{2}}} \qquad if\quad t≡0 \bmod 2\\∑^{\frac{N}{2}-1}_{x=0} [(a_{x}-a_{x+\frac{N}{2}})⋅e^{-i\frac{2πx}{N}}]⋅e^{-i\frac{2π\frac{t-1}{2}x}{\frac{N}{2}}}\qquad otherwise\end{cases} \end{aligned}

2基底周波数間引き高速フーリエ変換(バタフライ図)

  • FFTのアルゴリズムを図示したものがバタフライ図
    • AAの添字がぐちゃぐちゃになる(ビットリバース)のは今示しているのがCooley-Tukey型と呼ばれるin-placeなアルゴリズムであるため
    • 普通は添字を戻す並び替えかout-placeなStockham型を使う
  • 今回は畳み込み乗算さえできればいいので並び替えをする必要がない
    • 論文実装のFFTが自前な理由の一つは並び替えがいらないのを利用するため

2基底周波数間引き高速フーリエ変換(疑似コード)

  • in-placeなので受け取った配列を書き換えながら実行される
  • 簡潔なので再帰で書いているが最適化しづらいのでForで書くことが多い
Radix2FFT(𝐚,start,end) //startとendは最初0とN
  N = start-end
  n = (start-end)/2
  for x from 0 to n-1
    temp = aₓ+aₓ₊ₙ
    aₓ₊ₙ = (aₓ-aₓ₊ₙ)*exp(-i2πx/N)
    aₓ = temp
  if(N>2)
    Radix2FFT(𝐚,start,start+n)
    Radix2FFT(𝐚,start+n,end)

2基底時間間引き高速フーリエ変換

  • 周波数間引きの逆変換を時間間引きという
    • ビットリバースの並びを入力とする

4基底周波数間引き高速フーリエ変換(バタフライ図)

  • 左のバタフライ図をぐっと睨むと回転因子を移動させて右のように分解できることがわかる
  • WN2=ei2π28=iW_N^2=e^{-i\frac{2π2}{8}}=-iなのでこれは乗算にならない

4基底周波数間引き高速フーリエ変換(疑似コード)

Radix4FFT(𝐚,start,end)
  N = start-end
  if(N==2)
    s = start //下付きにするため
    temp = aₛ+aₛ₊₁
    aₛ₊₁ = aₛ - aₛ₊₁
    aₛ = temp
  else
    n = (start-end)/4
    for x from 0 to n-1
      for j from 0 to 1
        l = j*n
        m = l+2*n
        temp = aₓ₊ₗ+aₓ₊ₘ
        aₓ₊ₘ = (aₓ₊ₗ-aₓ₊ₘ)*i //実数部と虚数部の交換と符号反転
        aₓ₊ₗ = temp
      for j from 0 to 1
        l = j*2*n
        m = l+n
        if(j==0)
          temp = aₓ₊ₗ+aₓ₊ₘ
          aₓ₊ₘ = (aₓ₊ₗ-aₓ₊ₘ)*exp(-i2π2x/N)
        else
          temp = (aₓ₊ₗ+aₓ₊ₘ)*exp(-i2πx/N)
          aₓ₊ₘ = (aₓ₊ₗ-aₓ₊ₘ)*exp(-i2π3x/N)
        aₓ₊ₗ = temp
    Radix4FFT(𝐚,start,start+n)
    Radix4FFT(𝐚,start+n,start+2n)
    Radix4FFT(𝐚,start+2n,start+3n)
    Radix4FFT(𝐚,start+3n,end)

Split Radix(蛇足)

  • 4基底で計算量削減効果があるのは下半分だけ
  • 上半分は2基底でやめるとSplit-Radixになる
    • メモリアクセスやプログラムの煩雑さの都合でSplit-Radixはあまりやらない
    • 8基底でも同じことはできる

8基底周波数間引き高速フーリエ変換(バタフライ図)

  • N=16なのでWN2=ei2π216=12(1i),W6=12(1i)W_N^2=e^{-i\frac{2π2}{16}}=\frac{1}{\sqrt{2}}(1-i),W^6=\frac{1}{\sqrt{2}}(-1-i)で乗算が減る

16基底以上のFFT

  • 16基底以上も同じように回転因子をまとめていくことで構成できる
  • しかし、16基底以上では8基底までで見られたような乗算の削減はできない
  • 計算量としては基底が大きいほど下がる
  • とはいえ基底を大きくするほど並列性が下がる面もあるのでレジスタの本数やメモリアクセス量との兼ね合いが必要

AVX2について

  • x86で定義されているベクトル拡張命令の一つ
    • 256bit、倍精度4つ分を1サイクルで計算することができる
  • コンパイラに任せるとうまく使ってくれないことがあるので、FFTのように並列計算をゴリゴリやる場合にはアセンブラを書いたほうが速いことがある
    • 配列の加算程度ならコンパイラに-march=nativeを渡せば勝手に使ってくれる

TFHEでの利用

  • FFTを適切に書けば多項式乗算をこれに置き換えたほうが速い
    • 32bit固定小数点数を倍精度に変換して計算する
    • 倍精度には符号bitがあるので、符号付きとして解釈したほうが良い
    • FFTは実数の上で定義されるので多項式乗算をしている間はトーラスをただの実数とみなし乗算後に剰余を取る
  • GPUの場合は倍精度が弱いので数論変換(NTT)を使ったほうが良い

TFHEの多項式乗算における精度要求について

  • 倍精度は53bitしか精度がない
    • External Productでは最大で(Bg21)(2311)(\frac{Bg}{2}-1)⋅(2^{31}-1)NlN⋅ l個足したものが結果の係数として出てくる
    • これを正確に保持するには少なくとも(Bgbit1+31+Nbit+log2((k+1)l))(Bgbit-1+31+Nbit+log_2 ((k+1)⋅ l))bitの精度が必要
    • 今回は81+31+9+log26<50<538-1+31+9+log_2 6<50<53で問題ない
    • どうせ後で32bitにまるめてしまうので実際は多少はみ出ても問題ない

実装について

  • 暗号の本筋とは関係ないのでTFHEの論文実装で使われているSPQLIOSを使うのが良いようには思う
  • 有名所のFFT実装としてはIntel IPPやFFTWがあるがこれらは一般的なFFT用途を想定しているためにTFHEで用いるとSPQLIOSに勝てない
  • GPUでFFTをするcuFFTもあるがGPUならcuFHEのようにNTTを実装したほうが速い
  • SVE2やHIPでの実装はあり?
    • AVX512に関しては[]
    • NTTに関してはIntel HEXLがある

参考文献

page_number: true