An Attention Free Transformer

概要

Transformerは従来のRNNやCNNと比べて計算量を大幅に減らしたが、まだその計算量は大きいままである。ボトルネックとなっているのは、Attention layerにおける行列積の計算で、テキストサイズの2乗に比例した時間がかかっている。そこで、行列積に近似した別の手法を用いて計算量を削減した、Attention Free Transformer(AFT)を提案する。また、経験的に、Transformerはlocal patternを示す(ナニコレ 多分Attentionしてる部分がlayer, head間でほぼ変わってないことを示している)ことから、AFT-localとAFT-convを提案する。

前提知識

Transformerのモデル構造は以下から。

usapyoi.hatenablog.com

詳細はリンクを見ればわかるが、要するにAttention layerの出力を計算する際に {\bf Q}, {\bf K}, {\bf V}の行列積を求める必要があり、それがボトルネックになっているらしい。

導入

AFTもTransformerと同様に {\bf Q}, {\bf K}, {\bf V}から成るが、まず違いとしてあげられるのは、AFTの出力が以下のようになっていることだ。

AFT(の中の層)の出力

ここで、 w_{t}は位置バイアスで、学習対象のようだ。  {\bf Q}, {\bf K}, {\bf V}による式を弄るような研究はいくつもなされてきたが、それらとの違いは本研究が行列積ではなくアダマール積(element-wise product)を採用している点である。

AFT-localはglobal connectivity(だからナニコレ)を維持しつつ、ある程度のバリエーションを残すらしい。つまり、Attentionする箇所が全てのlayer, head間で同じに成る程の平均化はしないということかも。

AFT-convはspatial weight sharing(畳み込みのことらしい。最初からそう書け)を課すことで上述の設計を拡張し、global receptive field(?????????)を持ったCNNの一種のような形。

方法論

前節でも述べたが、AFTでAttention layerの代わりになるのは以下のような式だ。


{\bf Y}_{t} = {\rm Sigmoid}({\bf Q}_{t}) \bigodot \frac{\sum^{T}_{\tau = 1}(K_{\tau}+w_{t,\tau}) \bigodot {\bf V}_{\tau}}{\sum^{T}_{\tau = 1}(K_{\tau}+w_{t,\tau})}

いろんな説明が欠如、もとい、省略されている。まず、 {\bf K}_{\tau}={\bf X}{\bf W}_{\tau}^{K} {\bf X} \in R^{T \times d}, {\bf W}_{\tau}^{K} \in R^{d \times d_{k}}より {\bf K}_{\tau} \in R^{T \times d_{k}}である一方で、 w \in R^{T \times T}とある。おそらくこれは、 w_{t, \tau}スカラーであり、行列の各要素にそのスカラーを加算しているのだろう。また、 {\bf V}_{\tau}についても同様に {\bf V}_{\tau} \in {T \times d_{v}}であるが、特にことわりのない限り d_{k}=d_{v}らしいので、アダマール積は可能である。また、割り算については各要素ごとの除算…なのだろうか…おそらくその後 {\rm Sigmoid}({\bf Q}_{t}) \in R^{T \times d_{k}}アダマール積を取っているのでそうかと思われる。

AFT-local

学習済の w_{t, \tau}をいじった形。そんなことしていいのだろうか?


w\_{t, \tau} = 
\left\{
\begin{array}{ll}
w\_{t, \tau} & if |t - \tau| < s \\
0 & otherwise
\end{array}
\right.

ここで、 s \leq Tはlocal window sizeと呼ばれ、ローカル性を制御する変数となっている。計算量的には大助かりの設計のようですね。

AFT-simple

AFT-localで s = 0とした場合。位置バイアス w_{t,\tau}が消えてスッキリします。


{\bf Y}_{t} = {\rm Sigmoid}({\bf Q}_{t}) \bigodot \frac{\sum^{T}_{\tau = 1}(K_{\tau}) \bigodot {\bf V}_{\tau}}{\sum^{T}_{\tau = 1}(K_{\tau})}\\
\quad = {\rm Sigmoid}({\bf Q}_{t}) \bigodot \sum^{T}_{\tau = 1} {({\rm softmax}(K) \bigodot V)}_{\tau}

AFT-conv

前述のローカル性の考え方を延長させてspatial weight sharing…つまり畳み込みを組み込んだ形。


{\bf Y}_{t}^{i} = {\rm Sigmoid}({\bf Q}_{t}^{i}) \bigodot \frac
{{\rm conv1d}({\rm exp}(K^{i}) \bigodot V^{i},  {\           rm exp}(w^{i}) - 1)
  +  {\sum}_{\tau = 1}^{T} {\rm exp}(K_{\tau}^{i}) \bigodot V_{\tau}^{i} }
{{\rm conv1d}({\rm exp}(K^{i}),  {\rm exp}(w^{i}) - 1)
  +  {\sum}_{\tau = 1}^{T} {\rm exp}(K_{\tau}^{i})}

で、この {\rm conv1d}はdepth-wise separable 1d convolution operationらしい。なんのこっちゃといった感じだが、入力画像群の1枚(channel)に対して1枚のフィルタしか用意しない畳み込み(channel-wise convolution)と、複数のchannelにおける入力画像群の各1点(point)に対してしか行わない畳み込み(point-wise convolution)の組み合わせによって行う畳み込みらしい。この辺は改めて記事を書くつもり。

シミュレーション

行ったのは以下の3つ。

  1. image autoregressive modeling

時系列データを予測する自己回帰モデルを画像、つまり二次元に拡張したものかと思われる。左上から右下まで画素を精査していき、ある画素はそれまでの画素に基づいた確率分布から生成されたものと捉えるらしい。わからん。

とりあえずリンクを貼っておく。後で読む。

uvadlc-notebooks.readthedocs.io

  1. character level language modeling

入力文字から次の文字を予測するタスク。例えば「I」と打ったら次に来るのは「'm」ではないか?のように。詰まるところ予測変換、オートフィルのようなもの。

  1. image classification

画像認識。

結果

学習速度や使用メモリの改善が見られた。精度の結果書いてないんですがどこにあるんですかね?

感想

畳み込みニューラルネットの知識が足りてないのと、Transformerって結局どのパラメタが最適化されるんだ?ってのと。ググっても論文読んでも「○○のデータセット使って〜」とか「GPUは△△の□□で〜」とかで学習則どこにも書いてないし。RWKVの理解のために手を出したはいいけど…うーん。