當前位置:
首頁 > 知識 > 這份攻略幫你「穩住」反覆無常的 GAN

這份攻略幫你「穩住」反覆無常的 GAN

選自Medium

作者:Bharath Raj

參與:Geek AI、張倩

GAN 自 2014 年提出以來得到了廣泛應用,BigGAN 等生成的以假亂真的圖像更是引發了眾多關注,但由於訓練穩定性較差,GAN 的使用變得非常困難。本文列出了一些提高 GAN 訓練穩定性的常用技術。

生成對抗網路(GAN)是一類非常強大的神經網路,具有非常廣闊的應用前景。GAN 本質上是由兩個相互競爭的神經網路(生成器和判別器)組成的系統。

GAN 的工作流程示意圖。

給定一組目標樣本,生成器會試圖生成一些人造的樣本,這些生成的樣本能夠欺騙判別器將其視為真實的目標樣本,達到「以假亂真」的目的。而判別器則會試圖將真實的(目標)樣本與虛假的(生成)樣本區分開來。通過這樣循環往複的訓練方法,我們最終可以得到一個能夠很好地生成與目標樣本相似的樣本的生成器。

由於 GAN 幾乎可以學會模擬出所有類型的數據分布,它有著非常廣泛的應用場景。通常,GAN 被用來去除圖片中的人為影響、超解析度、姿勢遷移以及任何類型的圖像轉換,如下所示:

使用 GAN 完成的圖像變換。

然而,由於 GAN 的訓練穩定性反覆無常,使用 GAN 是十分困難的。誠然,許多研究人員已經提出了很好的解決方案來緩解 GAN 訓練中涉及的一些問題。然而,這一領域的研究進展是如此之快,以至於人們很難跟上這些最新的有趣的想法。本文列出了一些常用的使 GAN 訓練穩定的技術。

使用 GAN 的弊端

由於一系列原因,想要使用 GAN 是十分困難的。本節列舉出了其中的一些原因:

1. 模式崩潰

自然的數據分布是極其複雜的多模態函數(也稱多峰函數)。也就是說,數據分布有許多「峰」或「模式」。每個模態代表相似的數據樣本聚集在一起,但是與其它的模態並不相同。

在模式崩潰的情況下,生成器會生成從屬於有限模態集集合的樣本。當生成器認為它可以通過生成單一模式的樣本來欺騙鑒別器時,就會發生這種情況。也就是說,生成器只從這種模式生成樣本。

上面一排圖片表示沒有發生模式崩潰的情況下 GAN 輸出的樣本。下面一排圖片表示發生模式崩潰時 GAN 輸出的樣本。

