當前位置:
首頁 > 科技 > NIPS 2018 | 將RNN內存佔用縮小90%:多倫多大學提出可逆循環神經網路

NIPS 2018 | 將RNN內存佔用縮小90%:多倫多大學提出可逆循環神經網路


選自arXiv


作者:Matthew MacKay等

機器之心編譯


參與:高璇、張倩





循環神經網路(RNN)在處理序列數據方面取得了當前最佳的性能表現,但訓練時需要大量內存。可逆循環神經網路提供了一個減少訓練內存需求的路徑,因為隱藏狀態不需要存儲,而是可以在反向傳播過程中重新計算。本文首先展示了不需要存儲隱藏激活的完全可逆 RNN 從根本上是有限制的,因為它們不能忘記隱藏狀態的信息。然後,研究人員提供了一個存儲少量比特的方案,使遺忘信息實現完全逆轉。本文的方法實現了與傳統模型相當的性能,但所需內存只佔傳統模型的 1/10 到 1/15。




循環神經網路(RNN)在語音識別 [1]、語言建模 [2,3] 和機器翻譯 [4,5] 等多種任務上都取得了極優的性能。然而,訓練 RNN 需要大量的內存。標準的訓練演算法是時間截斷的反向傳播(TBPTT)[6,7]。該演算法將輸入序列劃分為較短的子序列 T,然後對每個子序列進行處理,並對梯度進行反向傳播。如果模型隱藏狀態的大小為 H,那麼 TBPTT 所需的內存是 O(T H)。






減少 TBPTT 演算法對內存的需求會增加被截斷序列的長度 T,從而抓取更長的時間尺度上的相關性。也可以增加隱藏狀態的大小 H,或者利用更深層的輸入到隱藏、隱藏到隱藏或隱藏到輸出轉換,從而賦予模型更強的表達能力。增加這些轉換的深度可以提高復調音樂預測、語言建模和神經機器翻譯(NMT)的性能 [8,9,10]。




可逆循環網路架構提供了一種降低 TBPTT 內存需求的方法。可逆架構實現了在給定下一個隱藏狀態和當前輸入的當前時間步上的隱藏狀態重建,這樣無需在每個時間步上存儲隱藏狀態就能執行 TBPTT。代價就是要增加計算成本來重建反向傳播過程中的隱藏狀態。




本文首先介紹了廣泛使用的門控循環單元(GRU)[11] 和長短期記憶(LSTM)[12] 架構的可逆相似架構。然後證明,任何不需要存儲隱藏激活的完全可逆的 RNN,在一個簡單的一步預測任務中都會失敗。即使這個任務對於普通的 RNN 很簡單,但在完全可逆模型中卻失敗了,因為它們需要記住輸入序列才能完成任務。根據這一發現,研究人員擴展了 Maclaurin 等人 [13] 的高效內存反轉法,在每個單元中存儲少量的比特,以對遺忘信息的架構進行完全逆轉。






研究人員在語言建模和神經機器翻譯基準上評估了這些模型的性能。根據任務、數據集和所選架構,可逆模型(無注意力機制)所需內存只佔傳統模型的 1/10 到 1/15。可逆模型在 Penn TreeBank 數據集 [14] 上的詞級語言建模任務中得到了與傳統的 LSTM 模型和 GRU 模型相似的性能,在 WikiText-2 數據集 [15] 上比傳統模型困惑度落後 2-5 個點。




使用基於注意力的循環序列到序列模型來節省內存是很困難的,因為為執行注意力機制,編碼器的隱藏狀態必須同時保存在內存中。對與嵌入詞相連的隱藏狀態的子集執行注意力機制,可以解決這一問題。使用這種技術後,可逆模型在神經機器翻譯任務中取得了成功,在 Multi30K 數據集 [16] 上的性能優於基線 GRU 和 LSTM 模型,並在 IWSLT 2016[17] 基準上取得了很有競爭力的表現。使用該技術可以將內存在解碼器中減少到原來的 1/10-1/15,在編碼器中減少到原來的 1/5-1/10。






論文:Reversible Recurrent Neural Networks







論文鏈接:https://arxiv.org/pdf/1810.10999v1.pdf




摘要

:循環神經網路(RNN)在處理序列數據方面取得了當前最佳的性能表現,但訓練時需要大量的內存,限制了可訓練的 RNN 模型的靈活性。可逆的 RNN-RNN 可以對其進行隱藏狀態到隱藏狀態的轉換,提供了一個減少訓練內存需求的路徑,因為隱藏狀態不需要存儲,而是可以在反向傳播過程中重新計算。我們首先展示了不需要存儲隱藏激活的完全可逆 RNN,從根本上是有限制的,因為它們不能忘記隱藏狀態的信息。然後,我們提供了一個存儲少量比特的方案,使遺忘信息實現完全逆轉。我們的方法實現了與傳統模型相當的性能,同時激活內存開銷變為原來的 1/10-1/15。然後我們將技術擴展到基於注意力的序列到序列模型,在這種模型中性能不變,但將內存開銷在解碼器中減少到原來的 1/10-1/15,在編碼器中減少到原來的 1/5-1/10。




3 可逆循環架構




構建 RevNet 的技術可以與傳統的 RNN 模型相結合,生成可逆的 RNN。在本節中,我們提出了類似 GRU 和 LSTM 的可逆架構。




3.1 可逆 GRU




我們首先回顧一下,在給定當前隱藏狀態 h^(t) 和當前輸入 x^(t)(省略偏差)時計算下一個隱藏狀態 h^(t+1) 的 GRU 方程:







