當前位置:
首頁 > 最新 > 對抗思想與強化學習的碰撞-SeqGAN模型原理和代碼解析

對抗思想與強化學習的碰撞-SeqGAN模型原理和代碼解析

1、背景

GAN作為生成模型的一種新型訓練方法,通過discriminative model來指導generative model的訓練,並在真實數據中取得了很好的效果。儘管如此,當目標是一個待生成的非連續性序列時,該方法就會表現出其局限性。非連續性序列生成,比如說文本生成,為什麼單純的使用GAN沒有取得很好的效果呢?主要的屏障有兩點:

1)在GAN中,Generator是通過隨機抽樣作為開始,然後根據模型的參數進行確定性的轉化。通過generative model G的輸出,discriminative model D計算的損失值,根據得到的損失梯度去指導generative model G做輕微改變,從而使G產生更加真實的數據。而在文本生成任務中,G通常使用的是LSTM,那麼G傳遞給D的是一堆離散值序列,即每一個LSTM單元的輸出經過softmax之後再取argmax或者基於概率採樣得到一個具體的單詞,那麼這使得梯度下架很難處理。

2)GAN只能評估出整個生成序列的score/loss,不能夠細化到去評估當前生成token的好壞和對後面生成的影響。

強化學習可以很好的解決上述的兩點。再回想一下Policy Gradient的基本思想,即通過reward作為反饋,增加得到reward大的動作出現的概率,減小reward小的動作出現的概率,如果我們有了reward,就可以進行梯度訓練,更新參數。如果使用Policy Gradient的演算法,當G產生一個單詞時,如果我們能夠得到一個反饋的Reward,就能通過這個reward來更新G的參數,而不再需要依賴於D的反向傳播來更新參數,因此較好的解決了上面所說的第一個屏障。對於第二個屏障,當產生一個單詞時,我們可以使用蒙塔卡羅樹搜索(Alpho Go也運用了此方法)立即評估當前單詞的好壞,而不需要等到整個序列結束再來評價這個單詞的好壞。

因此,強化學習和對抗思想的結合,理論上可以解決非連續序列生成的問題,而SeqGAN模型,正是這兩種思想碰撞而產生的可用於文本序列生成的模型。

SeqGAN模型的原文地址為:https://arxiv.org/abs/1609.05473,當然在我的github鏈接中已經把下載好的原文貼進去啦。

結合代碼可以更好的理解模型的細節喲:https://github.com/princewen/tensorflow_practice/tree/master/seqgan


2、SeqGAN的原理

SeqGAN的全稱是Sequence Generative Adversarial Nets。這裡打公式太麻煩了,所以我們用word打好再粘過來,沖這波手打也要給小編一個贊呀,哈哈!

整體流程

模型的示意圖如下:

Generator模型和訓練

接下來,我們分別來說一下Generator模型和Discriminator模型結構。

Generator一般選擇的是循環神經網路結構,RNN,LSTM或者是GRU都可以。對於輸入的序列,我們首先得到序列中單詞的embedding,然後輸入每個cell中,並結合一層全鏈接隱藏層得到輸出每個單詞的概率,即:

有了這個概率,Generator可以根據它採樣一批產生的序列,比如我們生成一個只有,兩個單詞的序列,總共的單詞序列有3個,第一個cell的輸出為(0.5,0.5,0.0),第二個cell的輸出為(0.1,0.8,0.1),那麼Generator產生的序列以0.4的概率是1->2,以0.05的概率是1->1。注意這裡Generator產生的序列是概率採樣得到的,而不是對每個輸出進行argmax得到的固定的值。這和policy gradient的思想是一致的。

在每一個cell我們都能得到一個概率分布,我們基於它選擇了一個動作或者說一個單詞,如何判定基於這個概率分布得到的單詞的還是壞的呢?即我們需要一個reward來左右這個單詞被選擇的概率。這個reward怎麼得到呢,就需要我們的Discriminator以及蒙塔卡羅樹搜索方法了。前面提到過Reward的計算依據是最大可能的Discriminator,即儘可能的讓Discriminator認為Generator產生的數據為real-world的數據。這裡我們設定real-world的數據的label為1,而Generator產生的數據label為0.

