模型不收斂,訓練速度慢,如何才能改善 GAN 的性能?
AI 研習社按:本文為雷鋒字幕組編譯的技術博客,原標題 GAN?—?Ways to improve GAN performance,作者Jonathan Hui。
翻譯 | 姚秀清 郭蘊哲 校對 | 吳桐 整理 | 孔令雙
與其他深度網路相比,GAN 模型在以下方面可能會受到嚴重影響。
不收斂:模型永遠不會收斂,更糟糕的是它們變得不穩定。
模式崩潰:生成器生成單個或有限模式。
慢速訓練:訓練生成器的梯度會消失。
作為 GAN 系列的一部分,本文探討了如何改進 GAN 的方法。 尤其在如下方面,
更改成本函數以獲得更好的優化目標。
在成本函數中添加額外的懲罰以強制執行約束。
避免過度自信和過度擬合。
更好的優化模型的方法。
添加標籤。
特徵匹配
生成器試圖找到最好的圖像來欺騙鑒別器。當兩個網路相互抵抗時,「最佳「圖像會不斷變化。 然而,優化可能變得過於貪婪,並使其成為永無止境的貓捉老鼠遊戲。這是模型不收斂且模式崩潰的場景之一。
特徵匹配改變了生成器的成本函數,用來最小化真實圖像的特徵與生成圖像之間的統計差異,即,它將目標從擊敗對手擴展到真實圖像中的特徵匹配。 我們使用圖像特徵函數 f(x) 對真實圖像和生成圖像的均值間的L2範數距離來懲罰生成器。
其中 f(x) 是鑒別器立即層的輸出,用於提取圖像特徵。
每個批次計算的實際圖像特徵的平均值,都會波動。這對於減輕模式崩潰來說可能是個好消息。它引入了隨機性,使得鑒別器更難以過擬合。
當 GAN 模型在訓練期間不穩定時,特徵匹配是有效的。
微批次鑒別
當模式坍塌時,創建的所有圖像看起來都相似。為了緩解這個問題,我們將不同批次的實際圖像和生成的圖像分別送給鑒別器,並計算圖像 x 與同一批次中其餘圖像的相似度。 我們在鑒別器的一個密集層中附加相似度 o(x) ,來確定該圖像是真實的還是生成的。
如果模式開始崩潰,則生成的圖像的相似性增加。鑒別器可以使用該分數來檢測生成的圖像。這促使生成器生成具有更接近真實圖像的多樣性的圖像。
圖像 xi 與同一批次中的其他圖像之間的相似度 o(xi) 是通過一個變換矩陣 T 計算得到的。如下所示,xi 是輸入圖像,xj 是同一批次中的其餘圖像。
方程式有點難以追蹤,但概念非常簡單。(讀者可以選擇直接跳到下一部分。)我們使用變換矩陣 T 將特徵 xi 轉換為 Mi , 一個 B×C 的矩陣。
我們使用 L1 範數和下面的等式導出圖像 i 和 j 之間的相似度 c(xi, xj) 。
圖像 xi 與批次中其餘圖像之間的相似度 o(xi) 為
這裡是回顧:
引用自論文「 Improved Techniques for Training GANs 」
微批次鑒別使我們能夠非常快速地生成視覺上吸引人的樣本,在這方面它優於特徵匹配。
單面標籤平滑
深度網路可能會過自信。 例如,它使用很少的特徵來對對象進行分類。 深度學習使用正則化和 Dropout 來緩解問題。
在 GAN 中,我們不希望模型過擬合,尤其是在數據雜訊大時。如果鑒別器過分依賴於某一小組特徵來檢測真實圖像,則生成器可能迅速模仿這些特徵以擊敗鑒別器。在 GAN 中,過度自信的負面作用嚴重,因為鑒別器很容易成為生成器利用的目標。為了避免這個問題,當任何真實圖像的預測超過 0.9(D(實際圖像)> 0.9)時,我們會對鑒別器進行懲罰。 這是通過將目標標籤值設置為 0.9 而不是 1.0 來完成的。 這裡是偽代碼:
歷史平均
在歷史平均中,我們跟蹤最後 t 個模型的模型參數。 或者,如果我們需要保留一長串模型,我們會更新模型參數的運行平均值。
我們為成本函數添加了如下的一個 L2 成本,來懲罰不同於歷史平均值的模型。
對於具有非凸對象函數的 GAN,歷史平均可以迫使模型參數停止圍繞平衡點兜圈子,從而令其收斂。
經驗回放
為了擊敗生成器當前產生的內容,模型優化可能變得過於貪婪。為了解決這個問題,經驗回放維護了過去優化迭代中最新生成的圖像。我們不僅僅使用當前生成的圖像去擬合模型,而且還為鑒別器提供了所有最近生成的圖像。因此,鑒別器不會針對生成器某一特定時間段生成的實例進行過度擬合。
使用標籤(CGAN)
許多數據集都帶有樣本對象類型的標籤。訓練 GAN 已經很難了。因此,對於引導 GAN 的訓練來說,任何額外的幫助都可以大大提高其性能。添加標籤作為潛在空間 z 的一部分, 有助於 GAN 的訓練。如下所示 , CGAN 中採用的數據流就充分利用了樣本的標籤。
成本函數
成本函數重要嗎? 它當然重要,否則那麼多研究工作的心血都將是一種浪費。但是如果你聽說過 2017 年 Google Brain 的一篇論文,你肯定會有疑慮。 但努力提升圖像質量仍然是首要任務。因此在我們對成本函數的作用有一個明確的認識之前,我們很有可能會看到研究人員仍在努力嘗試著不同的成本函數。
下圖列出了一些常見 GAN 模型的成本函數。
表格改動自這裡:
https://github.com/hwalsuklee/tensorflow-generative-model-collections
我們決定不在本文中詳細介紹這些成本函數。實際上,如果您想了解更多信息,我們強烈建議您細緻地閱讀這些文章中的至少一篇:WGAN/WGAN-GP,EBGAN / BEGAN,LSGAN,RGAN 和 RaGAN 。 在本文的最後,我們還列出了一篇更詳細地研究成本函數的文章。 成本函數是 GAN 的一個主要研究領域,我們鼓勵您稍後閱讀該文章。
以下是某些數據集中的一些 FID 分數(越低越好)。這是一個參考點,但需要注意的是,現在對於究竟哪些成本函數表現最佳下結論還為時尚早。 實際上,目前還沒有哪一個成本函數在所有不同數據集中都具有最佳表現。
但缺乏好的超參數的模型不可能表現良好,而調參需要大量時間。所以在隨機測試不同的成本函數之前,請耐心地優化超參數。
實現技巧
將圖像的像素值轉換到 -1 到 1 之間。在生成模型的最後一層使用 tanh 作為激活函數。
在實驗中使用高斯分布對 z 取樣。
Batch normalization 可以讓訓練結果更穩定。
上採樣時使用 PixelShuffle 和反卷積。
下採樣時不要使用最大池化而使用卷積步長。
Adam 優化通常比別的優化方法表現的更好。
圖像交給判別模型之前添加一些雜訊,不管是真實的圖片還是生成的。
GAN 模型的動態特性尚未得到很好的解釋。所以這些技巧只是建議,其優化結果如何可能存在差異。例如,提出 LSGAN 的文章指出 RMSProp 在他們的實驗中表現更加穩定。這種情況非常稀少,但是也表明了提出普遍性的建議是非常困難的。
Virtual batch normalization (VBN)
Batch normalization 已經成為很多深度神經網路設計中的事實標準。Batch normalization 的均值和方差來自當前的 minibatch 。然而,它會在樣本之間創建依賴關係,導致生成的圖像不是彼此獨立的。
下圖顯示了在使用同一個 batch 的數據訓練時,生成的圖像有著相同的色調。
本來, 我們對雜訊 z 是從隨機分布中採樣,為我們提供獨立樣本。然而,這種 batch normalization 造成的偏見卻抵消了 z 的隨機性。
Virtual batch normalization (VBN) 是在訓練前從一個 reference batch 中採樣。在前向傳播中,我們提前選擇一個 reference batch 為 batch normalization 去計算 normalization 的參數( μ 和 σ )。 然而,我們在整個訓練過程中使用同一個 batch,會讓模型過擬合。為了解決這個問題,我們將 reference batch 與當前 batch 相結合起來計算參數。
隨機種子
用於初始化模型參數的隨機種子會影響 GAN 的性能。 如下表所示,測量GAN性能的FID分數在50次獨立運行(訓練)中有所不同。但是波動的範圍不大,並且可以在後續的微調中完成。
一篇來自 Google Brain 的論文指出 LSGAN 偶爾會在某些數據集中失敗或崩潰,並且需要使用另一個隨機種子重新啟動訓練。
Batch normalization
DGCAN 強力建議在網路設計中加入 batch normalization 。 Batch normalization 的使用也成為許多深度網路模型的一般做法。 但是,也會有例外。 下圖演示了 batch normalization 對不同數據集的影響。 y 軸是 FID 得分,越低越好。 正如 WGAN-GP 論文所建議的那樣,當使用成本函數 WGAN-GP 時,不應該使用 batch normalization 。 我們建議讀者檢查 batch normalization 上使用的成本函數和相應的FID性能,並通過實驗驗證來設置。
多重 GANs
模式崩潰可能並不全是壞事。 實際上,當模式崩潰時,圖像質量通常會提高。 實際上,我們可以會為每種模式收集最佳模型,並使用它們來重建不同的圖像模式。
判別模型和生成模型之間的平衡
判別模型和生成模型總是處於拉鋸戰中以相互削弱。生成模型積極創造最好的圖像來擊敗判別模型。 但如果判別模型響應緩慢,生成的圖像將收斂,模式開始崩潰。 相反,當判別模型表現良好時,原始生成模型的成本函數的梯度消失,學習速度慢。 我們可以將注意力轉向平衡生成模型和判別模型之間的損失,以便在訓練 GAN 中找到最佳位置。 不幸的是,解決方案似乎難以捉摸。 在判別模型和生成模型之間的交替梯度下降中,定義它們之間的靜態比率似乎是有效的,但也有很多人懷疑它的效果。 如果說已經有人做過這件事的話,那就是研究人員每訓練生成模型5次再更新判別模型的嘗試了。 其他動態平衡兩個網路的建議僅在最近才引起關注。
另一方面,一些研究人員認為平衡這些網路的可行性和願景是很困難的。 訓練有素的判別模型無論如何都能為生成模型提供高質量的反饋。 然而訓練生成模型使之能與判斷模型抗衡也並不容易。 相反,當生成模型表現不佳時,我們可能會將注意力轉向尋找不具有接近零梯度的成本函數。
然而問題仍然是存在的。 人們提出了許多建議,研究者們對什麼是最好的損失函數的爭論仍在繼續。
判別模型和生成模型的網路容量
判別模型通常比生成模型更複雜(有更多濾波器和更多層),而良好的判別模型可以提供高質量的信息。 在許多 GAN 應用中,當增加生成模型容量並沒有帶來質量上的改進時,我們便遇到了瓶頸。 在我們確定遭遇了瓶頸並解決這個問題之前,增加生成模型容量不會成為優先考慮的選項。
※用 OpenCV 檢測圖像中各物體大小
※Kaggle 新賽:Google AI Open Images 目標檢測
TAG:AI研習社 |