ブログ 学習 機械学習 深層学習 自然言語処理

【Transformerの基礎】Multi-Head Attentionの仕組み

本記事では、Transformerの基礎として、Multi-Head Attentionの仕組みを分かりやすく解説します。

本記事の構成は、はじめにTransformerおよびTransformer Blockについて紹介し、TransformerにおけるMulti-Head Attentionの重要性について説明します。その後に、Multi-Head Attentionがどのような仕組みで実現されているのかを説明します。本記事を読めばMulti-Head Attentionについて、概要から計算式に至るまで、ほぼ全てをご理解いただけると思います。

解説動画は以下になります。

Transformer

Transformer[1]とは2017年に高精度な機械翻訳モデルとして登場した機械学習モデルです。Transformerが提案された論文のタイトルは「Attention Is All You Need」と特徴的です。

Transformerは、従来の機械翻訳モデルで使用されていたSeq2Seqの課題である、大域的な特徴やニュアンスを捉えにくい点と、学習の高速化が難しい点に対して、その解決策として提案されたモデルです。詳しくは以下の記事で解説しています。

Transformerは従来から使用されてきたCNNやRNNとは異なり、Attention機構と呼ばれる仕組みに基づいたエンコーダデコーダモデルです。Attention機構とは注意表現を学習する仕組みです。LSTMなどで使用されているゲート機構に似ており、必要な情報は強調し、不必要な情報は目立たなくする仕組みです。より良い注意表現を学習できれば膨大な情報から注意を向けるべき情報に焦点を当てて処理をすることが可能になります。注意の英単語であるAttention(アテンション)からきています。TransformerではMulti-Head Attentionと呼ばれるAttention機構が使われています。下図にTransformerの構造を示します。

Transformer

論文で提案されたTransformerは、上図のようなエンコーダデコーダ構造をしていますが、Transformerの本質的な機能単位で考えると、エンコーダ部分だけを見れば十分です。これが次に説明するTransformerの基本単位、Transformer Blockです。

Transformer Block

Transformer Block[2]はTransformerの基本単位で、Multi-Head AttentionとPosition-Wise Feed-Forward Networksからなります(下図)。BERTやGPTなどのTransformerベースのモデルにおいて、TransformerとはTransformer Blockのことを意味していることが大半です。

Transformer Block

Transformer Blockは、Query、Key、Valueの3つの入力を持ち、Multi-Head Attentionを使って注意表現に基づく計算をします。Position-Wise Feed-Forward Networks(図中、Feed Forward)では、Multi-Head Attentionの出力結果をトークン毎に変換します。トークン毎に計算するためPosition-Wiseと呼ばれます。また、各々の層に平行したショートカットパスはSkip-Connectionと呼ばれ、勾配消失が起こりにくくしモデルの深層化を可能にします(以下で補足しています)。

Multi-Head Attentionは、入力の直前にLinear層を持つScaled Dot-Product Attention(=Single-Head Attention[3])を複数並列に配置した構造をしています。Scaled Dot-Product Attentionは名前の通り、内積に基づく注意計算を行いますが、それ単体では学習パラメータを持たず、学習によりチューニングできないため、直前に学習パラメータを持つLinear層を設けて多種多様な特徴部分空間における注意表現の学習を可能にしています。学習により獲得された注意表現を可視化すると、各ヘッドが異なる注意表現を獲得していることが確認されています[1,3,4]。ヘッドが1つの場合は、Single-Head Attentionになりますが、複数並列にすることでモデルの表現能力を高くしているのです。

Position-Wise Feed-Forward Networksとは、Multi-Head Attentionの後に続く順伝播型ニューラルネットワークです。先に説明した通りPosition-Wiseには位置単位でという意味があり、各トークン毎に計算をします。Multi-Head Attentionの中にLinear層があり、これもニューラルネットワークの一種ですが、活性化関数が無いため、非線形変換をすることができません。ニューラルネットワークによる学習において、非線形の活性化関数を使うことには大きな意味があります。その為、Multi-Head Attentionの直後にPosition-Wise Feed-Forward Networksがあるのです。

