雙向循環神經網路及TensorFlow實現
前言
循環神經網路得益於其記憶功能使其擅長處理序列方面的問題,它能提取序列之間的特徵,進而對序列輸出進行預測。比如我說「我肚子餓了,準備去xx」,那麼根據前面的序列輸入來預測「xx」很可能就是「吃飯」。
單向循環神經網路
所謂的單向循環神經網路其實就是常見的循環神經網路,可以看到t時刻、t-1時刻、t+1時刻,不同時刻輸入對應不同的輸出,而且上一時刻的隱含層會影響當前時刻的輸出。這種結構就是單向循環神經網路結構。
這裡寫圖片描述
單向循環神經網路不足
從單向的結構可以知道它的下一刻預測輸出是根據前面多個時刻的輸入來共同影響的,而有些時候預測可能需要由前面若干輸入和後面若干輸入共同決定,這樣會更加準確。比如我說「我肚子xx,準備去吃飯」,那麼如果沒有後面的部分就不能很好地推斷出是「餓了」,也可以是「好疼」或「胖了」之類的。
雙向循環神經網路
鑒於單向循環神經網路某些情況下的不足,提出了雙向循環神經網路。因為是需要能關聯未來的數據,而單向循環神經網路屬於關聯歷史數據,所以對於未來數據提出了反向循環神經網路,兩個方向的網路結合到一起就能關聯歷史與未來了。
雙向循環神經網路按時刻展開的結構如下,可以看到向前和向後層共同連接著輸出層,其中包含了6個共享權值,分別為輸入到向前層和向後層兩個權值、向前層和向後層各自隱含層到隱含層的權值、向前層和向後層各自隱含層到輸出層的權值。
這裡寫圖片描述
可以由下列式子表示,
雙向循環網路如何訓練
前向傳播
沿著時刻1到時刻T正向計算一遍,得到並保存每個時刻向前隱含層的輸出。
沿著時刻T到時刻1反向計算一遍,得到並保存每個時刻向後隱含層的輸出。
正向和反向都計算完所有輸入時刻後,每個時刻根據向前向後隱含層得到最終輸出。
反向傳播
計算所有時刻輸出層的項。
根據所有輸出層的項,使用 BPTT 演算法更新向前層。
根據所有輸出層的項,使用 BPTT 演算法更新向後層。
實現代碼
創建辭彙
處理字元首先就是需要創建包含語料中所有的詞的辭彙,需要一個從字元到辭彙位置索引的詞典,也需要一個從位置索引到字元的詞典。
批量生成器
創建一個批量生成器用於將文本生成批量的訓練樣本,其中text為整個語料,batch_size為批大小,vocab_size為辭彙大小,seq_length為序列長度,vocab_index_dict為辭彙索引詞典。生成器的生成結構大致如下圖,按文本順序豎著填進矩陣,而矩陣的列大小為batch_size。
這裡寫圖片描述
構建圖
分別定義向前和向後兩個LSTM循環神經網路,需要指定隱含層的神經元數hidden_size,然後將創建的兩個神經網路傳入完成雙向循環神經網路的創建。
接著創建佔位符,主要有輸入佔位符和target佔位符,輸入佔位符,與批大小和序列長度相關的結構[batch_size, seq_length]。最後是target佔位符,結構與輸入佔位符是一樣的。為更好理解這裡給輸入和target畫個圖,如下:
這裡寫圖片描述
這裡寫圖片描述
這裡寫圖片描述
上面得到的3維的嵌入層空間向量,我們無法直接傳入循環神經網路,需要一些處理。需要根據序列長度切割,通過split後再經過squeeze操作後得到一個list,這個list就是最終要進入到循環神經網路的輸入,list的長度為seq_length,這個很好理解,就是由這麼多個時刻的輸入。每個輸入的結構為(batch_size,embedding_size),也即是(20,128)。注意這裡的embedding_size,剛好也是128,與循環神經網路的隱含層神經元數量一樣,這裡不是巧合,而是他們必須要相同,這樣嵌入層出來的矩陣輸入到神經網路才能剛好與神經網路的各個權重完美相乘。最終得到循環神經網路的輸出和最終狀態。
經過2層循環神經網路得到了輸出outputs,但該輸出是一個list結構,我們要通過tf.reshape轉成tf張量形式,該張量結構為(200,128)。同樣target佔位符也要連接起來,結構為(200,)。接著構建softmax層,權重結構為[hidden_size, vocab_size],偏置項結構為[vocab_size],輸出矩陣與權重矩陣相乘並加上偏置項得到logits,然後使用sparse_softmax_cross_entropy_with_logits計算交叉熵損失,最後求損失平均值。
最後使用優化器對損失函數進行優化。為了防止梯度爆炸或梯度消失需要用clip_by_global_norm對梯度進行修正。
計算訓練準確率。
創建會話
創建會話開始訓練,設置需要訓練多少輪,由num_epochs指定。epoch_size為完整訓練一遍語料庫需要的輪數。通過批量生成器獲取一批樣本數據,因為當前時刻的輸入對應的正確輸出為下一時刻的值,所以用data[:-1]和data[1:]得到輸入和target。組織ops並將輸入、target和狀態對應輸入到佔位符上,執行。
github
https://github.com/sea-boat/DeepLearning-Lab/blob/master/BiLstm.py
TAG:遠洋號 |