判別器最終會發現這種模式是人為生成的。結果,生成器會直接轉而生成另一種模式。這樣的情況會無限循環下去,從本質上限制了生成樣本的多樣性。詳細解釋請參考博客《Mode collapse in GANs》(http://aiden.nibali.org/blog/2017-01-18-mode-collapse-gans/)

2. 收斂性

在 GAN 的訓練過程中,一個普遍的問題就是「何時停止訓練 GAN 模型?」由於在判別器損失降低的同時生成器的損失會增高(反之亦然),我們並不能基於損失函數的值就來判別 GAN 的收斂性。下圖說明了這一點:

一張典型的 GAN 損失函數示意圖。請注意,此圖無法說明 GAN 的收斂性。

3. 質量

和前面提到的問題一樣,很難定量地判斷生成器何時能生成高質量的樣本。向損失函數中加入額外的感知正則化項可以在一定程度上幫助我們緩解這種情況。

4. 評價標準

GAN 的目標函數說明了生成器(G)與判別器(D)這一對相互博弈的模型相對於其對手的性能,但卻不能代表輸出樣本的質量或多樣性。因此,我們需要能夠在目標函數相同的情況下進行度量的獨特的評價標準。

術語

在我們深入研究可能有助於提升 GAN 模型性能的技術之前,讓我們回顧一些術語。

1. 下確界及上確界

簡而言之,下確界是集合的最大下界,上確界是集合的最小上界,「上確界、下確界」與「最小值、最大值」的區別在於下確界和上確界不一定屬於集合。

2. 散度度量

散度度量代表了兩個分布之間的距離。傳統的 GAN 本質上是最小化了真實數據分布和生成的數據分布之間的 Jensen Shannon 散度(JS 散度)。GAN 的損失函數可以被改寫為最小化其它的散度度量,例如:Kulback Leibler 散度(KL 散度)或全變分距離。通常,Wasserstein GAN 最小化了推土機距離。

3. Kantorovich Rubenstein 對偶性

我們很難使用一些散度度量的原始形式進行優化。然而,它們的對偶形式(用上確界替換下確界,反之亦然)可能就較為容易優化。對偶原理為將一種形式轉化為另一種形式提供了框架。詳細解釋請參考博客:《Wasserstein GAN and the Kantorovich-Rubinstein Duality》(https://vincentherrmann.github.io/blog/wasserstein/)

4. LiPSCHITZ 連續性

一個 Lipschitz 連續函數的變化速度是有限的。對具備 Lipschitz 連續性的函數來說,函數曲線上任一點的斜率的絕對值不能超過實數 K。這樣的函數也被稱為 K-Lipschitz 連續函數。

Lipschitz 連續性是 GAN 所期望滿足的,因為它們會限制判別器的梯度,從而從根本上避免了梯度爆炸問題。另外,Kantorovich-Rubinstein 對偶性要求 Wasserstein GAN 也滿足 Lipschitz 連續性,詳細解釋請參考博客:《Spectral Normalization Explained》(https://christiancosgrove.com/blog/2018/01/04/spectral-normalization-explained.html)。

用於提升模型性能的技術

有許多技巧和技術可以被用來使 GAN 更加穩定和強大。為了保證本文的簡潔性,我僅僅解釋了一些相對來說較新或較為複雜的技術。在本節的末尾,我列舉出了其它的各種各樣的技巧和技術。

1. 替換損失函數

針對 GAN 存在的的缺點,最流行的修正方法之一是使用「Wasserstein GAN」。它本質上是使用「推土機距離」(Wasserstein-1 距離或 EM 距離)代替傳統 GAN 的 Jensen Shannon 散度。然而,EM 距離的原始形式是難以進行優化的,因此我們使用它的對偶形式(通過 Kantorovich Rubenstein 對偶性計算得出)。這要求判別器滿足「1-Lipschitz」,我們是通過裁剪判別器的權重來保證這一點的。

使用推土機距離的優點是,即使真實的和生成的樣本的數據分布沒有交集,推土機距離也是「連續的」,這與 JS 或 KL 散度不同。此外,此時生成圖像的質量與損失函數值之間存在相關性。而使用推土機距離的缺點是,我們需要在每次更新生成器時更新好幾個判別器(對於原始實現的每次生成器更新也是如此)。此外,作者聲稱,權值裁剪是一種糟糕的確保 1-Lipschitz 約束的方法。

與 Jensen Shannon 散度(如右圖所示)不同,即使數據分布不是連續的,推土機距離(如左圖所示)也是連續的。詳細的解釋請參閱論文《Wasserstein GAN》(https://arxiv.org/pdf/1701.07875.pdf)

另一種有趣的解決方案是採用均方損失而非對數損失。LSGAN 的作者認為,傳統的 GAN 損失函數並沒有提供足夠的刺激來「拉動」生成的數據分布逼近真實的數據分布。

原始 GAN 損失函數中的對數損失並不影響生成數據與決策邊界之間的距離(決策邊界將真實數據和生成的數據分開)。另一方面,LSGAN 對遠離決策邊界的生成樣本進行懲罰,本質上將生成的數據分布「拉向」實際的數據分布。它通過使用均方損失替代對數損失來做到這一點。詳細解釋請參考博客:《Least Squares GAN》(https://wiseodd.github.io/techblog/2017/03/02/least-squares-gan/)。

2. 兩個時間尺度上的更新規則(TTUR)

在此方法中,我們為判別器和生成器使用了不同的學習率。通常,生成器使用較慢的更新規則,而判別器使用較快的更新規則。通過使用這種方法,我們只需對學習率進行微調,就可以以 1:1 的比例執行生成器和判別器的更新。值得注意的是,SAGAN 的實現就使用了這個方法。

3. 梯度懲罰

在論文「Improved Training of WGANs」中,作者聲稱權值裁剪(正如在原始的 WGAN 中執行的那樣)導致一些優化問題的產生。作者認為權重裁剪迫使神經網路去學習「較為簡單的近似」從而得到最優的數據分布,這導致 GAN 得到的最終結果質量變低。他們還聲稱,如果 WGAN 的超參數設置不正確,權重裁剪會導致梯度爆炸或梯度消失的問題。作者在損失函數中引入了一個簡單的梯度懲罰規則,從而緩解了上述問題。除此之外,正如在原始的 WGAN 實現中那樣,這樣做還保證了 1-Lipschitz 連續性。

正如在原始的 WGAN-GP 論文中提到的,將梯度懲罰作為正則化項加入。

DRAGAN 的作者聲稱,當 GAN 中進行的博弈(即判別器和生成器互相進行對抗)達到了「局部均衡狀態」時,模式崩潰現象就會發生。他們還聲稱,此時由判別器所貢獻的梯度是非常「尖銳的」。使用這樣的梯度懲罰能夠很自然地幫助我們避開這些狀態,大大提高訓練的穩定性,並減少模式崩潰現象的發生。

4. 譜歸一化

譜歸一化是一種通常在判別器中使用的權值歸一化技術,它能夠優化訓練過程(使訓練過程更穩定),從本質上保證了判別器滿足「K-Lipschitz 連續性」。

SAGAN 等實現也在生成器中使用了譜歸一化技術。博文《Spectral Normalization Explained》(https://christiancosgrove.com/blog/2018/01/04/spectral-normalization-explained.html)也指出,譜歸一化比梯度懲罰的計算效率更高。

5. 展開和打包

正如博文《Mode collapse in GANs》所描述的,一個阻止模式崩潰發生的方法是在更新參數時預測「對策」。當判別器有機會對生成器的結果做出反應時(考慮到對策,就像 min-max 方法),展開(unrolled)的 GAN 就可以讓生成器騙過判別器。

另一個阻止模式崩潰發生的方式是將屬於同一類的一些樣本「打包」,然後將其傳給判別器。這種方法被 PacGAN 所採用,該論文聲稱它們減少了模式崩潰的發生。

6. 堆疊 GAN

單個的 GAN 可能不夠強大,無法有效地處理某些任務。因此,我們可以使用連續放置的多個 GAN,其中每個 GAN 可以解決一個簡化的問題模塊。例如,FashionGAN 使用了兩個 GAN 處理局部的圖像轉換任務。

FashionGAN 使用了兩個 GAN 來執行局部的圖像轉換。

將這種情況推到極致,可以逐步提高 GAN 模型所面臨問題的難度。例如,Progressive GAN(ProGAN)可以生成解析度超高的高質量圖像。

7. 相對 GAN

傳統的 GAN 會度量生成數據是真實數據的概率。相對 GAN(Relativistic GAN)則會去度量生成數據比真實數據「更加真實」的概率。正如 RGAN 相關論文《The relativistic discriminator: a key element missing from standard GAN》中提到的那樣,我們可以使用一個合適的距離來度量這種「相對真實性」。

圖 B 為我們使用標準 GAN 損失得到的判別器的輸出。圖 C 為輸出的曲線實際的樣子。圖 A 為 JS 散度的最優解。

作者還提到,當判別器達到最優狀態時,其輸出的概率 D(x)應該收斂到 0.5。然而,傳統的 GAN 訓練演算法會迫使判別器為任何圖像輸出「真實」(即概率為 1)的結果。這在某種程度上阻止了判別器的輸出概率達到其最優值。相對 GAN 也解決了這個問題,並且如下圖所示,取得了非常顯著的效果。

在 5000 輪迭代後,標準 GAN 得到的輸出(左圖),以及相對 GAN 得到的輸出(右圖)。

8. 自注意力機制

自注意力 GAN 的作者聲稱,用於生成圖像的卷積操作關注的是局部傳播的信息。也就是說,由於它們的感受野(restrictive receptive field)有限,它們忽略了在全局傳播的關係。

將注意力映射(由黃色方框中的網路計算得出)加入到標準的卷積運算中。

自注意力生成對抗網路使圖像生成任務能夠進行注意力機制驅動的遠距離依賴建模。自注意力機制是對於常規的卷積運算的補充。全局信息(遠距離依賴)有助於生成更高質量的圖像。網路可以選擇忽略注意力機制,或將其與常規的卷積運算一同進行考慮。要想更細緻地了解自注意力機制,請參閱論文《Self-Attention Generative Adversarial Networks》(https://arxiv.org/pdf/1805.08318.pdf)。

9. 其它各種各樣的技術

下面是其它的一些被用來提升 GAN 模型性能的技術(不完全統計!):

特徵匹配

使用 Mini Batch 技術優化的判別器

歷史平均

單邊標籤平滑法

虛擬批量歸一化

你可以通過論文《Improved Techniques for Training GANs》以及博文《From GAN to WGAN》了解更多關於這些技術的信息。在下面的 GitHub 代碼倉庫中列舉出了更多的技術:https://github.com/soumith/ganhacks。

評價指標

到目前為止,讀者已經了解了提升 GAN 訓練效果的方法,我們需要使用一些指標來量化證明這些方法有效。下面,本文將列舉出一些常用的 GAN 模型的性能評價指標。

1. Inception(GoogleNet)得分

Inception 得分可以度量生成數據有多「真實」。

Inception Score 的計算方法。

上面的方程由兩個部分(p(y|x) 和 p(y))組成。在這裡,x 代表由生成器生成的圖像,p(y|x) 是將圖像 x 輸入給一個預訓練好的 Inception 網路(正如在原始實現中使用 ImageNet 數據集進行預訓練,https://arxiv.org/pdf/1801.01973.pdf)時得到的概率分布。同時,p(y) 是邊緣概率分布,可以通過對生成圖像 x 的一些不同的樣本求 p(y|x) 平均值計算得出。這兩項代表了真實圖像所需要滿足的兩種特性:

生成圖像應該包含「有意義」的目標(清晰、不模糊的目標)。這就意味著 p(y|x) 應該具有「較小的熵」。也就是說,我們的 Inception 網路必須非常有把握地確定生成的圖像從屬於某個特定的類。

生成的圖像應該要「多樣」。這就意味著 p(y) 應該有「較大的熵」。換句話說,生成器應該在生成圖像時使得每張圖像代表不同類的標籤(理想情況下)。

理想狀況下 p(y|x) 和 p(y) 的示意圖。這種情況下,二者的 KL 散度非常大。

如果一個隨機變數是高度可預測的,那麼它的熵就很小(即,p(y) 應該是有一個尖峰的分布)。相反,如果隨機變數是不可預測的,其熵就應該很大(即 p(y|x) 應該是一個均勻分布)。如果這兩個特性都得到了滿足,我們應該認為 p(y|x) 和 p(y) 的 KL 散度很大。自然,Inception 得分(IS)越大越好。如果讀者想要了解對 Inception 得分更加深入的分析,請參閱論文《A Note on the Inception Score》(https://arxiv.org/pdf/1801.01973.pdf)。

2. Fréchet Inception 距離(FID)

Inception 得分的一個不足之處在於,並沒有對真實數據和生成數據的統計量(如均值和方差)進行比較。Fréchet 距離通過對比真實圖像和生成圖像的均值和方差解決了這個問題。Fréchet Inception 距離(FID)執行了與 Inception 得分相同的分析過程,但是它是在通過向預訓練好的 Inception-v3 網路傳入真實的和生成的圖像後得到的特徵圖上完成的。FID 的公式如下所示:

FID 得分對比了真實的數據分布和生成數據分布的均值和方差。「Tr」代表矩陣的「跡」。

FID 得分越低越好,因為此時它表明生成圖像的統計量與真實圖像非常接近。

結語

為了克服 GAN 訓練中的種種弊端,研究社區提出了許多解決方案和方法。然而,由於大量湧現的新研究成果,很難跟進所有有意義的新工作。因此,本文分享的細節是不詳盡的,並且可能在不久的將來就會過時。但是,筆者希望本文可以為那些想要提高 GAN 模型性能的人提供一定的指導。

本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。

------------------------------------------------


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

李沐等將目標檢測絕對精度提升 5%,不犧牲推理速度
絕佳的ASR學習方案:這是一套開源的中文語音識別系統

TAG:機器之心 |