複雑な構造をしていますが、根本はScaled Dot-Product Attentionであり、それさえ理解できてしまえば、あとは表現能力を上げるための工夫であり、Multi-Head Attentionについても容易に理解することが可能です。そこで本記事では、Scaled Dot-Product Attentionについて最初に詳細を説明した後に、その表現能力を向上させたSingle-Head Attention、さらに表現能力を向上させたMulti-Head Attentionについて順番に解説をします。

Skip-Connectionについて補足します。Skip-Connectionとは、ResNet[5]で提案された、層をスキップする情報伝達経路のことをいいます。Skip-Connectionを設けることで、恒等写像を実現できるようになり、また、学習時の勾配消失問題の改善に寄与するため、ネットワークの深層化を実現しました。下図は、2つの層(Layer 1とLayer 2)をスキップする情報伝達経路がある例を表しています。\(\boldsymbol{x}\)が入力されたとき、2つの層による計算結果は\(\mathcal{F}(\boldsymbol{x})\)でSkip-Connectionにより\(\boldsymbol{x}\)がそのまま伝達されるので、最終的な出力\(\mathcal{H}(\boldsymbol{x})\)は\(\mathcal{F}(\boldsymbol{x})+\boldsymbol{x}\)になります。

Skip-Connectionの概要

Scaled Dot-Product Attention

Scaled Dot-Product Attentionとは、日本語でスケール化内積注意と訳すことが可能であるように、内積を利用したベクトル間の類似性に基づく変換を行います。入力は、3つでQuery、Key、Valueと呼ばれます。これらのネーミングはデータベースで使われているワードから来ています。内部で行われている操作はデータベースとは大きく異なりますが、Queryに基づいてKeyに何らかの変更を施し、Valueを取り出してくる操作をするため、なんとなく似ている感じはしますね。

Scaled Dot-Product Attentionの構造を下図に示します。

Scaled Dot-Product Attentionの構造

3つの入力を持ちますが、ここでは仕組みを理解しやすくするためにKeyとValueは同じ入力として説明をします。上図は複雑な構造を表しているように見えますが、単純に\(\text{softmax}\left(\frac{\boldsymbol{QK^T}}{\sqrt{d_k}}\right)\)を計算しているにすぎず、学習要素は全くありません。

具体例を見た方が理解が進むと思うので、具体例を見てみましょう。Queryには任意の単語を表すベクトルを、KeyとValueには文章を単語ごとにベクトル化して並べた行列を入力することを考えます(ちなみに単語のベクトル化では埋め込み層を用いることができます)。例えば、「I have a pen.」という文章を単語ごとにトークン化しベクトルに変換したとします(下図)。色の違いはベクトルの違いを表しています。

ここで、Query(\(\boldsymbol{Q}\))として「have」を表すベクトルを、Key(\(\boldsymbol{K}\))とValue(\(\boldsymbol{V}\))には各ベクトルを順番に並べた行列を与えることにします。

このとき、ソフトマックス関数の入力である、\(\frac{\boldsymbol{QK^T}}{\sqrt{d_k}}\)では、Queryのベクトルと入力文章の全ベクトルに対して内積を計算し、Keyのベクトルの次元数\(d_k\)のルートでスケール化したベクトルを計算します。すなわち、「I have a pen.」を構成するトークンに対して、「have」を表すベクトルにどの程度類似しているかを計算するのです。スケール化についてはソフトマックス関数との兼ね合いで必要になります。内積の計算結果のままでは扱いにくいため、ソフトマックス関数を適用し、類似している単語ほど1に近く、相違している単語ほど0に近く、また全体の和が1になるような変換を施します。この結果を用いて、Valueと内積を計算します。Valueとの内積とは文章を構成する各ベクトルのQueryベクトルとの類似度に基づく線形結合になります。一言でまとめると、Scaled Dot-Product AttentionではQueryベクトルとKeyの各ベクトルの類似性に基づいてValueの各ベクトルの線形結合を計算します。この計算の流れを下図に示します。

Scaled Dot-Product Attentionの処理

