TFHE実装入門

7.FFTによる多項式乗算

松岡 航太郎

モチベーション

  • TFHEにおいて一番重い処理は多項式乗算(External Product)
  • ここを速くすることが全体の高速化に大きく寄与する
  • 今回の話を一般化すると多倍長精度の乗算につながる
  • 8基底FFTまで説明する

FFTとは

  • Fast Fourier Transformの略で日本語では高速フーリエ変換という
  • これは離散フーリエ変換と呼ばれる処理を高速に実行するアルゴリズム
  • なので離散フーリエ変換を使って多項式乗算を表現することを先に行う

離散フーリエ変換とは

  • 文字通りフーリエ変換を離散化したもの
  • Descrete Fourier TransformでDFTとよく呼ぶ
  • フーリエ変換は時間信号を複数の周波数成分に分解する処理
  • ここでは多項式乗算に転用するだけなのでフーリエ変換については割愛する
  • 入力となる多項式をa[X]a[X]、出力をA[X]A[X]としよう
  • このとき離散フーリエ変換は以下のように与えられる
  • 指数の符号は逆でもいい

At=x=0N1axei2πtxN,t{0,1,...,N1}A_t = ∑^{N-1}_{x=0} a_x⋅e^{-i\frac{2πtx}{N}},t∈\{0,1,...,N-1\}

逆離散フーリエ変換

  • フーリエ変換の重要な性質として、逆変換が存在することがある
  • 離散フーリエ変換でも逆変換がありInverseでIDFTという
  • 定数倍と指数の符号くらいしか変わらない
  • ちなみに順と逆のどちらを定数倍してもいいし符号も逆にしていい
  • 実際にもとに戻ることは自分で確かめるか適当な本を参照してほしい
  • 行列表現にしてエルミート行列であることを使うのが簡単かも?
  • 1N\frac{1}{N}を順の方にかけたり両方に1N\frac{1}{\sqrt{N}}を掛ける流儀もある

ax=1Nt=0N1Atei2πxtN,x{0,1,...,N1}a_x = \frac{1}{N}∑^{N-1}_{t=0} A_t⋅e^{i\frac{2πxt}{N}},x∈\{0,1,...,N-1\}

畳み込み定理

  • フーリエ変換の重要な性質として畳み込み定理がある
  • これはフーリエ変換をした結果(周波数成分)の要素ごとの積をとったものの逆変換が入力の畳み込みになっているという定理
  • 離散フーリエ変換での畳み込み定理がXN1X^N-1を法とする多項式乗算になる
  • つまり、a[X]b[X]c[X]modXN1a[X]⋅b[X]≡c[X]\mod{X^N-1}が以下と同値になる

ck=1N[l=0N1(n=0N1anei2πnkNm=0N1bmei2πmkN)ei2πklN]c_k = \frac{1}{N}[\sum^{N-1}_{l=0}(\sum^{N-1}_{n=0}a_ne^{-i\frac{2\pi n k}{N}} \cdot \sum^{N-1}_{m=0}b_me^{-i\frac{2\pi m k}{N}})e^{i\frac{2\pi k l}{N}}]

補題:指数が非ゼロなら回転因子の総和がゼロになること

  • 前の式が同値なことを示すには以下の式が成り立つことを示す必要がある
  • これはすべての係数が1の多項式の離散フーリエ変換は定数項がNNで他が0の多項式であることを意味する
  • 離散フーリエ変換に掛けるとたしかにそうなることが簡単にわかる

