當前位置:
首頁 > 知識 > 谷歌論文新突破:通過輔助損失提升RNN學習長期依賴關係的能力

谷歌論文新突破:通過輔助損失提升RNN學習長期依賴關係的能力

本文提出了一種簡單的方法,通過在原始函數中加入輔助損失改善 RNN 捕捉長期依賴關係的能力,並在各種設置下評估了該方法,包括用長達 16,000 的序列對一張圖的逐個像素進行分類,以及對一個真實的基準文件進行分類;和其他常用模型和大小相當的轉換器相比,該方法在性能和資源使用效率方面的表現都非常突出。

介紹

大量人工智慧應用的前提是首先理解序列中事件間的長期依賴關係。例如,在自然語言處理中,有時就必須要對書中描述的遠距離事件之間的關係有所了解,這樣才能回答問題。一般而言,現在是通過梯度下降和帶有循環網路的 BPTT(Rumelhart et al., 1986)解決這一問題的。然而,通過梯度下降方法學習長期依賴性很難,因為藉助 BPTT 計算的梯度在訓練過程中有消失或爆炸的傾向。除此以外,如果想要使 BPTT 起作用,人們需要存儲中間過程的隱藏狀態。內存需求與序列長度成正比,使得這種方法難以處理大問題。

圖 1:本文方法概述。輔助損失改善了循環網路的內存問題,主任務的 BPTT 需要的步驟少了一些。

也有人提出過若干個有望解決這些問題的方法。首先,可以使用 LSTM(Hochreiter & Schmidhuber, 1997)代替常用的循環神經網路,這可以改善循環網路中的梯度流的問題。此外,還可以使用梯度裁減(Pascanu et al., 2013)提高 LSTM 訓練過程的穩定性。最後,為了減少內存方面的需求,可以使用截斷的 BPTT 或合成梯度(Jaderberg et al., 2017)定期存儲隱藏層的狀態(Gruslys et al., 2016; Chen et al., 2016)。

卷積神經網路也可以消除長期的依賴關係問題,因為內核較大,而且像 ResNets(He et al., 2016)這樣的深度網路允許跨越圖像中相距較遠的兩個部分學習長期依賴關係。但這樣就會用到完全不同的架構,我們可以對此進行權衡。例如,在訓練過程中,模型的輸入(一張圖像或者一個序列)以及中間的激活都要存儲在內存中。在推斷時,典型的 CNN 需 O(n) 的存儲空間(n 代表輸入的大小)。儘管由於訓練和推斷的計算需要隨機存取到內存 O(n),但變換器(Vaswani et al., 2017)也有相似的問題,並且嚴重一些。

圖 2:本文方法的草圖。對每個隨機定位點,也就是 F 點而言,我們在這個位置上建立輔助損失。左圖:我們預測了 F 點前的一段隨機序列 BCD。將 B 點插入解碼器網路以開始重建,而 C 點和 D 點可以選擇是否饋送。右圖:我們通過在主窗口堆疊輔助 RNN 對子序列 GHI 進行預測。在這兩種情況中,輔助損失的梯度都被截斷,通過這種方式來保證 BPTT 總體消耗維持不變。

RNN 的優勢在於,假設 BPTT 的長度為 l,訓練就需要 O(l) 的內存。這是一個用 PTB 數據集(Marcus et al., 1994)訓練語言模型的典型實例,這樣做 100 萬個符號序列的狀態就永遠不會重置。因此,從理論上講 RNN 可以從極遠的距離學到這種關係。此外,RNN 的推斷也需要 O(l) 的內存,因為 RNN 不需要「回頭」。

在這篇論文中,我們提出一種正交技術以進一步解決循環網路單純依賴 BPTT 的缺陷。該方法介紹了一種無監督輔助損失,可以重建/預測錨點前後的一部分隨機序列。實現這個方法,只需要幾步有監督損失的 BPTT。

