ICLR 2018最佳論文AMSGrad能夠取代Adam嗎
作者:Filip Korzeniowski
編譯:weakish
編者按:Google的Reddi等關於Adam收斂性的論文最近被評為ICLR 2018最佳論文,其中提出了一個Adam的變體AMSGrad。那麼,在實踐中,AMSGrad是不是能夠取代Adam(目前深度學習中最流行的優化方法之一)呢?讓我們一起來看奧地利林茨大學(JKU)博士Filip Korzeniowski所做的試驗。
在ICLR 2018最佳論文On the Convergence of Adam and Beyond(關於Adam的收斂性及其他)中,Google的Reddi等指出了Adam收斂性證明的缺陷,並提出了一個Adam演算法的變體AMSGrad。論文通過一個合成任務和少量試驗展示了AMSGrad的優勢。然而,它僅僅使用了小型網路(MNIST上的單層MLP,CIFAR-10上的小型卷積網路),並且沒有表明測試精確度(顯然,比起交叉熵損失,我們更加關心精確度)。從訓練和測試損失上看,他們在CIFAR-10上訓練的卷積網路,比當前最先進的結果要差很多(我們並不知道精確度)。
由於我一般習慣使用Adam,所以我決定在更實際的網路上評估一下AMSGrad。請注意,我這裡訓練的模型也不大,也不是當前最先進的,但絕對比論文展示的更實際。
我在Lasagne的Adam實現的基礎上做了修改,實現了AMSGrad。我還添加了一個關閉Adam的偏置糾正(bias correction)的選項。
首先,讓我們檢查下實現是否正確。我運行了論文中概述的合成試驗(使用隨機配置)。下圖為所得學習曲線。
這和論文中的圖形相當接近了,也和ycmario的重新實現的結果差不多。Adam錯誤地收斂於1,而x的最佳值為-1.
我接著試驗了四種配置:
MNIST上的邏輯回歸。
CIFAR-10上的CifarNet(論文中描述的CNN)。
CIFAR-10上的SmallVgg(一個小型的VGG風格網路)
CIFAR-10上的Vgg(一個較大的VGG風格網路)
在每種配置中,我為所有下列組合各運行了5遍訓練:
beta2 ∈ (據論文建議)
學習率 ∈
偏置糾正 ∈
batch大小為128(和原論文一樣)。訓練了150個epoch,初始學習率線性下降,直到在第150個epoch後降至0. 在CIFAR-10試驗中,我同樣使用了標準的左右翻轉數據增強。相應的代碼見GitHub倉庫fdlm/amsgrad_experiments。
讓我們首先關注Vgg配置。雖然這些配置各有其古怪之處,我們可以從結果得出的結論卻是類似的。
CIFAR-10上的Vgg
我這裡使用的卷積網路結構如下:
每層卷積之後是batch normalisation和rectifier激活函數。虛線標誌著dropout,相應的概率標明在線的上方。細節詳見代碼。
在下面的圖形中,Adam的結果用藍色表示,AMSGrad的結果用紅色表示。較淡的顏色表明偏置糾正關閉了。每條線表示一次訓練。
訓練損失
在我的試驗中,訓練損失非常接近0,而論文中的數值則是0.3. 這當然是因為我用的是大得多的模型。我們同時看到,如果學習率不是太低的話,AMSGrad看起來在最後的訓練階段收斂得快一點,並且大多數情況下,使用偏置糾正有助於收斂。不過,最後所有變體在大多數超參數配置下(學習率過低除外)收斂至相似的最終損失。如果我們仔細查看,我們會發現在最後的若干epoch上Adam超越了AMSGrad。這沒什麼大不了的——畢竟這是訓練損失。然而,這類模型(以及訓練方案)和論文中報告的CifarNet模型的結果很不一樣,論文中的情況是,AMSGrad的訓練損失比Adam低很多。
驗證損失
驗證損失的表現不同。我們看到,AMSGrad持續超越Adam,特別是在後期的epoch中。兩個演算法達到了相似的最小驗證誤差(在第20-25個epoch周圍),但自此之後,Adam看起來過擬合地更厲害,至少是就交叉熵損失而言。所以AMSGrad可以彌補Adam經常比標準SGD(隨機梯度下降)概括性差這一點嗎?需要做一些標準SGD的試驗來回答這一問題。
訓練精確度
讓我們看下精確度。在分類問題中,精確度是一個比交叉熵損失重要得多的量度。畢竟,我們在乎的是我們的模型成功分類了多少樣本。首先,讓我們看下訓練精確度。就CIFAR-10而言,大部分強大的模型能夠達到接近100%的精確度,如果訓練是恰當的話(事實上,即使使用隨機標籤,它們仍能達到100%精確度。)
正如我們看到的,訓練精確度和訓練損失的表現差不多:AMSGrad比Adam收斂得快,但最終的結果差不多。如果超參數不是太離譜的話,不管使用哪種演算法,我們都能達到差不多100%的精確度(同樣,學習率不能太低)。
最後,到了最有意思的部分——驗證精確度。
驗證精確度
現在是失望時刻。儘管在所有超參數配置上,AMSGrad的驗證損失更低,從驗證精確性來看,基本是平局。在某些設定下,Adam表現更好(lr = 0.01、b2 = 0.999),在其他一些設定下,AMSGrad(lr = 0.001、b2 = 0.999)表現更好。在所有配置中取得最好結果的是Adam(lr = 0.001、b2 = 0.99),但我認為這一差別並不顯著。看起來,給定合適的超參數設定,兩個演算法的表現差不多好。
討論
論文指出了Adam收斂性證明的一個致命缺陷,並展示了一個證明成立的Adam變體,AMSGrad。然而,我發現論文對AMSGrad的實際影響的試驗評估比較有限。在論文中,作者聲稱(著重由我所加)「AMSGrad在訓練損失和精確度上表現得比Adam好多了。此外,這一表現提升同樣轉換到了測試損失上。」不幸的是,就我使用的模型和訓練方案而言,我的試驗不能證實這些:
無論是基於損失,還是基於精確度,兩者在訓練集上的表現類似。
AMSGrad的測試(這裡是驗證)損失確實比較低。然而,
測試損失上的提升並沒有轉換為更好的測試精確度(公平起見,作者從來沒有聲稱這一點)。
測試損失和測試精確度間的表現差異提出了一個關鍵的問題:類別交叉熵訓練用於分類的神經網路有多合適?我們能做得更好嗎?
特此聲明:我確實非常欣賞論文作者指出Adam弱點的工作。儘管我並沒有驗證證明(顯然,直到這篇論文,都沒人驗證Adam的證明),我傾向於相信他們的結論。同時,人造樣本確實表明Adam在特定條件下無法工作。我認為這是一篇好論文。
然而,實際影響需要更多的以經驗為基礎的探索。我在本文中描述的試驗並沒有表明Adam和AMSGrad之間在實踐上有很大的差異。
附錄
這裡我展示其他設定的結果。
MNIST上的邏輯回歸
這一設定和論文中的一個試驗類似。唯一的差別是學習率:論文使用的是α/√t,其中t為迭代,而我使用的是之前提到的每個epoch後的線性衰減。讓我們看下結果。
訓練損失
我們首先看到的是,線性學習率衰減可以得到低得多的訓練損失。在論文中,訓練曲線看起來在5000多次迭代(batch大小128,約13個epoch)後,在約0.25處保持平坦。而在我的試驗中,取決於超參數,它達到了0.2. 另外一個有趣的發現是,在所有配置中,Adam取得了最低的最終訓練損失。這和論文中的結果相反,特別是論文聲稱的,相比Adam,AMSGrad對參數變動更加強韌(robust)。
驗證損失
這裡情況有所不同。AMSGrad通常給出最低的驗證損失(除非我們用了過低的學習率)。然而,在lr = 0.002、b2 = 0.99時,Adam取得了最低的驗證損失(差別微乎其微),不過之後發散了。
訓練精確度
訓練精確度的表現和訓練損失非常類似,驗證精確度同樣如此。
驗證精確度
從這些結果來看,在實踐層面,相比Adam,AMSGrad並沒有明確的優勢。結果取決於學習率和beta2的選擇。同樣,結果並未確認AMSGrad對超參數變動更強韌這一主張。
CIFAR-10上的CifarNet
我嘗試重現實現論文描述的卷積神經網路。然而,我沒能重現訓練取現。取決於學習率和beta2,我要麼得到了更好的訓練損失(但是驗證損失更差),要麼兩者都比論文中的差。由於論文並未提供模型的所有細節(例如,我們不知道初始化方案,是否使用了L2正則化),很難查找原因。不管怎麼說,下面是我得到的結果。
訓練損失
我們看到,兩個演算法在較高的學習率下都不能收斂。就訓練損失而言,學習率0.001看起來效果最好。儘管Adam取得了最低的訓練損失,但AMSGrad的訓練看起來更穩定(差異較小),不過我認為需要更多試驗來確認這一點。
驗證損失
在取得最佳訓練損失的配置上,驗證損失很糟糕地發散了。兩種演算法都受到了類似的影響。看起來基於較低的學習率訓練的模型概括性要好很多。我在這裡說點掃興的話,精確度的圖片會大不一樣。
訓練精確度
訓練精確度的表現和訓練損失類似。
驗證精確度
這裡是最有意思的部分了。我們之前看到驗證損失糟糕地發散的模型,驗證精確度實際上是最優的。就驗證損失而言看起來概括性更好的模型,從驗證精確性的角度而言,概括性並不好。記住這一點很重要。
取決於設定,Adam或AMSGrad以微小的優勢超越彼此。更重要的是,和當前最先進模型相比,模型的表現差太多了(只有78%的精確度),不管使用的是什麼優化演算法。
CIFAR-10上的SmallVgg
這是我在本文主要部分所用模型的一個較小的版本。在每層中,使用一半的過濾器,每塊(block)只使用兩個卷積層。我將省略這些試驗的結果的解釋,因為它們並沒有增加了什麼新的內容。
訓練損失
驗證損失
訓練精確度
驗證精確度
原文地址:https://fdlm.github.io/post/amsgrad/
※MIT提出TbD網路,讓視覺問答模型更易於解釋同時保持高性能
※演算法是新的醫藥:人工智慧醫療的風口
TAG:論智 |