n=0N1e2πlnN={N if l0modN0 otherwise\sum^{N-1}_{n=0} e^{\frac{2\pi l n}{N}} = \begin{cases}N\ if\ l\equiv 0\bmod{N}\\ 0\ otherwise\end{cases}

畳み込み定理(続き)

  • 補題を利用してさっきの式を展開してみよう
  • XN1X^N-1を法とするので上の項が下の項に足し算として回り込んでくる
  • このような回り込み方を正巡回という
  • TFHEではXN+1X^N+1が法なので逆の負巡回が欲しく、工夫がいる

ck=1N[l=0N1(n=0N1anei2πnlNm=0N1bmei2πmlN)ei2πklN]=1N[l=0N1n=0N1m=0N1anbmei2π(kn+m)lN]=n+m=kanbm+n+m=k+Nanbm\begin{aligned} c_k &= \frac{1}{N}[\sum^{N-1}_{l=0}(\sum^{N-1}_{n=0}a_ne^{-i\frac{2\pi n l}{N}} \cdot \sum^{N-1}_{m=0}b_me^{-i\frac{2\pi m l}{N}})e^{i\frac{2\pi k l}{N}}]\\ &= \frac{1}{N}[\sum^{N-1}_{l=0}\sum^{N-1}_{n=0}\sum^{N-1}_{m=0}a_nb_me^{-i\frac{2\pi (k-n+m)l}{N}}]\\ &= \sum_{n+m = k}a_nb_m + \sum_{n+m = k+N}a_nb_m \end{aligned}

離散荷重変換

  • 負巡回を作るにはこれがいる
  • Descrete Weighted fourier TransformでDWTともいう
  • 離散ウェーブレット変換とは関係ない
  • 定義的には下の式のように重みwxw_xをかけて(逆)離散フーリエ変換をするだけ
  • ちゃんともとに戻る
  • 重みをうまく選べば負巡回乗算が実現できる

At=x=0N1wxaxei2πtxN,t{0,1,...,N1}ax=1wxNt=0N1Atei2πxtN,x{0,1,...,N1}\begin{aligned} A_t &= ∑^{N-1}_{x=0} w_xa_x⋅e^{-i\frac{2πtx}{N}},t∈\{0,1,...,N-1\}\\ a_x &= \frac{1}{w_xN}∑^{N-1}_{t=0} A_t⋅e^{i\frac{2πxt}{N}},x∈\{0,1,...,N-1\} \end{aligned}

負巡回乗算

  • wx=ei2πtx2Nw_x = e^{i\frac{2πtx}{2N}}と選んで畳み込み乗算
  • さらに上位N2\frac{N}{2}桁分を虚数部に入れておく
  • 虚数部も使うとメモリと計算量の両方を節約できる
  • 実数部と虚数部を連続した別の配列にすると変換が簡単

l=0N211N2ei2πl2N{m=0N21[n=0N21ei2πn2N(an+iaN2+n)ei2πlnN2n=0N21ei2πn2N(bn+ibN2+n)ei2πlnN2]ei2πlmN2}xl=l=0N21[ei2πl2Nn+m=lei2πl2N(an+iaN2+n)(bm+ibN2+m)+n+m=l+N2ei2π(l+N2)2N(an+iaN2+n)(bm+ibN2+m)]xl=l=0N21[n+m=l(an+iaN2+n)(bm+ibN2+m)+n+m=l+N2i(an+iaN2+n)(bm+ibN2+m)]xl=l=0N21{[n+m=l(anbmaN2+nbN2+m)n+m=l+N2(anbN2+m+aN2+nbm)]+i[n+m=l(anbN2+m+aN2+nbn)+n+m=l+N2(anbmaN2+nbN2+m)]}xl\begin{aligned} \sum^{\frac{N}{2}-1}_{l = 0}&\frac{1}{\frac{N}{2}}e^{-i\frac{2\pi l}{2N}}\{\sum^{\frac{N}{2}-1}_{m=0}[\sum^{\frac{N}{2}-1}_{n=0}e^{i\frac{2\pi n}{2N}}(a_n+ia_{\frac{N}{2}+n})e^{-i\frac{2\pi l n}{\frac{N}{2}}} \cdot \sum^{\frac{N}{2}-1}_{n=0}e^{i\frac{2\pi n}{2N}}(b_n+ib_{\frac{N}{2}+n})e^{-i\frac{2\pi l n}{\frac{N}{2}}}]e^{i\frac{2\pi l m}{\frac{N}{2}}}\}x^l\\ &= \sum^{\frac{N}{2}-1}_{l = 0}[e^{-i\frac{2\pi l}{2N}}\sum_{n+m=l}e^{i\frac{2\pi l}{2N}}(a_n+ia_{\frac{N}{2}+n})(b_m+ib_{\frac{N}{2}+m}) + \sum_{n+m=l+\frac{N}{2}}e^{i\frac{2\pi (l+\frac{N}{2})}{2N}}(a_n+ia_{\frac{N}{2}+n})(b_m+ib_{\frac{N}{2}+m})]x^l\\ &= \sum^{\frac{N}{2}-1}_{l = 0}[\sum_{n+m=l}(a_n+ia_{\frac{N}{2}+n})(b_m+ib_{\frac{N}{2}+m}) + \sum_{n+m=l+\frac{N}{2}}i(a_n+ia_{\frac{N}{2}+n})(b_m+ib_{\frac{N}{2}+m})]x^l \\ &= \sum^{\frac{N}{2}-1}_{l = 0}\{[\sum_{n+m=l}(a_nb_m-a_{\frac{N}{2}+n}b_{\frac{N}{2}+m}) - \sum_{n+m=l+\frac{N}{2}}(a_nb_{\frac{N}{2}+m}+a_{\frac{N}{2}+n}b_m)]+i[\sum_{n+m=l}(a_nb_{\frac{N}{2}+m}+a_{\frac{N}{2}+n}b_n) + \sum_{n+m=l+\frac{N}{2}}(a_nb_m-a_{\frac{N}{2}+n}b_{\frac{N}{2}+m})]\}x^l \end{aligned}

高速フーリエ変換

  • ここまで見たように離散フーリエ変換ができると負巡回乗算ができる
  • ∴離散フーリエ変換が高速化できれば負巡回乗算が高速化できる
  • そのアルゴリズムが高速フーリエ変換

2基底周波数間引き高速フーリエ変換(数式)

  • 高速フーリエ変換の基本的アイデアは問題の再帰的分割
  • nZ+n∈\mathbb{Z}^+個に分割するとnn基底
  • 以下のような式から2つに分割できることがわかる

At=x=0N1axei2πtxN=x=0N21axei2πtxN+ax+N2ei2πt(x+N2)N={x=0N21(ax+ax+N2)ei2πt2xN2ift0mod2x=0N21[(axax+N2)ei2πxN]ei2πt12xN2otherwise\begin{aligned} A_t &= ∑^{N-1}_{x=0} a_x⋅e^{-i\frac{2πtx}{N}}\\ &= ∑^{\frac{N}{2}-1}_{x=0} a_x⋅e^{-i\frac{2πtx}{N}}+a_{x+\frac{N}{2}}⋅e^{-i\frac{2πt(x+\frac{N}{2})}{N}}\\ &=\begin{cases}∑^{\frac{N}{2}-1}_{x=0} (a_{x}+a_{x+\frac{N}{2}})⋅e^{-i\frac{2π\frac{t}{2}x}{\frac{N}{2}}} \qquad if\quad t≡0 \bmod 2\\∑^{\frac{N}{2}-1}_{x=0} [(a_{x}-a_{x+\frac{N}{2}})⋅e^{-i\frac{2πx}{N}}]⋅e^{-i\frac{2π\frac{t-1}{2}x}{\frac{N}{2}}}\qquad otherwise\end{cases} \end{aligned}

2基底周波数間引き高速フーリエ変換(バタフライ図)

  • FFTのアルゴリズムを図示したものがバタフライ図
  • AAの添字がぐちゃぐちゃになる(ビットリバース)のは今示しているのがCooley-Tukey型と呼ばれるin-placeなアルゴリズムであるため
  • 普通は添字を戻す並び替えかout-placeなStockham型を使う
  • 今回は畳み込み乗算さえできればいいので並び替えをする必要がない
  • 論文実装のFFTが自前な理由の一つはこの並び替えがいらないのを利用するため

2基底周波数間引き高速フーリエ変換(疑似コード)

  • in-placeなので受け取った配列を書き換えながら実行される
  • 簡潔なので再帰で書いているが最適化しづらいのでForで書くことが多い
Radix2FFT(𝐚,start,end) //startとendは最初0とN/
  N = start-end
  n = (start-end)/2
  for x from 0 to n-1
    temp = aₓ+aₓ₊ₙ
    aₓ₊ₙ = (aₓ-aₓ₊ₙ)*exp(-i2πx/N)
    aₓ = temp
  if(N>2)
    Radix2FFT(𝐚,start,start+n)
    Radix2FFT(𝐚,start+n,end)

2基底時間間引き高速フーリエ変換

  • 周波数間引きの逆変換を時間間引きという
  • ぐちゃぐちゃになった方の並びを入力とする

4基底周波数間引き高速フーリエ変換(バタフライ図)

  • バタフライ図をぐっと睨むと回転因子を移動させて以下のように分解できることがわかる
  • W2=ei2π28=iW^2=e^{-i\frac{2π2}{8}}=-iなのでこれは乗算にならない

4基底周波数間引き高速フーリエ変換(疑似コード)

Radix4FFT(𝐚,start,end)
  N = start-end
  if(N==2)
    s = start //下付きにするため
    temp = aₛ+aₛ₊₁
    aₛ₊₁ = aₛ - aₛ₊₁
    aₛ = temp
  else
    n = (start-end)/4
    for x from 0 to n-1
      for j from 0 to 1
        l = j*n
        m = l+2*n
        temp = aₓ₊ₗ+aₓ₊ₘ
        aₓ₊ₘ = (aₓ₊ₗ-aₓ₊ₘ)*i //実数部と虚数部の交換と符号反転
        aₓ₊ₗ = temp
      for j from 0 to 1
        l = j*2*n
        m = l+n
        if(j==0)
          temp = aₓ₊ₗ+aₓ₊ₘ
          aₓ₊ₘ = (aₓ₊ₗ-aₓ₊ₘ)*exp(-i2π2x/N)
        else
          temp = (aₓ₊ₗ+aₓ₊ₘ)*exp(-i2πx/N)
          aₓ₊ₘ = (aₓ₊ₗ-aₓ₊ₘ)*exp(-i2π3x/N)
        aₓ₊ₗ = temp
    Radix4FFT(𝐚,start,start+n)
    Radix4FFT(𝐚,start+n,start+2n)
    Radix4FFT(𝐚,start+2n,start+3n)
    Radix4FFT(𝐚,start+3n,end)

Split Radix

  • 4基底で計算量削減効果があるのは下半分だけ
  • 上半分は2基底でやめるとSplit-Radixになる
  • メモリアクセスやプログラムの煩雑さの都合でSplit-Radixはあまりやらない
  • 8基底でも同じことはできる

8基底周波数間引き高速フーリエ変換(バタフライ図)

  • N=16なのでW2=ei2π216=12(1i),W6=12(1i)W^2=e^{-i\frac{2π2}{16}}=\frac{1}{\sqrt{2}}(1-i),W^6=\frac{1}{\sqrt{2}}(-1-i)で乗算が減る

width:500px

16基底以上のFFT

  • 16基底以上も同じように回転因子をまとめていくことで構成できる
  • しかし、16基底以上では8基底までで見られたような乗算の削減はできない
  • 計算量としては基底が大きいほど下がる
  • とはいえ基底を大きくするほど並列性が下がる面もあるのでレジスタの本数やメモリアクセス量との兼ね合いが必要

AVX2について

  • x86で定義されているベクトル拡張命令の一つ
  • 256bit、倍精度4つ分を1サイクルで計算することができる
  • コンパイラに任せるとうまく使ってくれないことがあるので、FFTのように並列計算をゴリゴリやる場合にはアセンブラを書いたほうが速い
  • 配列の加算程度ならコンパイラに-march=nativeを渡せば勝手に使ってくれる

TFHEでの利用

  • FFTを適切に書けば多項式乗算をこれに置き換えたほうが速い
  • 32bit固定小数点数を倍精度に変換して計算する
  • 倍精度には符号bitがあるので、符号付きとして解釈したほうが良い
  • FFTは実数の上で定義されるので多項式乗算をしている間はトーラスをただの実数とみなし乗算後に剰余を取る

TFHEの多項式乗算における精度要求について

  • 倍精度は53bitしか精度がない
  • External Productでは最大で(Bg21)(2311)(\frac{Bg}{2}-1)⋅(2^{31}-1)NN個足したものが結果の係数として出てくる
  • これを正確に保持するには少なくとも(Bgbit1+31+Nbit)(Bgbit-1+31+Nbit)bitの精度が必要
  • 今回は61+31+10=46<536-1+31+10=46<53で問題ない
  • どうせ後で32bitにまるめてしまうので実際は多少はみ出ても問題ない

実装について

  • 暗号の本筋とは関係ないのでTFHEの論文実装で使われているSPQLIOSを使うのが良いようには思う
  • 有名所のFFT実装としてはIntel IPPやFFTWがあるがこれらは一般的なFFT用途を想定しているためにTFHEで用いるとSPQLIOSに勝てない
  • GPUでFFTをするcuFFTもあるがGPUならcuFHEのようにNTTを実装したほうが速い
  • AVX512やHIPでの実装はあり?

参考文献

page_number: true