論文結果表明無監督輔助損失顯著改善了 LSTM 的優化和泛化能力。此外,如果使用這一方法,無需在訓練過程中執行冗長的 BPTT 以獲得良好的結果。因此,該方法適用於長序列,在此之前,這些長序列中出現的梯度消失/爆炸問題以及冗長的 BPTT 消耗問題都是模型訓練中的重要瓶頸。

實驗採用的序列長達 16,000 個元素,帶有輔助損失的 LSTM 訓練得更快並使用了更少的內存,而採用完整的反向傳播訓練 LSTM 則非常困難。

方法

假設目標是使用循環網路閱讀序列並分類。我們隨機採樣一個或多個錨點,並在每個錨點插入無監督輔助損失。

3.1. 重建輔助損失

在重建過去事件時,我們取樣了錨點之前的子序列,並將第一段標記序列插入解碼器網路;然後我們要求解碼器網路預測出剩下的子序列。整個過程如圖 2 左圖所示。

我們推斷,如果擬預測序列離定位點足夠近,解碼重建過去事件所需的 BPTT 的步驟就會非常少。另外,隨著訓練的進一步加強,定位點會在循環網路中充當臨時存儲的角色來記錄序列中過去的事件。如果我們選擇了足夠多的定位點,就會在整段序列上建立足夠多的存儲,當我們到序列末端時,分類器會記住序列從而更好地進行分類。因此,分類器僅需幾步反向傳播步驟對 LSTM 的權重進行微調,因為網路已經通過優化的輔助損失很好地對輸入序列的嵌入進行了學習。

3.2. 預測輔助損失

本文考慮的另一種輔助損失類似於語言模型損失,如圖 2 右圖所示。這種情況要求解碼器網路在子序列中從錨點出發預測出所給序列的下一段標記序列。這類無監督輔助損失第一次是 Dai & Le (2015) 提出的,他們將其應用於整個輸入序列。但我們將其應用在長期依賴關係學習的擴展方案中,因此我們僅將這種損失應用在隨機錨點之後的子序列中。

3.3. 訓練

我們將前一種方法稱為 r-LSTM , 將後一種方法稱為 p-LSTM(r 和 p 分別代表重建和預測),在兩個階段對這兩個模型進行訓練。第一階段是單純的無監督預訓練,在這一方法中輔助損失取最小值。而在第二階段中,執行的是半監督學習,在這一階段中我們取主要目標損失 L_supervised 和 L_auxiliary 的總和最小值。用定期採樣(Bengio et al., 2015a)的方法訓練執行重建操作的輔助 LSTM。

表 2:在 MNIST、pMNIST 和 CIFAR10 上測試的準確率

圖 3:上圖: StanfordDogs 的 8 個級別序列長度測試的準確度。下圖:運行具有 128 個訓練實例的單個小批次的時間,以秒為測量單位。

圖 5 :輔助損失對訓練和測試準確率的影響

論文:Learning Longer-term Dependencies in RNNs with Auxiliary Losses

論文鏈接:https://arxiv.org/abs/1803.00144

儘管訓練循環神經網路(RNNs)最近仍有進展,但在長序列中捕捉長期依賴關係仍舊是根本的挑戰。現在一般會用通過時間的反向傳播(BPTT)解決這一問題,但這很難應用於極長的序列。本文提出了一種簡單的方法,可以通過在原始函數中加入輔助損失改善 RNN 捕捉長期依賴關係的能力。輔助損失強制 RNN 在序列中重建之前的事件或是預測接下來的事件,這樣的操作可以截斷長序列中的反饋,還可以提高 BPTT 整體的能力。我們在各種設置下評估了所述方法,包括用長達 16,000 的序列對一張圖的逐個像素進行分類,以及對一個真實的基準文件進行分類。和其他常用模型和大小相當的轉換器相比,我們的方法在性能和資源使用效率方面的表現都非常突出。我們進一步分析了輔助損失在優化和正則化方面的積極影響,和沒有反向傳播相比,幾乎不會出現極端情況。


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

如何使用貪婪搜索和束搜索解碼演算法進行自然語言處理
要不要買延誤險?來,先看看Google Flight預測大法

TAG:機器之心 |