Scaled Dot-Product Attentionがどのような計算処理を行うか理解できたところで、その働きについて説明します。下図に先ほどの例におけるScaled Dot-Product Attentionの働きを示します。KeyとValueには「I have a pen.」を表す行列を、Queryには「have」を表すベクトルを与えると、「have」のベクトルとの類似性に基づいて文章全体の単語ベクトルを用いて線形結合を計算しベクトルを出力します。出力されるベクトルのサイズはQueryのベクトルと同じです。その為、出力ベクトルを単語ベクトルとして考えることができます。これを再帰的に行えば、文章を生成することが可能であることが分かりますね。

Queryに用いるベクトルは、KeyやValueに含まれるベクトルである必要はありません(下図)。

このように、KeyやValueは同じで、Queryは異なる入力であるAttentionをSourceTarget Attentionと呼びます。一方で、QueryもKeyやValueと同じにすることもでき、そのようなAttentionをSelf-Attentionと呼びます。

TransformerではSelf-Attentionが多く使われており、重要な意味を持ちます。下図は、Scaled Dot-Product AttentionをSelf-Attentionとして使用した場合の例を表しています。

Scaled Dot-Product Attentionによる注意表現

以上が、Scaled Dot-Product Attentionの説明になります。このように入力されたベクトルの類似性に基づいて計算するのですが、Scaled Dot-Product Attentionには大きな欠点があります。それは、学習パラメータを持たないという点です。すなわち、多種多様な特徴部分空間における注意表現を学習することができず、埋め込み層により変換された分散表現におけるベクトルの類似性に基づくワンパターンな変換しかできないのです。そこで、多種多様な特徴部分空間における注意表現を学習可能にするために、Query、Key、Valueの入力の直前に学習パラメータをもつLinear層を用意し、多種多様な特徴部分空間の注意表現を学習できるようにしたのが次に説明するSingle-Head Attentionです。

※Scaled Dot-Product Attentionの図中、Mask部分について説明をします。Generatorの学習では、文章のはじめだけを見せて、その後を予測するように学習を行います。Keyには文章全体が格納されていますので、そのまま処理してしまうと、予測したい未来の情報が入力に含まれることになります。そこで、内積をとった後にマスク処理をして未来の情報を見れないようにしているのです。マスクしてしまえば、マスクされベクトルはソフトマックスによる線形結合で使われなくなります。

Single-Head Attention

Single-Head Attentionとは、学習パラメータを持たないScaled Dot-Product Attentionの表現能力を広げるために、各入力の直前に学習パラメータを持つLinear層を追加したものです。これにより、入力されるベクトルの特徴空間に依存しない注意表現を学習することが可能になります。

Single-Head Attention

すなわち、先ほどのScaled Dot-Product Attentionでは、埋め込み層により変換された分散表現における注意表現しか学習できませんでしたが、Linear層が追加されたことで必ずしも単語ベクトル同士が近くなくても、Linear層で近くなるように変換するように重みパラメータを調節すればよいことがわかります。これが、Single-Head Attentionです。しかし、Single-Head Attentionにも課題があります。私たちが使用する言語には、同じ単語でも意味や文法などが異なるものが多数ありますが、Single-Head Attentionでは、それらを平均化して扱ってしまい、複雑な文章構造に対応できないのです。そこで、考案されたのが次に説明するMulti-Head Attentionです。

Multi-Head Attention

Multi-Head Attentionとは、Single-Head Attentionを多数並列に配置することで、さまざまな注意表現の学習を可能にしたAttention機構です。

原論文には以下のような記述があります。

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.

参考文献[1]の5ページの冒頭

と記載されています。すなわち、Single-Head Attentionでは多種多様な意味や文法をもつ単語に対しても単一の注意表現に平均化されてしまいますが、Single-Head Attentionを多数並列に配置してMulti-Headにすれば、複数の特徴部分空間における注意表現の獲得をすることができます。

Multi-Head Attentionの構造を下図に示します。

Multi-Head Attention

Multi-Head Attentionの中にある個々のSingle-Head Attentionは、下図のようにそれぞれ異なる特徴を学習することが可能[3,4]です。

文献[3]より引用: Multi-Head Attentionによる複数の注意表現の獲得
[4]より引用