如果當前的cell是最後的一個cell,即我們已經得到了一個完整的序列,那麼此時很好辦,直接把這個序列扔給Discriminator,得到輸出為1的概率就可以得到reward值。如果當前的cell不是最後一個cell,即當前的單詞不是最後的單詞,我們還沒有得到一個完整的序列,如何估計當前這個單詞的reward呢?我們用到了蒙特卡羅樹搜索的方法。即使用前面已經產生的序列,從當前位置的下一個位置開始採樣,得到一堆完整的序列。在原文中,採樣策略被稱為roll-out policy,這個策略也是通過一個神經網路實現,這個神經網路我們可以認為就是我們的Generator。得到採樣的序列後,我們把這一堆序列扔給Discriminator,得到一批輸出為1的概率,這堆概率的平均值即我們的reward。這部分正如過程示意圖中的下面一部分:

用原文中的公式表示如下:

得到了reward,我們訓練Generator的方式就很簡單了,即通過Policy Gradient的方式進行訓練。最簡單的思想就是增加reward大的動作的選擇概率,減小reward小的動作的選擇概率。

Discriminator模型和訓練

Discriminator模型即一個分類器,對文本分類的分類器很多,原文採用的是卷積神經網路。同時為了使模型的分類效果更好,在CNN的基礎上增加了一個highway network。有關highway network的介紹參考博客:https://blog.csdn.net/l494926429/article/details/51737883,這裡就不再細講啦。

對於Discriminator來說,既然是一個分類器,輸出的又是兩個類別的概率值,我們很自然的想到使用類似邏輯回歸的對數損失函數,沒錯,論文中也是使用對數損失來訓練Discriminator的。

結合oracle模型

可以說,模型我們已經介紹完了,但是在實驗部分,論文中引入了一個新的模型中,被稱為oracle model。這裡的oracle如何翻譯,我還真的是不知道,總不能翻譯為甲骨文吧。這個oracle model被用來生成真實的序列,可以認為這個model就是一個被訓練完美的lstm模型,輸出的序列都是real-world數據。論文中使用這個模型的原因有兩點:首先是可以用來產生訓練數據,另一點是可以用來評價我們Generator的真實表現。原文如下:

我們會在訓練過程中不斷通過上面的式子來評估我們的Generator與oracle model的相似性。

預訓練過程

上面我們講的其實是在對抗過程中Generator和Discriminator的訓練過程,其實在進行對抗之前,我們的Generator和Discriminator都有一個預訓練的過程,這能使我們的模型更快的收斂。

對於Generator來說,預訓練和對抗過程中使用的損失函數是不一樣的,在預訓練過程中,Generator使用的是交叉熵損失函數,而在對抗過程中,我們使用的則是Policy Gradient中的損失函數,即對數損失*獎勵值。

而對Discriminator來說,兩個過程中的損失函數都是一樣的,即我們前面介紹的對數損失函數。

SeqGAN模型流程

介紹了這麼多,我們再來看一看SeqGAN的流程:

3、SeqGAN代碼解析

這裡我們用到的代碼高度還原了原文中的實驗過程,本文參考的github代碼地址為:https://github.com/ChenChengKuan/SeqGAN_tensorflow

參考的代碼為python2版本的,本文將其稍作修改,改成了python3版本的。其實主要就是print和pickle兩個地方。本文代碼的github地址為:https://github.com/princewen/tensorflow_practice/tree/master/seqgan

代碼實在是太多了,我們這裡只介紹一下代碼結構,具體的代碼細節大家可以參考github進行學習。

3.1 代碼結構

本文的代碼結構如下:

save:save文件夾下保存了我們的實驗日誌,eval_file是由Generator產生,用來評價Generator和oracle model相似性所產生的數據。real_data是由oracle model產生的real-world數據,generator_sample是由Generator產生的數據,target_params是oracle model的參數,我們直接用裡面的參數還原oracle model。

configuration: 一些配置參數

dataloader.py: 產生訓練數據,對於Generator來說,我們只在預訓練中使用dataloader來得到訓練數據,對Discriminator來說,在預訓練和對抗過程中都要使用dataloader來得到訓練數據。而在eval過程即進行Generator和oracle model相似性判定時,會用刀dataloader來產生數據。

discriminator.py:定義了我們的discriminator

generator.py:定義了我們的generator

rollout.py:計算reward時的採樣過程

target_lstm.py:定義了我們的oracle model,這個文件不用管,複製過去就好,哈哈。

train.py: 定義了我們的訓練過程,這是我們一會重點講解的文件

utils.py: 定義了一些在訓練過程中的通用過程。

下面,我們就來介紹一下每個文件。

3.2 dataloader

dataloader是我們的數據生成器。

