從零開始,了解元學習
選自Medium
作者:Thomas Wolf
機器之心編譯
參與:Tianci LIU、路
本文介紹了元學習,一個解決「學習如何學習」的問題。
元學習是目前機器學習領域一個令人振奮的研究趨勢,它解決的是學習如何學習的問題。
傳統的機器學習研究模式是:獲取特定任務的大型數據集,然後用這個數據集從頭開始訓練模型。很明顯,這和人類利用以往經驗,僅僅通過少量樣本就迅速完成學習的情況相差甚遠。
因為人類學習了「如何學習」。
在這篇文章中,我將從一個非常直觀的元學習簡介入手,從它最早的起源一直談到如今的元學習研究現狀。然後,我會從頭開始,在 PyTorch 中實現一個元學習模型,同時會分享一些從該項目中學到的經驗教訓。
首先,什麼是學習?
我們先來簡單了解一下,當我們訓練一個用來實現貓狗圖像分類的簡單神經網路時,到底發生了什麼。假設我們現在有一張貓的圖像,以及對應的表示「這是一隻貓」的標籤。為簡潔起見,我做了一個簡單的動畫來展示訓練的過程。
神經網路訓練過程的單步。該網路用來實現貓狗圖像分類。
反向傳播是神經網路訓練中很關鍵的一步。因為神經網路執行的計算和損失函數都是可微函數,因此我們能夠求出網路中每一個參數所對應的梯度,進而減少神經網路當前給出的預測標籤與真實/目標標籤之間的差異(這個差異是用損失函數度量的)。在反向傳播完成後,就可以使用優化器來計算模型的更新參數了。而這正是使神經網路的訓練更像是一門「藝術」而不是科學的原因:因為有太多的優化器和優化設置(超參數)可供選擇了。
我們把該「單個訓練步」放在一張圖中展示,如下所示:
現在,訓練圖像是一隻,表示圖像是一隻貓的標籤是 。最大的這些 表示我們的神經網路,裡面的 表示參數和梯度,標有 L 的四邊形表示損失函數,標有 O 的四邊形表示優化器。
完整的學習過程就是不斷地重複這個優化步,直到神經網路中的參數收斂到一個不錯的結果上。
上圖表示神經網路的訓練過程的三步,神經網路(用最大的 表示)用於實現貓狗圖像分類。
元學習
元學習的思想是學習「學習(訓練)」過程。
元學習有好幾種實現方法,不過本文談到的兩種「學習『學習』過程」的方法和上文介紹的方式很類似。
在我們的訓練過程中,具體而言,可以學習到兩點:
神經網路的初始參數(圖中的藍色);
優化器的參數(粉色的★)。
我會介紹將這兩點結合的情況,不過這裡的每一點本身也非常有趣,而且可獲得到簡化、加速以及一些不錯的理論結果。
現在,我們有兩個部分需要訓練:
用「模型(M)」這個詞來指代我們之前的神經網路,現在也可以將其理解為一個低級網路。有時,人們也會用「優化對象(optimizee)」或者「學習器(learner)」來稱呼它。該模型的權重在圖中用 表示。
用「優化器(O)」或者「元學習器」來指代用於更新低級網路(即上述模型)權重的高級模型。優化器的權重在圖中用 ★ 表示。
如何學習這些元參數?
事實上,我們可以將訓練過程中的元損失的梯度反向傳播到初始的模型權重和/或優化器的參數。
現在,我們有了兩個嵌套的訓練過程:優化器/元學習器上的元訓練過程,其中(元)前向傳輸包含模型的多個訓練步:我們之前見過的前饋、反向傳播以及優化步驟。
現在我們來看看元訓練的步驟:
元訓練步(訓練優化器 O)包含 3 個模型(M)的訓練步。
在這裡,元訓練過程中的單個步驟是橫向表示的。它包含模型訓練過程中的兩個步驟(在元前饋和元反向傳播的方格中縱向表示),模型的訓練過程和我們之前看到的訓練過程完全一樣。
可以看到,元前向傳輸的輸入是在模型訓練過程中依次使用的一列樣本/標籤(或一列批次)。
元訓練步中的輸入是一列樣本(、)及其對應的標籤(、)。
我們應該如何使用元損失來訓練元學習器呢?在訓練模型時,我們可以直接將模型的預測和目標標籤做比較,得到誤差值。
在訓練元學習器時,我們可以用元損失來度量元學習器在目標任務——訓練模型——上的表現。
一個可行的方法是在一些訓練數據上計算模型的損失:損失越低,模型就越好。最後,我們可以計算出元損失,或者直接將模型訓練過程中已經計算得到的損失結合在一起(例如,把它們直接加起來)。
我們還需要一個元優化器來更新優化器的權重,在這裡,問題就變得很「meta」了:我們可以用另一個元學習器來優化當前的元學習器……不過最終,我們需要人為選擇一個優化器,例如 SGD 或者 ADAM(不能像「turtles all the way down」一樣(註:turtles all the way down 這裡大概是說,「不能一個模型套一個模型,這樣無限的套下去」)。
這裡給出一些備註,它們對於我們現在要討論的實現而言非常重要:
二階導數:將元損失通過模型的梯度進行反向傳播時,需要計算導數的導數,也就是二階導數(在最後一個動畫中的元反向傳播部分,這是用綠色的 穿過綠色的 來表示的)。我們可以使用 TensorFlow 或 PyTorch 等現代框架來計算二階導數,不過在實踐中,我們通常不考慮二階導數,而只是通過模型權重進行反向傳播(元反向傳播圖中的黃色 ),以降低複雜度。
坐標共享:如今,深度學習模型中的參數數量非常多(在 NLP 任務中,很容易就有將近 3000 萬 ~2億個參數)。當前的 GPU 內存無法將這麼多參數作為單獨輸入傳輸給優化器。我們經常採用的方法是「坐標共享」(coordinate sharing),這表示我們為一個參數設計一個優化器,然後將其複製到所有的參數上(具體而言,將它的權重沿著模型參數的輸入維度進行共享)。在這個方法中,元學習器的參數數量和模型中的參數數量之間並沒有函數關係。如果元學習器是一個記憶網路,如 RNN,我們依然可以令模型中的每個參數都具有單獨的隱藏狀態,以保留每個參數的單獨變化情況。
在 PyTorch 中實現元學習
我們來嘗試寫些代碼,看看真實情況如何吧。
現在我們有了一個模型,它包含一個我們想要進行訓練的權重集合,我們將使用該集合解決這兩項任務:
在元前饋步驟中:我們使用這個模型計算(損失函數的)梯度,並作為優化器的輸入來更新模型參數;
在元反向傳播步驟中:我們使用這個模型作為反向傳播優化器參數梯度(從元損失中計算得到)的路徑。
在 PyTorch 中完成這個任務最簡單的方法是:使用兩個一樣的模塊來表示模型,每個任務一個。我們把存儲元前饋步驟中使用的模型梯度的模塊稱為前向模型(forward model),把元反向傳播步驟中將參數存儲為反向傳播優化器梯度的連續路徑的模塊稱為後向模型(backward model)。
兩個模塊之間會使用共享的 Tensor,以防止重複佔用內存(Tensor 是內存中真正有意義的部分);但同時,也會保留各自的 Variable,以明確區分模型的梯度和元學習器的梯度。
PyTorch 中的一個簡單元學習器類
在 PyTorch 中共享張量非常直接:只需要更新 Variable 類中的指針,讓它們指向相同的 Tensor 就可以了。但如果模型已經是內存優化模型,例如 AWD-LSTM 或 AWD-QRNN 這類共享 Tensors(輸入和輸出嵌入)的演算法時,我們就會遇到問難。這時,我們在更新兩個模塊中的模型參數時,需要很小心,以確保我們保留的指針是正確的。
在這裡給出一個實現方法:設置一個簡單的輔助程序來完成遍歷參數的任務,並返回更新 Parameter 指針(而不只是 Tensor)所需的全部信息,並保持共享參數同步。
以下是一個實現函數:
通過這個函數,我們可以嵌入任何模型,並且很整潔地遍曆元學習器的模型參數。
現在,我們來寫一個簡單的元學習器類。我們的優化器是一個模塊:在前饋階段,它可以將前向模型(及其梯度)和後向模型作為輸入接受,並遍歷它們的參數來更新後向模型中的參數,同時允許元梯度反向傳播(通過更新 Parameter 指針,而不僅僅是 Tensor 指針)。
這樣一來,我們就可以像在第一部分中看到的那樣來訓練優化器了。以下是一個簡單的要點示例,展示了前文描述的元訓練過程:
避免內存爆炸——隱藏狀態記憶
有時,我們想要學習一個可在非常龐大的(可能有幾千萬個參數的)模型上運行的優化器;同時,我們還希望可以在大量步驟上實現元訓練,以得到優質梯度;就像我們在論文《Meta-Learning a Dynamical Language Model》中所實現的那樣。
在實踐中,這意味著,我們想要在元前饋中包含一個很長的訓練過程,以及很多時間步;同時我們還需要將每一步的參數(黃色)和梯度(綠色)保存在內存中,這些參數和梯度會在元反向傳播中使用到。
我們如何在不讓 GPU 內存爆炸的情況下做到這一點呢?
一個辦法是,使用梯度檢查點(gradient checkpointing)來用內存換取計算,這個方法也叫「隱藏狀態記憶」(Hidden State Memorization)。在我們的案例中,梯度檢查點表示,將我們連續計算的元前饋和元反向傳播切分成片段。
來自 Open AI 的 Yaroslav Bulatov 有一篇很好的介紹梯度檢查點的文章,如果你感興趣,可以了解一下:
Fitting larger networks into memory(https://medium.com/@yaroslavvb/fitting-larger-networks-into-memory-583e3c758ff9)
這篇文章非常長,所以我沒有給出一個完整的梯度檢查點代碼示例,建議大家使用已經很完善的 TSHadley 的 PyTorch 實現,以及當前還在開發的梯度檢查點的 PyTorch 本地實現。
元學習中的其他方法
元學習中還有另外兩個很有前景的研究方向,但本文沒有時間來討論了。在這裡我給出一些提示,這樣,當你知道了它們大致的原理後,就可以自己查閱相關資料了:
循環神經網路:我們之前給出了神經網路的標準訓練過程。還有一個方法:將連續的任務作為一個輸入序列,然後建立一個循環模型,並用它提取、構建一個可用於新任務的序列表徵。在這種方法中,對於某個帶有記憶或注意力的循環神經網路,我們通常只使用一個訓練過程。這個方法的效果也很不錯,尤其是當你設計出適合任務的嵌入時。最近的這篇 SNAIL 論文是一個很好的例子:A Simple Neural Attentive Meta-Learner(https://openreview.net/forum?id=B1DmUzWAW)。
強化學習:優化器在元前饋過程中完成的計算和循環神經網路的計算過程很類似:在輸入序列(學習過程中模型的權重序列和梯度序列)上重複使用相同的參數。在真實場景下,這表示我們會遇到循環神經網路經常遇到的一個問題:一旦模型出錯,就很難返回安全路徑,因為我們並沒有訓練模型從訓練誤差中恢復的能力;同時,當遇到一個比元學習過程中使用的序列更長的序列時,模型難以泛化。為了解決這些問題,我們可以求助於強化學習方法,讓模型學習一個和當前訓練狀態相關的動作策略。
自然語言處理中的元學習
元學習和用於自然語言處理(NLP)的神經網路模型(如循環神經網路)之間有一個非常有趣的相似之處。在上一段中,我們曾提到:
用於優化神經網路模型的元學習器的行為和循環神經網路類似。
和 RNN 類似,元學習器會提取一系列模型訓練過程中的參數和梯度作為輸入序列,並根據這個輸入序列計算得到一個輸出序列(更新後的模型參數序列)。
我們的論文《Meta-Learning a Dynamical Language Model》中詳細論述了該相似性,並研究了將元學習器用於神經網路語言模型中,以實現中期記憶:經過學習,元學習器能夠在標準 RNN(如 LSTM)的權重中,編碼中期記憶(除了短期記憶在 LSTM 隱藏狀態中的傳統編碼方式以外)。
我們的元學習語言模型由 3 層記憶層級組成,自下而上分別是:標準 LSTM、用於更新 LSTM 權重以存儲中期記憶的元學習器,以及一個長期靜態記憶。
我們發現,元學習語言模型可以通過訓練來編碼最近輸入的記憶,就像一篇維基百科文章的開始部分對預測文章的結尾部分非常有幫助一樣。
上圖中的曲線展示了在給定一篇維基百科文章開始部分的情況下(A, …, H 是連續的維基百科文章),模型預測文章辭彙的效果。單詞顏色表示的意思相同:藍色表示更好,紅色表示更差。當模型在閱讀一篇文章時,它從文章的開始部分進行學習,讀到結尾部分的時候,它的預測效果也變得更好了(更多細節,請閱讀我們的論文)。
以上是我對元學習的介紹,希望對大家有所幫助!
參考文獻
1. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#afeb) As such, meta-learning can be seen as a generalization of「transfer learning」and is related to the techniques for fine-tuning model on a task as well as techniques for hyper-parameters optimization. There was an interesting workshop on meta-learning (https://nips.cc/Conferences/2017/Schedule?showEvent=8767) at NIPS 2017 last December.
2. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#dc5a) Of course in a real training we would be using a mini-batch of examples.
3. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#e0bb) More precisely:「most of」these operations are differentiable.
4. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#d640) Good blog posts introducing the relevant literature are the BAIR posts: Learning to learn (http://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/) by Chelsea Finn and Learning to Optimize with Reinforcement Learning (http://bair.berkeley.edu/blog/2017/09/12/learning-to-optimize-with-rl/) by Ke Li.
5. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#930c) Good examples of learning the model initial parameters are Model-Agnostic Meta-Learning (https://arxiv.org/abs/1703.03400) of UC Berkeley and its recent developments (https://openreview.net/forum?id=BJ_UL-k0b) as well as the Reptile algorithm (https://blog.openai.com/reptile/) of OpenAI. A good example of learning the optimizer』s parameters is the Learning to learn by gradient descent by gradient descent (https://arxiv.org/abs/1606.04474) paper of DeepMind. A paper combining the two is the work Optimization as a Model for Few-Shot Learning (https://openreview.net/forum?id=rJY0-Kcll) by Sachin Ravi and Hugo Larochelle. An nice and very recent overview can be found in Learning Unsupervised Learning Rules (https://arxiv.org/abs/1804.00222).
6. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#d094) Similarly to the way we back propagate through time in an unrolled recurrent network.
7. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#725d) Initially described in DeepMind』s Learning to learn by gradient descent by gradient descent (https://arxiv.org/abs/1606.04474) paper.
8. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#4e23) We are using coordinate-sharing in our meta-learner as mentioned earlier. In practice, it means we simply iterate over the model parameters and apply our optimizer broadcasted on each parameters (no need to flatten and gather parameters like in L-BFGS for instance).
9. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#d029) There is a surprising under-statement of how important back-propagating over very long sequence can be to get good results. The recent paper An Analysis of Neural Language Modeling at Multiple Scales (https://arxiv.org/abs/1803.08240) from Salesforce research is a good pointer in that direction.
10. ^ (https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a#6c6f) Gradient checkpointing is described for example in Memory-Efficient Backpropagation Through Time (https://arxiv.org/abs/1606.03401) and the nice blog post (https://medium.com/@yaroslavvb/fitting-larger-networks-into-memory-583e3c758ff9) of Yaroslav Bulatov.
本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。
------------------------------------------------
※英特爾劉茵茵:持續優化NLP服務,助推人工智慧創新和落地
※學界 | 用GAN自動生成法線貼圖,讓圖形設計更輕鬆
TAG:機器之心 |