このように、文章全体にわたって複数の注意表現の獲得が可能なAttention機構が、Multi-Head Attentionです。さて、Multi-Head Attentionの利点をご理解いただいたところで、多くの方は幾つのSingle-Head Attentionを並列にすればよいかという疑問を持つでしょう。多すぎると冗長になり、少なすぎるとモデルの表現能力が下がってしまいます。Transformerの原論文[1]では8つのSingle-Head Attentionが並列に用いられています。8つと聞いて少ないと感じた方が多いのではないでしょうか?

これでも十分な精度が出る理由について私なりの解釈を述べます。それは、各Single-Head Attentionの出力をConcatしたとに、Linear層が用意されているからだと考えます。私たちの住む世界は3次元で、とても広い世界が広がっていますが、それらはxyzの単位ベクトルの線形結合で表すことができてしまいます。それと同様に、本質的な注意表現は8つしか獲得できなくても、その後のConcatとLinear層により結合&線形変換されるなかで、新たな注意表現に変換することが可能だからではないでしょうか?

少し私なりの解釈を混ぜてしまいましたが、以上を踏まえたときのMulti-Head Attentionの計算式は以下に示すようになります。

$$
\begin{eqnarray} \rm{MultiHead Attention}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}) &=& \rm{Concat}(head_1, head_2, \cdots, head_h)\boldsymbol{W}_o\\ \rm{where}\ head_i &=& \rm{ScaledDotProductAttention}(\boldsymbol{QW}^Q_i, \boldsymbol{KW}^K_i, \boldsymbol{VW}^V_i) \end{eqnarray}
$$

Transformerの論文では、ヘッドの増加に伴う計算量の増加を防ぐために、Scaled Dot-Product Attentionの直前のLinear層で入力ベクトルのサイズをヘッド数\(h\)で除算したサイズに変換し、Scaled Dot-Product Attentionの入力としています。そして、\(h\)個のSingle-Head Attentionの出力を結合して入力ベクトルと同じサイズにしてから、出力直前のLinear層で変換を施して最終的な出力を計算しています。下図は、Self-Attentionとして使用するときのMulti-Head Attention内部における行列サイズの変化を表しています。

まとめ

以上の内容をまとめます。Transformerの本質的な機能単位はTransformer Blockとよばれ、Multi-Head AttentionとPosition-Wise Feed-Forward Networksに分けられます。Multi-Head Attentionは、Scaled Dot-Product AttentionとLinear層からなるSingle-Head Attentionを並列化したものです。

Scaled Dot-Product Attentionは内積によるQueryベクトルとKey行列の各ベクトル間の類似性に基づいて、Value行列の各ベクトルの線形結合を計算します。しかし、学習パラメータを持たず学習により獲得される注意表現を変えることができないため、入力部分に線形層を追加した、Single-Head Attentionが考えられました。しかし、Single-Head Attentionでは、多数の注意表現が全て平均化されてしまうため、複数並列にし、それらの出力を合体させてLinear層による変換を行うようにしたMulti-Head Attentionが提案されました。

Multi-Head Attentionは多数の注意表現を獲得可能となっています。

以上で本記事の内容を終わりにします。最後までお読みいただきありがとうございました。

参考文献

[1] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin, "Attention is all you need," in Proc. NeurIPS, 2017.
[2] Wenyng Duan, Liu Jiang, Ning Wang, and Hong Rao, "Pre-Trained Bidirectional Temporal Representation for Crowd Flows Prediction in Regular Region," IEEE Access, 2019.
[3] Hyeongu Yun, Taegwan Kang, and Kyomin Jung, "Analyzing and Controlling Inter-Head Diversity in Multi-Head Attention," Appl. Sci, 2021.
[4] Kevin Clark, Urvashi Khandelwal, Omer Levy, and Christopher D. Manning, "What Does BERT Look At? An Analysis of BERT's Attention," in Proc. ACL, 2019.
[5] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, "Deep Residual Learning for Image Recognition," in Proc. CVPR, 2016.

  • この記事を書いた人
管理人

管理人

このサイトの管理人です。 人工知能や脳科学、ロボットなど幅広い領域に興味をもっています。 将来の目標は、人間のような高度な身体と知能をもったパーソナルロボットを開発することです。 最近は、ロボット開発と強化学習の勉強に力を入れています(NOW)。

-ブログ, 学習, 機械学習, 深層学習, 自然言語処理

PAGE TOP