它定義了兩個類,一個時Generator的數據生成器,主要用於Generator的預訓練以及計算Generator和Oracle model的相似性。另一個時Discriminator的數據生成器,主要用於Discriminator的訓練。

3.3 generator

generator中定義了我們的Generator,代碼結構如下:

build_input:定義了我們的預訓練模型和對抗過程中需要輸入的數據

build_pretrain_network: 定義了Generator的預訓練過程中的網路結構,其實這個網路結構在預訓練,對抗和採樣的過程中是一樣的,參數共享。預訓練過程中定義的損失是交叉熵損失。

build_adversarial_network: 定義了Generator的對抗過程的網路結構,和預訓練過程共享參數,因此你可以發現代碼基本上是一樣的,只不過在對抗過程中的損失函數是policy gradient的損失函數,即 -log(p(xi) * v(xi):

build_sample_network:定義了我們Generator採樣得到生成序列過程的網路結構,與前兩個網路參數是共享的。

那麼這三個網路是如何使用的呢?pretrain_network就是用來預訓練我們的Generator的,這個沒有異議。然後在對抗時的每一個epoch,首先用sample_network得到一堆採樣的序列samples,然後對採樣序列的對每一個時點,使用roll-out-policy結合Discriminator得到reward值。最後,把這些samples和reward值餵給adversarial_network進行參數更新。

3.4 discriminator

discriminator的文件結構如下:

前面的linear和highway函數實現了highway network。

在Discriminator類中,我們採用CNN建立了Discriminator的網路結構,值得注意的是,我們這裡採用的損失函數加入了正則項:

3.5 rollout

這個文件實現的通過rollout-policy得到一堆完整序列的過程,前面我們提到過了,rollout-policy實現需要一個神經網路,而我們這裡用Generator當作這個神經網路,所以它與前面提到的三個Generator的網路的參數也是共享的。

另外需要注意的是,我們這裡要得到每個序列每個時點的採樣數據,因此需要進行兩層循環:

假設我們傳過來的序列長度是20,最後一個不需要進行採樣,因為已經是完整的序列了。假設當前的step是5,那麼0-4是不需要採樣的,但我們需要把0-4位置的序列輸入到網路中得到state。得到state之後,我們再經過一層循環得到5-19位的採樣序列,然後將0-4位置的序列的和5-19位置的序列的進行拼接。

3.6 utils

utils中定義了兩個函數:

generate_samples函數用於調用Generator中的sample_network產生sample或者用於調用target-lstm中的sample_network產生real-world數據

target_loss函數用於計算Generator和oracle model的相似性。

3.7 train

終於改介紹我們的主要流程式控制制代碼了,先深呼吸一口,準備開始!

定義dataloader以及網路

首先,我們獲取了configuration中定義的參數,然後基於這些參數,我們得到了三個dataloader。

隨後,我們定義了Generator和Discriminator,以及通過讀文件來創建了我們的oracle model,在代碼中叫target_lstm。

預訓練Generator

我們首先定義了預訓練過程中Generator的優化器,即通過AdamOptimizer來最小化交叉熵損失,隨後我們通過target-lstm網路來產生Generator的訓練數據,利用dataloader來讀取每一個batch的數據。

同時,每隔一定的步數,我們會計算Generator與target-lstm的相似性(likelihood)

預訓練Discriminator

預訓練好Generator之後,我們就可以通過Generator得到一批負樣本,並結合target-lstm產生的正樣本來預訓練我們的Discriminator。

定義對抗過程中Generator的優化器

這裡定義的對抗過程中Generator的優化器即最小化我們前面提到的policy gradient損失,再回顧一遍:

對抗過程中訓練Generator

對抗過程中訓練Generator,我們首先需要通過Generator得到一批序列sample,然後使用roll-out結合Dsicriminator得到每個序列中每個時點的reward,再將reward和sample餵給adversarial_network進行參數更新。

對抗過程中訓練Discriminator

對抗過程中Discriminator的訓練和預訓練過程一樣,這裡就不再贅述。

3.8 訓練效果

來一發訓練效果截圖:

可以看到,我們的Generator越來越接近oracle model啦,哈哈哈!


參考文獻:

1、https://blog.csdn.net/liuyuemaicha/article/details/70161273

2、https://blog.csdn.net/yinruiyang94/article/details/77675586

3、https://www.jianshu.com/p/32e164883eab

4、https://blog.csdn.net/l494926429/article/details/51737883


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 小小挖掘機 的精彩文章:

深度強化學習-Policy Gradient基本實現

TAG:小小挖掘機 |