方程中的⊙表示元素的乘法。為了使更新可逆,我們將隱藏狀態 h 分成兩組,h = [h_1; h_2]。使用以下規則更新這些組:







注意,h_1^(t) 和非 h_1^(t-1) 用於計算 h_2^(t) 的更新。我們把這個模型稱為可逆門控循環單元或 RevGRU。




3.2 可逆 LSTM




接下來構造一個可逆的 LSTM。該 LSTM 將隱藏狀態分為輸出狀態 h 和單元狀態 c,更新方程為:







我們不能直接使用我們的可逆方法,因為 h^(t) 的更新不是 h^(t-1) 的非零線性變換。儘管如此,可通過使用如下方程來實現可逆:







我們使用 c_1^(t) 和 h_1^(t),利用與上述等式相同的方式計算 c_2、h_2 的更新。我們將此模型稱為可逆 LSTM 或 RevLSTM。




3.3 有限精度演算法的可逆性




我們已經定義了在精準演算法中可逆的 RNN。但在實際中,由於數值精度有限,隱藏狀態不能被完全地重建。考慮 RevGRU 方程 4 和 5,如果隱藏狀態 h 存儲在固定點中,則將 h 乘以 z(其條目小於 1)會破壞信息,從而阻止完全重建。例如,將隱藏單位乘以 1/2 相當於丟棄最低位位元組,其值在反向計算中無法恢復。信息丟失的這些誤差在時間步長上呈指數級累積,導致通過反轉獲得的初始隱藏狀態與真實的初始狀態相去甚遠。同樣的問題也會影響 RevLSTM 隱藏狀態的重建。因此,我們發現遺忘是構建完全可逆的循環架構的主要障礙。




解決這一問題有兩種可行途徑。首先是移除遺忘步驟。對於 RevGRU,這意味著我們像以前一樣計算 z_i^(t) 、 r_i^(t) 和 g_i^(t),並使用以下方法更新 h_i^(t):







我們將此模型稱為無遺忘(No-Forgetting)RevGRU 或 NF-RevGRU。因為 NFRevGRU 的更新不會丟棄信息,所以在給定時間內的訓練過程中,我們只需要在內存中存儲一個隱藏狀態即可。可以採用類似的步驟定義 NF-RevLSTM。




第二種方法是接受一些內存使用,並將從隱藏狀態中遺忘的信息存儲在前向傳播中。然後,我們可以在反向計算中將這些信息還原到隱藏狀態,以實現完全重建。具體內容將在第 5 節中詳細討論。




4 No Forgetting 的不可能性




我們已經證明,如果不丟棄任何信息,可以構造出具有有限精度的可逆 RNN。我們無法找到能夠在語言建模之類的任務上獲得理想性能的架構。這與之前發現的遺忘對 LSTM 性能至關重要是一致的 [23,24]。在本節中,我們認為這是由不可遺忘可逆模型的一個基本限制造成的:如果任何隱藏狀態都不能被遺忘,那麼任何給定時間步上的隱藏狀態必須包含足夠的信息,來重建所有以前的隱藏狀態。因此,在一個時間步長上存儲在隱藏狀態中的任何信息都必須保留在將來的所有時間步上,以確保精準重構,這超過了模型的存儲容量。





圖 1:在重複任務上展開完全可逆模型的反向計算,得到序列到序列的計算。左:重複任務本身,模型重複每個輸入指令。右:展開反轉。模型有效利用最終隱藏狀態來重構所有輸入指令,這意味著整個輸入序列必須存儲在最終隱藏狀態中。




5 遺忘的可逆性




由於零遺忘不可能實現,我們不得不探索實現可逆性的第二種方案:在前向計算中存儲隱藏狀態丟失的信息,在反向計算中恢複信息。最開始我們研究了只允許遺忘一個整數位的離散遺忘。這導致:如果前向傳遞中遺忘了 n 位位元組,我們可以將這 n 位位元組存儲在堆棧中,在重構期間彈出並恢復到隱藏狀態。但是,與基線模型相比,限制我們的模型僅僅遺忘整數位位元組就會導致性能大幅下降。本文接下來的內容會側重只遺忘一小部分位元組的部分遺忘。




5.2 注意力機制下的內存節省





圖 2:NMT 的注意力機制。詞嵌入、編碼器隱藏狀態和解碼器隱藏狀態分別用橙色、藍色和綠色表示;編碼器隱藏狀態的條紋區域表示為注意力機制而存儲在內存中的部分。用於計算上下文向量的最後幾個向量連接了詞嵌入和編碼器隱藏狀態。




6 實驗





表 1:Penn TreeBank 詞級語言建模上的驗證困惑度(內存節省)。在沒有限制的情況下,當遺忘被限制在 2 位、3 位和 5 位比特時,每個隱藏單元每個時間步的結果顯示如表。





表 2: WikiText-2 詞級語言建模上的驗證困惑度。在沒有限制的情況下,當遺忘被限制在 2 位、3 位和 5 位比特時,每個隱藏單元每個時間步的結果顯示如表。





表 3:不同的遺忘限制下 Multi30K 數據集的性能。P 為測試 BLEU 分數;M 表示編碼器在訓練期間平均節省的內存。






本文為機器之心編譯,

轉載請聯繫本公眾號獲得授權



?------------------------------------------------


加入機器之心(全職記者 / 實習生):hr@jiqizhixin.com


投稿或尋求報道:

content

@jiqizhixin.com


廣告 & 商務合作:bd@jiqizhixin.com

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

AIIA人工智慧開發者大會預告:開源賦能,智慧共享
AIIA開發者大會開啟在即,思必馳俞凱談語音交互技術的「AI互聯」

TAG:機器之心 |