Retentive Network: A Successor to Transformer for Large Language Modelsの日本語訳を教えて!
こういった悩みにお答えします.
本記事の信頼性
- リアルタイムシステムの研究歴12年.
- 東大教員の時に,英語でOS(Linuxカーネル)の授業.
- 2012年9月~2013年8月にアメリカのノースカロライナ大学チャペルヒル校(UNC)コンピュータサイエンス学部で客員研究員として勤務.C言語でリアルタイムLinuxの研究開発.
- プログラミング歴15年以上,習得している言語: C/C++,Python,Solidity/Vyper,Java,Ruby,Go,Rust,D,HTML/CSS/JS/PHP,MATLAB,Verse(UEFN), Assembler (x64,ARM).
- 東大教員の時に,C++言語で開発した「LLVMコンパイラの拡張」,C言語で開発した独自のリアルタイムOS「Mcube Kernel」をGitHubにオープンソースとして公開.
- 2020年1月~現在はアメリカのノースカロライナ州チャペルヒルにあるGuarantee Happiness LLCのCTOとしてECサイト開発やWeb/SNSマーケティングの業務.2022年6月~現在はアメリカのノースカロライナ州チャペルヒルにあるJapanese Tar Heel, Inc.のCEO兼CTO.
- 最近は自然言語処理AIとイーサリアムに関する有益な情報発信に従事.
- (AI全般を含む)自然言語処理AIの論文の日本語訳や,AIチャットボット(ChatGPT,Auto-GPT,Gemini(旧Bard)など)の記事を50本以上執筆.アメリカのサンフランシスコ(広義のシリコンバレー)の会社でプロンプトエンジニア・マネージャー・Quality Assurance(QA)の業務委託の経験あり.
- (スマートコントラクトのプログラミングを含む)イーサリアムや仮想通貨全般の記事を200本以上執筆.イギリスのロンドンの会社で仮想通貨の英語の記事を日本語に翻訳する業務委託の経験あり.
こういった私から学べます.
AIのプログラミング言語「C++/Python言語」を学べるおすすめのWebサイトを知りたいあなたはこちらからどうぞ.
独学が難しいあなたは,AIを学べるオンラインプログラミングスクール3社で自分に合うスクールを見つけましょう.後悔はさせません!
国内・海外のAIエンジニアのおすすめ求人サイトを知りたいあなたはこちらからどうぞ. こういった悩みにお答えします. こういった私が解説していきます. 国内・海外のAIエンジニアのおすすめ求人サイト(転職エージェント)を紹介します. AIエンジニアになるためには,主にC++/Pytho ... 続きを見る
国内・海外のAIエンジニアのおすすめ求人サイト【転職エージェント】【C++/Python言語】
国内・海外のプロンプトエンジニアのおすすめ求人サイトを知りたいあなたはこちらからどうぞ.
Retentive Network: A Successor to Transformer for Large Language Modelsの日本語訳を紹介します.
Microsoftの清華大学によるTransformerの後継「RetNet」がわかります.
※図表を含む論文の著作権はRetentive Network: A Successor to Transformer for Large Language Modelsの著者に帰属します.
Retentive Network: A Successor to Transformer for Large Language Modelsの目次は以下になります.
- Abstract
- 1章:Introduction
- 2章:Retentive Networks
- 3章:Experiments
- 4章:Conclusion
- References
- 付録A:Hyperparameters
- 付録B:Grouped Results of Different Context Lengths
Retentive Network: A Successor to Transformer for Large Language Modelsを解説しつつ,私の考えも語ります.
Retentive Network: A Successor to Transformer for Large Language Modelsの概要と私の日本語訳は以下になります.
In this work, we propose Retentive Network (RetNet) as a foundation architecture for large language models, simultaneously achieving training parallelism, low-cost inference, and good performance.
本研究では,大規模言語モデルのための基礎アーキテクチャとして,訓練並列性,低コスト推論,高性能を同時に実現するRetentive Network(RetNet)を提案する.We theoretically derive the connection between recurrence and attention.
我々は,リカレントとAttentionの関係を理論的に導出する.Then we propose the retention mechanism for sequence modeling, which supports three computation paradigms, i.e., parallel, recurrent, and chunkwise recurrent.
そして,並列,リカレント,チャンクワイズリカレントという3つの計算パラダイムをサポートする,シーケンスモデリングのためのRetention機構を提案する.Specifically, the parallel representation allows for training parallelism.
具体的には,並列表現は訓練並列性を可能にする.The recurrent representation enables low-cost \(O(1)\) inference, which improves decoding throughput, latency, and GPU memory without sacrificing performance.
リカレント表現では,低コストで\(O(1)\)推論が可能であり,性能を犠牲にすることなく,デコードスループット,レイテンシ,GPUメモリを向上させることができる.The chunkwise recurrent representation facilitates efficient long-sequence modeling with linear complexity, where each chunk is encoded parallelly while recurrently summarizing the chunks.
チャンクワイズリカレント表現は,チャンクをリカレントに要約しながら各チャンクを並列にエンコードすることで,線形複雑度で効率的な長シーケンスモデリングを容易にする.Experimental results on language modeling show that RetNet achieves favorable scaling results, parallel training, low-cost deployment, and efficient inference.
言語モデリングに関する実験結果は,RetNetが良好なスケーリング結果,並列訓練,低コストのデプロイ,効率的な推論を達成することを示している.The intriguing properties make RetNet a strong successor to Transformer for large language models.
この興味深い特性により,RetNetは大規模言語モデルのためのTransformerの強力な後継となる.Code will be available at https://aka.ms/retnet.
コードはhttps://aka.ms/retnetから入手できる.https://arxiv.org/abs/2307.08621
私の日本語訳の注意点は以下になります.
- 概要は英語と日本語を両方掲載しましたが,本文は私の日本語訳のみを掲載していること(英語で読みたいあなたは原文を読みましょう!)
- 基本的には原文の直訳ですが,わかりにくい箇所は意訳や説明を追加している箇所があること
- 原文の「Acknowledgements」(謝辞)は省略していること
- 本文中に登場する表記「[VSP+17]」などは参考文献ですので,興味がある方は本記事の参考文献を参照されたいこと
それでは,Retentive Network: A Successor to Transformer for Large Language Modelsの本文を読みすすめましょう!
目次
1章:Introduction(はじめに)
※訳注:原文では図1を参照していませんが,こちらに記載しておきます.
The only way to discover the limits of the possible is to go beyond them into the impossible.
Arthur C. Clarke可能性の限界を発見する唯一の方法は,その限界を超えて不可能性の中に入っていくことだ.
アーサー・C・クラークhttps://arxiv.org/abs/2307.08621
Transformer[VSP+17]は,大規模言語モデル[BMR+20]のためのデファクトアーキテクチャとなっており,当初はリカレントモデル[HS97]のシーケンス訓練問題を克服するために提案された.
しかし,Transformerの訓練並列性は,ステップごとに\(O(N)\)の複雑度とメモリに束縛されるキー・バリュー・キャッシュ[Sha19]のために,非効率的な推論の代償となり,Transformerのデプロイで不利になる.
シーケンスの長さが長くなると,GPUのメモリ消費量とレイテンシが増加し,推論速度が低下する.
効率的な\(O(1)\)推論を実現しつつ,Transformerのような訓練並列性と競争力を維持することを目標に,次世代アーキテクチャの開発に多くの努力が続けられてきた.
上記の目標を同時に達成することは困難であり,図2に示すような,いわゆる「不可能の三角形」(トリレンマ)と呼ばれている.
主に3つの研究がある.
1つ目の研究であるLinear Attention[KVPF20]は,標準Attentionスコア\(\exp(q \cdot k)\)をカーネル\(\phi(q) \cdot \phi(k)\)で近似し,自己回帰推論をリカレント形式で書き換えることができる.
しかし,モデリング能力と性能はTransformersより悪く,この手法の人気を妨げている.
2つ目は,訓練並列性を犠牲にしながらも,効率的な推論のためにリカレントモデルに戻る.
救済策として,要素ごとの演算子[PAA+23]がアクセラレーションに使われるが,表現能力と性能が損なわれる.
3つ目の研究は,S4[GGR21]やそのVariants[DFS+22, PMN+23]など,Attentionを他の機構に置き換えることを研究している.
これまでの研究はどれも不可能の三角形を突破することができず,結果としてTransformerと比較して明確な勝者はいない.
本研究では,低コストな推論,効率的な長シーケンスモデリング,Transformerに匹敵する性能,並列モデル訓練を同時に実現するRetentive Network(RetNet)を提案する.
具体的には,並列表現,リカレント表現,チャンクワイズリカレント表現という3つの計算パラダイムを持つ,Multi-Head Attentionに代わるMulti-Scale Retention機構を導入する.
第1に,並列表現は,GPUデバイスをフルに活用するための訓練並列性を強化する.
第2に,リカレント表現により,メモリと計算の点で効率的な\(O(1)\)推論が可能となる.
デプロイコストとレイテンシを大幅に削減できる.
さらに,キー・バリュー・キャッシュのトリックなしで実装が大幅に簡素化される.
第3に,チャンクワイズリカレント表現は効率的な長シーケンスモデリングを行うことができる.
計算速度のために各ローカルブロックを並列にエンコードする一方,GPUメモリを節約するためにグローバルブロックをリカレントエンコードする.
我々はRetNetとTransformerおよびそのVariantsを比較するために広範な実験を行った.
言語モデリングに関する実験結果は,RetNetがスケーリング曲線とコンテキスト内学習の両面で一貫して競争力があることを示している.
さらに,RetNetの推論コストは長さ不変である.
7Bのモデルと8kのシーケンス長の場合,RetNetはキー・バリュー・キャッシュを持つTransformersよりも8.4倍速くデコードし,70%のメモリを節約する.
訓練中,RetNetは標準的なTransformerよりも25~50%のメモリ節約と7倍の高速化を達成し,高度に最適化されたFlashAttention[DFE+22]に対して優位性を持つ.
さらに,RetNetの推論レイテンシはバッチサイズに依存しないため,膨大なスループットを可能にする.
このような興味深い特性により,RetNetは大規模言語モデルのためのTransformerの強力な後継となる.
2章:Retentive Networks
Retentive Network(RetNet)は,Transformer[VSP+17]と同様のレイアウト(すなわち,残差接続,pre-LayerNorm)に従ったL個の同一のブロックでスタックされる.
各RetNetブロックは2つのモジュールを含む:Multi-Scale Retention(MSR)モジュールとFeed-Forward Network(FFN)モジュールである.
以下の節ではMSRモジュールを紹介する.
入力シーケンス\(x = x_1 … x_{|x|}\)が与えられると,RetNetはそのシーケンスを自己回帰的にエンコードする.
入力ベクトル\(\{x_i \}_{i=1}^{|x|}\)はまず\(X^0 = [x_1, ... , x_{|x|}] \in \mathbb{R}^{|x|*d_{model}}\)にパッキングされる.
ここで,\(d_{model}\)は隠れ次元である.
次にコンテキスト化ベクトル表現\(X^l = RetNet_l (X^{l-1}), l \in [1,L]\)を計算する.
2.1節:Retention
本節では,リカレント性と並列性の二重の形式を持つRetention機構を紹介する.
そのため,推論をリカレントに行いつつ,並列的にモデルを訓練することができる.
入力\(X \in \mathbb{R}^{|x|*d_{model}}\)が与えられたら,それを一次元関数\(v(n) = X_n \cdot w_V\)に射影する.
状態\(s_n\)を介して\(v(n) \mapsto o(n)\)を写像するシーケンスモデリング問題を考える.
簡単のために\(v_n, o_n\)を\(v(n), o(n)\)とする.
この写像をリカレントに定式化する.
\begin{align}
s_n &= A s_{n-1} + K_n^\top v_n, & A \in \mathbb{R}^{d*d}, K_n \in \mathbb{R}^{1*d} \tag{1} \\
o_n &= Q_n s_n = \sum_{m=1}^n Q_n A^{n-m} K_m^\top v_m, & Q_n \in \mathbb{R}^{1*d}
\end{align}
ここで,\(v_n\)を状態ベクトル\(s_n\)に対応付け,シーケンス情報をリカレントにエンコードするために線形変換を実行する.
次に,射影\(Q_n, K_n\)の内容を考慮する.
$$ Q = XW_Q, \ \ \ K = XW_K \tag{2} $$
ここで,\(W_Q, W_K \in \mathbb{R}^{d*d}\)は学習可能な行列である.
行列\(A = \Lambda(\gamma e^{i \theta}) \Lambda^{-1}\)を対角化する.
ここで,\(\gamma, \theta \in \mathbb{R}^d\)である.
すると,\(A^{n-m} =\Lambda(\gamma e^{i \theta})^{n-m} \Lambda^{-1}\)が得られる.
\(\Lambda\)を\(W_Q\)と\(W_K\)に吸収させることで,式1を次のように書き換えることができる.
\begin{align}
o_n &= \sum_{m=1}^n Q_n (\gamma e^{i \theta})^{n-m} K_m^\top v_m \tag{3} \\
&= \sum_{m=1}^n (Q_n (\gamma e^{i \theta})^n) (K_m (\gamma e^{i \theta})^{-m} )^\top v_m
\end{align}
ここで,\(Q_n(\gamma e^{i \theta})^n, K_m(\gamma e^{i \theta})^{-m}\)は,xPos[SDP+22],すなわち,Transformerのために提案された相対位置埋め込みとして知られている.
さらに\(\gamma\)をスカラーとして単純化すると,式3は次のようになる.
$$ o_n = \sum_{m=1}^n \gamma^{n-m} (Q_n e^{in \theta}) (K_m e^{im \theta})^\dagger v_m \tag{4} $$
ここで,\(\dagger\)は共役転置である.
この定式化は訓練インスタンス内で容易に並列化可能である.
要約すると,式1のようなリカレントモデリングから始め,式4でその並列定式化を導く.
元の写像\(v(n) \mapsto o(n)\)をベクトルとして考え,以下のようにRetention機構を得る.
Retentionの並列表現:
図3aに示すように,Retention層は次のように定義される.
\begin{eqnarray}
Q = (XW_Q) \odot \Theta,\ \ \ K = (XW_K) \odot \overline{\Theta},\ \ \ V = XW_V \\
\Theta_n = e^{in \theta}, D_{nm} = \begin{cases}
\gamma^{n-m}, & n \geq m \tag{5} \\
0, & n < m
\end{cases} \\
Retention(X) = (QK^\top \odot D)V
\end{eqnarray}
ここで,\(\overline{\Theta}\)は\(\Theta\)の複素共役であり,\(D \in \mathbb{R}^{|x|*|x|}\)は因果的マスキングと相対距離に沿った指数関数的減衰を1つの行列として組み合わせたものである.
Self-Attentionと同様に,並列表現によりGPUで効率的にモデルを訓練することができる.
Retentionのリカレント表現:
図3bに示すように,提案された機構は,推論に有利なリカレントニューラルネットワーク(RNNs:Recurrent Neural Networks)として記述することもできる.
n番目のタイムステップでは,リカレント的に次のような出力を得る.
\begin{align}
& S_n = \gamma S_{n-1} + K_n^\top V_n \tag{6} \\
& Retention(X_n) = Q_n S_n, \ \ \ n = 1, ..., |x|
\end{align}
ここで,\(Q, K, V, \gamma\)は式5と同じである.
Retentionのチャンクワイズリカレント表現:
並列表現とリカレント表現のハイブリッド形式を利用することで,特に長いシーケンスの訓練を高速化することができる.
入力シーケンスをチャンクに分割する.
各チャンク内では,並列表現(式5)に従って計算を行う.
一方,チャンク間の情報はリカレント表現(式6)に従って渡される.
具体的には,チャンクの長さをBとする.
i番目のチャンクのRetention出力を計算する.
\begin{eqnarray}
Q_{[i]} = Q_{Bi:B(i+1)},\ \ \ K_{[i]} = K_{Bi:B(i+1)},\ \ \ V_{[i]} = V_{Bi:B(i+1)} \\
R_i = K_{[i]}^\top V_{[i]} + \gamma^B R_{i-1} \tag{7} \\
Retention(X_{[i]}) = \underbrace{(Q_{[i]} K_{[i]}^\top \odot D) V_{[i]}}_{Inner-Chunk} + \underbrace{(Q_{[i]} R_i) \odot \xi}_{Cross-Chunk},\ \ \ \xi_{ij} = \gamma^{i+1}
\end{eqnarray}
ここで,[i]はi番目のチャンク,すなわち\(x_{[i]} = [x_{(i-1)B+1}, …, x_{iB}]\)を示す.
2.2節:Gated Multi-Scale Retention
各レイヤで\(h=d_{model}/d\)のRetention Headsを使用し,dはHeadの次元である.
Headは異なるパラメータ行列\(W_Q, W_K, W_V \in \mathbb{R}^{d*d}\)を使用する.
さらに,Multi-Scale Retention(MSR)は,各Headに異なる\(\gamma\)を割り当てる.
簡単のため,異なるレイヤ間で\(\gamma\)を同一に設定し,固定とする.
さらに,Swish Gate[RZL17]を追加し,Retention層の非線形性を高める.
形式的には,入力Xが与えられたとき,レイヤを次のように定義する.
\begin{eqnarray}
\gamma &=& 1 - 2^{-5 - arange(0, h)} \in \mathbb{R}^h \\
head_i &=& Retention(X, \gamma_i) \tag{8} \\
Y &=& GroupNorm(Concat(head_1, ..., head_h))) \\
MSR(X) &=& (swish(XW_G) \odot Y)W_O
\end{eqnarray}
ここで,\(W_G, W_O \in \mathbb{R}^{d_{model}*d_{model}}\)は学習可能なパラメータであり,GroupNorm[WH18]は[SPP+19]で提案されたSubLNに従って各Headの出力を正規化する.
Headsが複数の\(\gamma\)スケールを使用するため,分散統計量が異なることに注意されたい.
そこで,Head出力を別々に正規化する.
Retentionの擬似コードを図4にまとめた.
Retentionスコアの正規化:
GroupNormのスケール不変の性質を利用して,Retention層の数値精度を向上させる.
具体的には,GroupNorm内でスカラー値を乗算しても,出力や後方勾配には影響しない.
すなわち,\(GroupNorm(\alpha ∗ head_i) = GroupNorm(head_i)\)である.
式5では3つの正規化係数を実装する.
まず,\(QK^\top\)を\(QK^\top/\sqrt{d}\)として正規化する.
第2に,Dを\(\tilde{D}_{nm} = D_{nm}/\sqrt{\sum_{i=1}^n D_{ni}}\)と置き換える.
第3に,RをRetentionスコア\(R = QK^\top \odot D\)とすると,我々はそれを \(\tilde{R}_{nm} = R_{nm} / \max(| \sum_{i=1}^n R_{ni}|,1)\)として正規化する.
そして,Retention出力は\(Retention(X) = \tilde{R}V\)となる.
上記のトリックは,スケール不変の特性のため,フォワードパスとバックワードパスの両方の数値フローを安定化させながら,最終結果に影響を与えない.
2.3節:Overall Architecture of Retention Networks(Retention Networkの全体アーキテクチャ)
L層Retentionネットワークでは,Multi-Scale Retention(MSR)とFeed-Forward Network(FFN)を積み重ねてモデルを構築する.
形式的には,入力シーケンス\(\{x_i\}_{i=1}^{|x|}\)は単語埋め込み層によってベクトルに変換される.
パックされた埋め込み\(X^0 = [x_1, …, x_{|x|}] \in \mathbb{R}^{|x|*d_{model}}\)を入力とし,モデル出力\(X^L\)を計算する.
\begin{eqnarray}
Y^l &=& MSR(LN(X^l)) + X^l \\
X^{l+1} &=& FFN(LN(Y^l)) + Y^l
\end{eqnarray}
ここで,\(LN(\cdot)\)はLayerNorm[BKH16]である.
FFNの部分は,\(FFN(X) = gelu(XW_1)W_2\)(\(W_1, W_2\)はパラメータ行列)として計算される.
- 訓練:訓練過程では,並列表現(式5)とチャンクワイズリカレント表現(式7)を使用する.シーケンスやチャンク内での並列化は,計算を高速化するためにGPUを効率的に利用する.さらに有利なことに,チャンクワイズリカレントは長いシーケンスの訓練に特に有効であり,FLOPsとメモリ消費量の両方において効率的である.
- 推論:リカレント表現(式6)が推論中に採用され,自己回帰デコーディングにうまく適合する.複雑度が\(O(1)\)であるため,同等の結果を得ながら,メモリと推論のレイテンシが短縮される.
2.4節:Relation to and Differences from Previous Methods(従来の方法との関係および相違点)
表1は様々な観点からRetNetと従来の手法を比較したものである.
比較結果は図2に示した「不可能の三角形」(トリレンマ)と呼応している.
さらに,RetNetはチャンクワイズリカレント表現により,長いシーケンスに対して線形メモリ複雑度を持つ.
また,特定の手法との比較を以下にまとめる.
- Transformer:Retentionの並列表現は,Transformer[VSP+17]と同様の精神を共有している.最も関連性の高いTransformerはLex Transformer[SDP+22]であり,これはxPosを位置埋め込みとして実装している.式3で説明されるように,Retentionの導出はxPosと一致する.Attentionと比較して,Retentionはsoftmaxを削除し,推論に大きな利点をもたらすリカレントな定式化を可能にする.
- S4:式2とは異なり,\(Q_n\)と\(K_n\)がコンテンツ非認識である場合,定式化はS4[GGR21]に縮退することができ,ここで\(O = (QK^\top, QAK^\top, …,QA^{|x|-1}K^\top)∗ V\)となる.
- Linear Attention:このVariantsは通常,softmax関数を置き換えるために,様々なカーネル\(\phi(q_i)\phi(k_j)/ \sum_{n=1}^{|x|} \phi(q_i) \phi(k_n)\)を使用する.しかし,Linear Attentionは位置情報を効果的にエンコードするのに苦労し,モデルの性能を低下させる.また,softmaxの近似を目指すのではなく,ゼロからシーケンスモデリングを再検討する.
- AFT/RWKV:Attention Free Transformer(AFT)は,ドット積のAttentionを要素ごとの演算に単純化し,softmaxをキーベクトルに移動する.RWKVはAFTの位置埋め込みを指数減衰に置き換え,訓練と推論のためにモデルをリカレントに実行する.比較すると,Retentionはシーケンス情報をエンコードするために高次元の状態を保持し,表現力と性能向上に寄与する.
- xPos/RoPE:Transformerで提案されている相対位置埋め込み手法と比較すると,式3はxPos[SDP+22]やRoPE[SLP+21]と同様の定式化を示している.
- Sub-LayerNorm:式8に示すように,Retention層は出力を正規化するためにSub-LayerNorm [WMH+22]を使用する.マルチスケールモデリングはHeadsに異なる分散をもたらすので,元のLayerNormをGroupNormに置き換える.
3章:Experiments(実験)
RetNetを評価するために言語モデリングに関する実験を行う.
提案アーキテクチャを様々なベンチマーク,すなわち言語モデリング性能,下流タスクのZero-Shot/Few-Shot学習で評価する.
さらに,訓練と推論について,速度,メモリ消費量,レイテンシを比較する.
3.1節:Setup(セットアップ)
- Parameter Allocation:公正な比較のために,MSRとFFNのパラメータを再配分する.ここでは簡単のためdを\(d_{model}\)とする.Transformersでは,\(W_Q, W_K, W_V, W_O \in \mathbb{R}^{d*d}\)のSelf-Attentionに約\(4d^2\)のパラメータがあり,中間次元が4dのFFNでは\(8d^2\)のパラメータがある.これに対してRetNetは,\(W_Q, W_K \in \mathbb{R}^{d*d}, W_G, W_V \in \mathbb{R}^{d*2d}, W_O \in \mathbb{R}^{2d*d}\)の\(8d^2\)のパラメータを持つ.VのHeadの次元はQ,Kの2倍であることに注意されたい.広がった次元は\(W_O\)によってdに投影される.パラメータ数をTransformerと同じにするため,RetNetのFFN中間次元は2dである.一方,実験ではHead次元を256に設定した.すなわち,クエリーとキーは256,バリューは512である.公平な比較のため,異なるモデルサイズ間で\(\gamma\)を同一に保ち,式8のデフォルト値ではなく,\(\gamma = 1 - e^{linspace(\log 1/32,\log 1/512,h)} \in \mathbb{R}^h\)とする.
- Language Model Training:表2に示すように,様々なサイズ(1.3B,2.7B,6.7B)の言語モデルをゼロから訓練する.訓練コーパスは,The Pile[GBB+20],C4[DMI+21],The Stack[KLBA+22]を編集したものである.シーケンスの開始を示すために<bos>トークンを付加する.訓練バッチサイズは4Mトークンで,最大長は2048である.100Bトークン,つまり25kステップでモデルを訓練する.AdamW[LH19]オプティマイザを使用し,\(\beta_1 = 0.9\),\(\beta_2 = 0.98\),重み減衰は0.05とする.ウォームアップステップ数は線形学習率減衰で375である.パラメータはDeepNet[WMD+22]に従って初期化し,訓練の安定性を保証する.実装はTorchScale[MWH+22]に基づいている.512 AMD MI200 GPUでモデルを訓練する.
※<bos>トークンを最初に追加することで,訓練の安定性とパフォーマンスが向上することがわかった.
3.2節:Comparisons with Transformer(Transformerとの比較)
- 言語モデリング:図5に示すように,TransformerとRetNetに基づく言語モデルについて,検証セットでのパープレキシティを報告する.3つのモデルサイズ,すなわち1.3B,2.7B,6.7Bでのスケーリング曲線を示す.RetNetはTransformerと同等の結果を達成している.さらに重要なことは,この結果はRetNetがサイズのスケーリングに関して有利であることを示している.パフォーマンスだけでなく,RetNetの訓練は我々の実験では非常に安定している.実験結果は,RetNetが大規模言語モデルにおいてTransformerの強力なライバルであることを示している.経験的に,モデルサイズが2Bより大きくなると,RetNetがTransformerを上回り始めることが分かる.また,付録Bにコンテキストの長さを変えた場合の言語モデリング結果をまとめる.
- 下流タスクのZero-Shot/Few-Shot評価:また,下流の幅広いタスクで言語モデルを比較した.6.7Bモデルを用いたZero-Shot学習と4-Shot学習を評価する.表3に示すように,データセットにはHellaSwag(HS)[ZHB+19],BoolQ[CLC+19],COPA[WPN+19],PIQA[BZB+20],Winograd,Winogrande[LDM12],StoryCloze(SC)[MRL+17]が含まれる.正解率の数値は,図5で示された言語モデリングのパープレキシティと一致している.RetNetはTransformerとZero-Shot学習やコンテキスト内学習において同等の性能を達成している.
3.3節:Training Cost(訓練コスト)
表4に示すように,TransformerとRetNetの訓練速度とメモリ消費量を比較する.
ここで,訓練シーケンスの長さは8192である.
FlashAttention[DFE+22]との比較も行い,再計算とカーネルフュージョンにより,速度の向上とGPUメモリIOの削減を実現する.
比較のため,RetNetの実装にはVanilla PyTorchのコードを使用し,カーネルフュージョンやFlashAttentionのようなアクセラレーションは今後の研究に委ねる.
式7のように,チャンクワイズリカレント表現でRetentionを表現する.
チャンクサイズは512に設定されている.
FlashAttentionはA100に高度に最適化されているため,8台のNvidia A100-80GBのGPUで結果を評価した.
テンソル並列は6.7Bと13Bモデルで有効になっている.
実験結果によれば,RetNetはTransformersよりもメモリ効率が高く,訓練時のスループットが高い.
FlashAttentionと比較しても,RetNetはスピードとメモリコストの点で競争力がある.
さらに,特定のカーネルに依存することなく,RetNetを他のプラットフォームで効率的に訓練させることも容易である.
例えば,AMD MI200クラスタでRetNetモデルを訓練したところ,十分なスループットを得ることができた.
RetNetはカーネルフュージョンなどの高度な実装により,さらにコストを削減できる可能性を秘めていることは注目に値する.
3.4節:Inference Cost(推論コスト)
図6に示すように,推論中のTransformerとRetNetのメモリコスト,スループット,レイテンシを比較する.
Transformerは以前にデコードされたトークンのKVキャッシュを再利用する.
RetNetは式6で説明されるようなリカレント表現を使用する.
実験ではA100-80GB GPUで6.7Bモデルを評価した.
図6は,推論コストの点でRetNetがTransformerを上回ることを示している.
- メモリ:図6aに示すように,TransformerのメモリコストはKVキャッシュにより直線的に増加する.これに対して,RetNetのメモリ消費量は長いシーケンスでも一定で,RetNetをホストするために必要なGPUメモリははるかに少なくて済む.RetNetの追加メモリ消費はほとんど無視できる程度(つまり約3%)で,モデルの重みが97%を占めている.
- スループット:図6bに示すように,Transformerのスループットはデコード長が長くなるにつれて低下する.これに比べて,RetNetは,Retentionのリカレント表現を利用することで,デコード時のスループットが高く,長さに依存しない.
- レイテンシ:レイテンシは,ユーザエクスペリエンスに大きく影響するデプロイにおける重要な指標である.図6cにデコードのレイテンシを示す.実験結果は,バッチサイズを大きくするとTransformerのレイテンシが大きくなることを示している.さらに,Transformerのレイテンシは入力が長くなるほど速くなる.レイテンシを許容できるようにするためには,バッチサイズを制限しなければならないが,これはTransformerの推論スループットを全体的に悪化させる.これに対して,RetNetのデコードのレイテンシはTransformersを上回り,異なるバッチサイズや入力長でもほとんど変わらない.
3.5節:Comparison with Transformer Variants(Transformer Variantsとの比較)
Transformerとは別に,Linear Transformer[KVPF20],RWKV[PAA+23],H3[DFS+22],Hyena[PMN+23]を含む様々な効率的なTransformer VariantsとRetNetを比較する.
すべてのモデルのパラメータは200Mで,16層,隠れ次元は1024である.
H3では,Headの次元を8とした.
RWKVについては,公正な比較のために,FFN層を他のモデルとの一貫性を保ちながら,Self-Attention層を代用するためにTimeMixモジュールを使用する.
バッチサイズは0.5Mトークンで,10kステップでモデルを訓練する.
ほとんどのハイパーパラメータと訓練コーパスは3.1節と同じである.
表5は,領域内検証セットと,他の領域外コーパス,例えばProject Gutenberg 2019-2022(PG22)[SDP+22],QMSum[ZYY+21],GovReport[HCP+21],SummScreen[CCWG21, SSI+22]におけるパープレキシティを示している.
全体として,RetNetは様々なデータセットにおいて以前の手法を凌駕している.
RetNetはドメイン内コーパスでより良い評価結果を得るだけでなく,いくつかのドメイン外データセットでもより低いパープレキシティを得る.
この良好な性能により,RetNetは大幅なコスト削減という利点(3.3節と3.4節)に加え,Transformerの強力な後継となる.
さらに,比較した手法の訓練効率と推論効率について議論する.
dを隠れ次元,nをシーケンス長とする.
訓練において,RWKVのトークンミキシングの複雑度は\(O(dn)\)であるのに対し,Hyenaのそれは高速フーリエ変換アクセラレーションにより\(O(dn \log n)\)である.
上記2つの方法は,モデル化能力をトレードオフするために要素ごとの演算子を採用することで,訓練FLOPSを削減している.
Retentionと比較すると,チャンクワイズリカレント表現は\(O(dn(b + h))\)であり,bはチャンクサイズ,hはヘッド次元である(通常,b = 512,h = 256).
モデルサイズが大きい場合(すなわちdが大きい場合),あるいはシーケンスの長さが長い場合,追加のb + hの影響は無視できる.
つまり,RetNetの訓練は,モデル化性能を犠牲にすることなく,非常に効率的である.
推論に関しては,比較された効率的なアーキテクチャの中で,HyenaはTransformerと同じ複雑度(すなわち,1ステップあたり\(O(n)\))であるのに対し,他のアーキテクチャは\(O(1)\)のデコードが可能である.
3.6節:Ablation Studies(アブレーション試験)
RetNetの様々な設計を検討し,言語モデリング結果を表6に示す.
評価設定と評価基準は3.5節と同じである.
- Architecture:式8のように,Swish GateとGroupNormを除去する.表6は,上記の2つのコンポーネントが最終的な性能を向上させていることを示している.第1に,ゲーティングモジュールは非線形性を強化し,モデル能力を向上させるために不可欠である.ゲートを取り除いた後のTransformersと同じパラメータ割り当てを使用していることに注意されたい.第2に,Retentionにおけるグループ正規化は,Multi-Head出力の分散のバランスをとり,訓練安定性と言語モデリング結果を向上させる.
- Multi-Scale Decay:式8は,Retention Headsの減衰率として異なる\(\gamma\)を使用していることを示している.アブレーション試験では,\(\gamma\)崩壊を除去し(すなわち,「- \(\gamma\)崩壊」),Heads間で同じ減衰率を適用する(すなわち,「- マルチスケール崩壊」)ことを検討する.具体的には,\(\gamma\)減衰を除去することは,\(\gamma=1\)に相当する.2つ目の設定では,すべてのヘッドについて\(\gamma=127/128\)とした.表6から,減衰の仕組みと複数の減衰率を使用することの両方で,言語モデリング性能が向上することがわかる.
- Head Dimension:式1のリカレントの観点からは,Head次元は隠れ状態のメモリ容量を意味する.アブレーション試験では,デフォルトのHead次元を256から64に縮小する.すなわち,クエリーとキーは64,バリューは128である.隠れ次元\(d_{model}\)は同じままなので,Heads数は増加する.表6の実験結果は,Head次元を大きくするほど性能が向上することを示している.
4章:Conclusion(結論)
本研究では,並列,リカレント,チャンクワイズリカレントといった様々な表現が可能な,シーケンスモデリングのためのRetentive Network(RetNet)を提案する.
RetNetは,Transformerと比較して,(メモリ,速度,レイテンシの点で)格段に優れた推論効率,有利な訓練並列化,競争力のある性能を達成する.
以上の利点から,RetNetは大規模言語モデルにおいてTransformersの後継として理想的であり,特に推論複雑度が\(O(1)\)であることがもたらすデプロイ上の利点を考慮すると理想的である.
将来的には,モデルサイズ[CDH+22]と訓練ステップの点でRetNetをスケールアップしたい.
さらに,長期記憶を圧縮することで,構造化プロンプト[HSD+22b]を効率的に扱うことができる.
また,マルチモーダルな大規模言語モデル[HSD+22a, HDW+23, PWD+23]を訓練するためのバックボーンアーキテクチャとしてRetNetを使用する予定である.
さらに,RetNetモデルを携帯電話などの様々なエッジデバイスに展開することにも興味がある.
References(参考文献)
- [BKH16] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
- [BMR+20] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel Ziegler, Jeffrey Wu, Clemens Winter, Chris Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. In Advances in Neural Information Processing Systems, volume 33, pages 1877–1901. Curran Associates, Inc., 2020.
- [BZB+20] Yonatan Bisk, Rowan Zellers, Ronan Le Bras, Jianfeng Gao, and Yejin Choi. Piqa: Reasoning about physical commonsense in natural language. In Thirty-Fourth AAAI Conference on Artificial Intelligence, 2020.
- [CCWG21] Mingda Chen, Zewei Chu, Sam Wiseman, and Kevin Gimpel. Summscreen: A dataset for abstractive screenplay summarization. arXiv preprint arXiv:2104.07091, 2021.
- [CDH+22] Zewen Chi, Li Dong, Shaohan Huang, Damai Dai, Shuming Ma, Barun Patra, Saksham Singhal, Payal Bajaj, Xia Song, Xian-Ling Mao, Heyan Huang, and Furu Wei. On the representation collapse of sparse mixture of experts. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems, 2022.
- [CLC+19] Christopher Clark, Kenton Lee, Ming-Wei Chang, Tom Kwiatkowski, Michael Collins, and Kristina Toutanova. BoolQ: Exploring the surprising difficulty of natural yes/no questions. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 2924–2936, Minneapolis, Minnesota, June 2019. Association for Computational Linguistics.
- [DFE+22] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
- [DFS+22] Tri Dao, Daniel Y Fu, Khaled K Saab, Armin W Thomas, Atri Rudra, and Christopher Ré. Hungry hungry hippos: Towards language modeling with state space models. arXiv preprint arXiv:2212.14052, 2022.
- [DMI+21] Jesse Dodge, Ana Marasovi´c, Gabriel Ilharco, Dirk Groeneveld, Margaret Mitchell, and Matt Gardner. Documenting large webtext corpora: A case study on the colossal clean crawled corpus. In Conference on Empirical Methods in Natural Language Processing, 2021.
- [GBB+20] Leo Gao, Stella Biderman, Sid Black, Laurence Golding, Travis Hoppe, Charles Foster, Jason Phang, Horace He, Anish Thite, Noa Nabeshima, et al. The Pile: An 800GB dataset of diverse text for language modeling. arXiv preprint arXiv:2101.00027, 2020.
- [GGR21] Albert Gu, Karan Goel, and Christopher Ré. Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396, 2021.
- [HCP+21] Luyang Huang, Shuyang Cao, Nikolaus Parulian, Heng Ji, and Lu Wang. Efficient attentions for long document summarization. arXiv preprint arXiv:2104.02112, 2021.
- [HDW+23] Shaohan Huang, Li Dong, Wenhui Wang, Yaru Hao, Saksham Singhal, Shuming Ma, Tengchao Lv, Lei Cui, Owais Khan Mohammed, Qiang Liu, Kriti Aggarwal, Zewen Chi, Johan Bjorck, Vishrav Chaudhary, Subhojit Som, Xia Song, and Furu Wei. Language is not all you need: Aligning perception with language models. ArXiv, abs/2302.14045, 2023.
- [HS97] Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural Computation, 9:1735–1780, November 1997.
- [HSD+22a] Yaru Hao, Haoyu Song, Li Dong, Shaohan Huang, Zewen Chi, Wenhui Wang, Shuming Ma, and Furu Wei. Language models are general-purpose interfaces. ArXiv, abs/2206.06336, 2022.
- [HSD+22b] Yaru Hao, Yutao Sun, Li Dong, Zhixiong Han, Yuxian Gu, and Furu Wei. Structured prompting: Scaling in-context learning to 1,000 examples. ArXiv, abs/2212.06713, 2022.
- [KLBA+22] Denis Kocetkov, Raymond Li, Loubna Ben Allal, Jia Li, Chenghao Mou, Carlos Muñoz Ferrandis, Yacine Jernite, Margaret Mitchell, Sean Hughes, Thomas Wolf, Dzmitry Bahdanau, Leandro von Werra, and Harm de Vries. The Stack: 3 tb of permissively licensed source code. Preprint, 2022.
- [KVPF20] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International Conference on Machine Learning, pages 5156–5165. PMLR, 2020.
- [LDM12] Hector Levesque, Ernest Davis, and Leora Morgenstern. The winograd schema challenge. In Thirteenth International Conference on the Principles of Knowledge Representation and Reasoning, 2012.
- [LH19] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. In International Conference on Learning Representations, 2019.
- [MRL+17] Nasrin Mostafazadeh, Michael Roth, Annie Louis, Nathanael Chambers, and James Allen. Lsdsem 2017 shared task: The story cloze test. In Proceedings of the 2nd Workshop on Linking Models of Lexical, Sentential and Discourse-level Semantics, pages 46–51, 2017.
- [MWH+22] Shuming Ma, Hongyu Wang, Shaohan Huang, Wenhui Wang, Zewen Chi, Li Dong, Alon Benhaim, Barun Patra, Vishrav Chaudhary, Xia Song, and Furu Wei. TorchScale: Transformers at scale. CoRR, abs/2211.13184, 2022.
- [OSG+23] Antonio Orvieto, Samuel L. Smith, Albert Gu, Anushan Fernando, Caglar Gulcehre, Razvan Pascanu, and Soham De. Resurrecting recurrent neural networks for long sequences. ArXiv, abs/2303.06349, 2023.
- [PAA+23] Bo Peng, Eric Alcaide, Quentin Anthony, Alon Albalak, Samuel Arcadinho, Huanqi Cao, Xin Cheng, Michael Chung, Matteo Grella, Kranthi Kiran GV, Xuzheng He, Haowen Hou, Przemyslaw Kazienko, Jan Kocon, Jiaming Kong, Bartlomiej Koptyra, Hayden Lau, Krishna Sri Ipsit Mantri, Ferdinand Mom, Atsushi Saito, Xiangru Tang, Bolun Wang, Johan S. Wind, Stansilaw Wozniak, Ruichong Zhang, Zhenyuan Zhang, Qihang Zhao, Peng Zhou, Jian Zhu, and Rui-Jie Zhu. Rwkv: Reinventing rnns for the transformer era, 2023.
- [PMN+23] Michael Poli, Stefano Massaroli, Eric Nguyen, Daniel Y Fu, Tri Dao, Stephen Baccus, Yoshua Bengio, Stefano Ermon, and Christopher Ré. Hyena hierarchy: Towards larger convolutional language models. arXiv preprint arXiv:2302.10866, 2023.
- [PWD+23] Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, and Furu Wei. Kosmos-2: Grounding multimodal large language models to the world. ArXiv, abs/2306.14824, 2023.
- [RZL17] Prajit Ramachandran, Barret Zoph, and Quoc V. Le. Swish: a self-gated activation function. arXiv: Neural and Evolutionary Computing, 2017.
- [SDP+22] Yutao Sun, Li Dong, Barun Patra, Shuming Ma, Shaohan Huang, Alon Benhaim, Vishrav Chaudhary, Xia Song, and Furu Wei. A length-extrapolatable transformer. arXiv preprint arXiv:2212.10554, 2022.
- [Sha19] Noam M. Shazeer. Fast transformer decoding: One write-head is all you need. ArXiv, abs/1911.02150, 2019.
- [SLP+21] Jianlin Su, Yu Lu, Shengfeng Pan, Bo Wen, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding. arXiv preprint arXiv:2104.09864, 2021.
- [SPP+19] Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper, and Bryan Catanzaro. Megatron-lm: Training multi-billion parameter language models using model parallelism. arXiv preprint arXiv:1909.08053, 2019.
- [SSI+22] Uri Shaham, Elad Segal, Maor Ivgi, Avia Efrat, Ori Yoran, Adi Haviv, Ankit Gupta, Wenhan Xiong, Mor Geva, Jonathan Berant, et al. Scrolls: Standardized comparison over long language sequences. arXiv preprint arXiv:2201.03533, 2022.
- [VSP+17] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, 4-9 December 2017, Long Beach, CA, USA, pages 6000– 6010, 2017.
- [WH18] Yuxin Wu and Kaiming He. Group normalization. In Proceedings of the European conference on computer vision (ECCV), pages 3–19, 2018.
- [WMD+22] Hongyu Wang, Shuming Ma, Li Dong, Shaohan Huang, Dongdong Zhang, and Furu Wei. DeepNet: Scaling Transformers to 1,000 layers. ArXiv, abs/2203.00555, 2022.
- [WMH+22] Hongyu Wang, Shuming Ma, Shaohan Huang, Li Dong, Wenhui Wang, Zhiliang Peng, Yu Wu, Payal Bajaj, Saksham Singhal, Alon Benhaim, et al. Foundation transformers. arXiv preprint arXiv:2210.06423, 2022.
- [WPN+19] Alex Wang, Yada Pruksachatkun, Nikita Nangia, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R Bowman. SuperGLUE: A stickier benchmark for general-purpose language understanding systems. arXiv preprint arXiv:1905.00537, 2019.
- [ZHB+19] Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi, and Yejin Choi. Hellaswag: Can a machine really finish your sentence? In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, 2019.
- [ZYY+21] Ming Zhong, Da Yin, Tao Yu, Ahmad Zaidi, Mutethia Mutuma, Rahul Jha, Ahmed Hassan Awadallah, Asli Celikyilmaz, Yang Liu, Xipeng Qiu, et al. Qmsum: A new benchmark for query-based multi-domain meeting summarization. arXiv preprint arXiv:2104.05938, 2021.
付録A:Hyperparameters(ハイパーパラメータ)
付録B:Grouped Results of Different Context Lengths(異なるコンテキスト長をグループ化した結果)
表8に示すように,コンテキスト長を変えた場合の言語モデリング結果を報告する.
数値を比較できるようにするため,2048個のテキストチャンクを評価データとして使用し,最後の128個のトークンについてのみパープレキシティを計算する.
実験の結果,RetNetは様々なコンテキスト長においてTransformerを上回ることがわかった.
さらに,RetNetはより良い結果を得るために,より長いコンテキストを利用することができる.
参考:Retentive Network: A Successor to Transformer for Large Language Modelsの解説動画
Retentive Network: A Successor to Transformer for Large Language Modelsの解説動画です.
まとめ
Retentive Network: A Successor to Transformer for Large Language Modelsの日本語訳を紹介しました.
Microsoftの清華大学によるTransformerの後継「RetNet」がわかりました.
AIのプログラミング言語「C++/Python言語」を学べるおすすめのWebサイトを知りたいあなたはこちらからどうぞ.
独学が難しいあなたは,AIを学べるオンラインプログラミングスクール3社で自分に合うスクールを見つけましょう.後悔はさせません!
国内・海外のAIエンジニアのおすすめ求人サイトを知りたいあなたはこちらからどうぞ. こういった悩みにお答えします. こういった私が解説していきます. 国内・海外のAIエンジニアのおすすめ求人サイト(転職エージェント)を紹介します. AIエンジニアになるためには,主にC++/Pytho ... 続きを見る
国内・海外のAIエンジニアのおすすめ求人サイト【転職エージェント】【C++/Python言語】
国内・海外のプロンプトエンジニアのおすすめ求人サイトを知りたいあなたはこちらからどうぞ.