當前位置:
首頁 > 最新 > 深度學習自動編碼器還能用於數據生成?這篇文章告訴你答案

深度學習自動編碼器還能用於數據生成?這篇文章告訴你答案

AI研習社按:本文作者廖星宇,原載於作者知乎專欄,AI研習社經授權發布。

什麼是自動編碼器

自動編碼器(AutoEncoder)最開始作為一種數據的壓縮方法,其特點有:

跟數據相關程度很高,這意味著自動編碼器只能壓縮與訓練數據相似的數據,這個其實比較顯然,因為使用神經網路提取的特徵一般是高度相關於原始的訓練集,使用人臉訓練出來的自動編碼器在壓縮自然界動物的圖片是表現就會比較差,因為它只學習到了人臉的特徵,而沒有能夠學習到自然界圖片的特徵;

壓縮後數據是有損的,這是因為在降維的過程中不可避免的要丟失掉信息;

到了2012年,人們發現在卷積網路中使用自動編碼器做逐層預訓練可以訓練更加深層的網路,但是很快人們發現良好的初始化策略要比費勁的逐層預訓練有效地多,2014年出現的Batch Normalization技術也是的更深的網路能夠被被有效訓練,到了15年底,通過殘差(ResNet)我們基本可以訓練任意深度的神經網路。

所以現在自動編碼器主要應用有兩個方面,第一是數據去噪,第二是進行可視化降維。然而自動編碼器還有著一個功能就是生成數據。

我們之前講過GAN,它與GAN相比有著一些好處,同時也有著一些缺點。我們先來講講其跟GAN相比有著哪些優點。

第一點,我們使用GAN來生成圖片有個很不好的缺點就是我們生成圖片使用的隨機高斯雜訊,這意味著我們並不能生成任意我們指定類型的圖片,也就是說我們沒辦法決定使用哪種隨機雜訊能夠產生我們想要的圖片,除非我們能夠把初始分布全部試一遍。但是使用自動編碼器我們就能夠通過輸出圖片的編碼過程得到這種類型圖片的編碼之後的分布,相當於我們是知道每種圖片對應的雜訊分布,我們就能夠通過選擇特定的雜訊來生成我們想要生成的圖片。

第二點,這既是生成網路的優點同時又有著一定的局限性,這就是生成網路通過對抗過程來區分「真」的圖片和「假」的圖片,然而這樣得到的圖片只是儘可能像真的,但是這並不能保證圖片的內容是我們想要的,換句話說,有可能生成網路儘可能的去生成一些背景圖案使得其儘可能真,但是裡面沒有實際的物體。

自動編碼器的結構

首先我們給出自動編碼器的一般結構

從上面的圖中,我們能夠看到兩個部分,第一個部分是編碼器(Encoder),第二個部分是解碼器(Decoder),編碼器和解碼器都可以是任意的模型,通常我們使用神經網路模型作為編碼器和解碼器。輸入的數據經過神經網路降維到一個編碼(code),接著又通過另外一個神經網路去解碼得到一個與輸入原數據一模一樣的生成數據,然後通過去比較這兩個數據,最小化他們之間的差異來訓練這個網路中編碼器和解碼器的參數。當這個過程訓練完之後,我們可以拿出這個解碼器,隨機傳入一個編碼(code),希望通過解碼器能夠生成一個和原數據差不多的數據,上面這種圖這個例子就是希望能夠生成一張差不多的圖片。

這件事情能不能實現呢?其實是可以的,下面我們會用PyTorch來簡單的實現一個自動編碼器。

首先我們構建一個簡單的多層感知器來實現一下。

class autoencoder(nn.Module):

def __init__(self):

super(autoencoder, self).__init__()

self.encoder = nn.Sequential(

nn.Linear(28*28, 128),

nn.ReLU(True),

nn.Linear(128, 64),

nn.ReLU(True),

nn.Linear(64, 12),

nn.ReLU(True),

nn.Linear(12, 3)

)

self.decoder = nn.Sequential(

nn.Linear(3, 12),

nn.ReLU(True),

nn.Linear(12, 64),

nn.ReLU(True),

nn.Linear(64, 128),

nn.ReLU(True),

nn.Linear(128, 28*28),

nn.Tanh()

)

def forward(self, x):

x = self.encoder(x)

x = self.decoder(x)

return x

這裡我們定義了一個簡單的4層網路作為編碼器,中間使用ReLU激活函數,最後輸出的維度是3維的,定義的解碼器,輸入三維的編碼,輸出一個28x28的圖像數據,特別要注意最後使用的激活函數是Tanh,這個激活函數能夠將最後的輸出轉換到-1 ~1之間,這是因為我們輸入的圖片已經變換到了-1~1之間了,這裡的輸出必須和其對應。

訓練過程也比較簡單,我們使用最小均方誤差來作為損失函數,比較生成的圖片與原始圖片的每個像素點的差異。

同時我們也可以將多層感知器換成卷積神經網路,這樣對圖片的特徵提取有著更好的效果。

class autoencoder(nn.Module):

def __init__(self):

super(autoencoder, self).__init__()

self.encoder = nn.Sequential(

nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10

nn.ReLU(True),

nn.MaxPool2d(2, stride=2), # b, 16, 5, 5

nn.Conv2d(16, 8, 3, stride=2, padding=1), # b, 8, 3, 3

nn.ReLU(True),

nn.MaxPool2d(2, stride=1) # b, 8, 2, 2

)

self.decoder = nn.Sequential(

nn.ConvTranspose2d(8, 16, 3, stride=2), # b, 16, 5, 5

nn.ReLU(True),

nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b, 8, 15, 15

nn.ReLU(True),

nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b, 1, 28, 28

nn.Tanh()

)

def forward(self, x):

x = self.encoder(x)

x = self.decoder(x)

return x

這裡使用了 nn.ConvTranspose2d(),這可以看作是卷積的反操作,可以在某種意義上看作是反卷積。

我們使用卷積網路得到的最後生成的圖片效果會更好,具體的圖片效果我就不再這裡放了,可以在我們的github上看到圖片的展示。github 地址:

http://t.cn/RK5gxpM

變分自動編碼器(Variational Auto Encoder)

變分編碼器是自動編碼器的升級版本,其結構跟自動編碼器是類似的,也由編碼器和解碼器構成。

回憶一下我們在自動編碼器中所做的事,我們需要輸入一張圖片,然後將一張圖片編碼之後得到一個隱含向量,這比我們隨機取一個隨機雜訊更好,因為這包含著原圖片的信息,然後我們隱含向量解碼得到與原圖片對應的照片。

但是這樣我們其實並不能任意生成圖片,因為我們沒有辦法自己去構造隱藏向量,我們需要通過一張圖片輸入編碼我們才知道得到的隱含向量是什麼,這時我們就可以通過變分自動編碼器來解決這個問題。

其實原理特別簡單,只需要在編碼過程給它增加一些限制,迫使其生成的隱含向量能夠粗略的遵循一個標準正態分布,這就是其與一般的自動編碼器最大的不同。

這樣我們生成一張新圖片就很簡單了,我們只需要給它一個標準正態分布的隨機隱含向量,這樣通過解碼器就能夠生成我們想要的圖片,而不需要給它一張原始圖片先編碼。

在實際情況中,我們需要在模型的準確率上與隱含向量服從標準正態分布之間做一個權衡,所謂模型的準確率就是指解碼器生成的圖片與原圖片的相似程度。我們可以讓網路自己來做這個決定,非常簡單,我們只需要將這兩者都做一個loss,然後在將他們求和作為總的loss,這樣網路就能夠自己選擇如何才能夠使得這個總的loss下降。另外我們要衡量兩種分布的相似程度,如何看過之前一片GAN的數學推導,你就知道會有一個東西叫KL divergence來衡量兩種分布的相似程度,這裡我們就是用KL divergence來表示隱含向量與標準正態分布之間差異的loss,另外一個loss仍然使用生成圖片與原圖片的均方誤差來表示。

我們可以給出KL divergence 的公式

這裡變分編碼器使用了一個技巧「重新參數化」來解決 KL divergence 的計算問題。

這時不再是每次產生一個隱含向量,而是生成兩個向量,一個表示均值,一個表示標準差,然後通過這兩個統計量來合成隱含向量,這也非常簡單,用一個標準正態分布先乘上標準差再加上均值就行了,這裡我們默認編碼之後的隱含向量是服從一個正態分布的。這個時候我們是想讓均值儘可能接近0,標準差儘可能接近1。而論文裡面有詳細的推導如何得到這個loss的計算公式,有興趣的同學可以去看看具體推到過程:

https://arxiv.org/pdf/1606.05908.pdf

下面是PyTorch的實現:

reconstruction_function = nn.BCELoss(size_average=False) # mse loss

def loss_function(recon_x, x, mu, logvar):

"""

recon_x: generating images

x: origin images

mu: latent mean

logvar: latent log variance

"""

BCE = reconstruction_function(recon_x, x)

# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)

KLD = torch.sum(KLD_element).mul_(-0.5)

# KL divergence

return BCE + KLD

另外變分編碼器除了可以讓我們隨機生成隱含變數,還能夠提高網路的泛化能力。

最後是VAE的代碼實現:

class VAE(nn.Module):

def __init__(self):

super(VAE, self).__init__()

self.fc1 = nn.Linear(784, 400)

self.fc21 = nn.Linear(400, 20)

self.fc22 = nn.Linear(400, 20)

self.fc3 = nn.Linear(20, 400)

self.fc4 = nn.Linear(400, 784)

def encode(self, x):

h1 = F.relu(self.fc1(x))

return self.fc21(h1), self.fc22(h1)

def reparametrize(self, mu, logvar):

std = logvar.mul(0.5).exp_()

eps = torch.cuda.FloatTensor(std.size()).normal_()

else:

eps = torch.FloatTensor(std.size()).normal_()

eps = Variable(eps)

return eps.mul(std).add_(mu)

def decode(self, z):

h3 = F.relu(self.fc3(z))

return F.sigmoid(self.fc4(h3))

def forward(self, x):

mu, logvar = self.encode(x)

z = self.reparametrize(mu, logvar)

return self.decode(z), mu, logvar

VAE的結果比普通的自動編碼器要好很多,下面是結果:

VAE的缺點也很明顯,他是直接計算生成圖片和原始圖片的均方誤差而不是像GAN那樣去對抗來學習,這就使得生成的圖片會有點模糊。現在已經有一些工作是將VAE和GAN結合起來,使用VAE的結構,但是使用對抗網路來進行訓練,具體可以參考一下這篇論文:

https://arxiv.org/pdf/1512.09300.pdf

文中相關代碼鏈接:

http://t.cn/RK5gxpM

英文參考:

http://t.cn/RtoJRAa


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 唯物 的精彩文章:

只要130 行代碼即可生成二維樣本,心動了嗎?
如何用 Caffe 生成對抗樣本?這篇文章告訴你一個更高效的演算法
一文詳解 DNN 在聲學應用中的模型訓練
深度學習下的醫學圖像分析(三)
今日頭條成功的核心技術秘訣是什麼?

TAG:唯物 |

您可能感興趣

家用體脂秤准嗎?這篇文章告訴你真實答案
吃雞選擇手機還是模擬器?這篇文章告訴你最權威的答案
怎樣才能確定Ta愛不愛你,這篇文章告訴你答案
用一篇文章來了解數據編碼
讀了那麼多繪本,真有用嗎?這篇文章告訴你答案
為什麼火影忍者被稱為不朽的經典?這篇文章告訴你答案!
數字能量婚姻測算身份證號碼,你也許還不明白!這篇文章告訴原因
摺疊屏為何能屈能伸?這篇文章告訴你
便攜與性能如何兼得?這篇文章給你答案
OTK卡組真的簡單無腦嗎?這篇文章告訴你答案
初學者該練什麼字體呢?這篇文章給你答案!
他心裡都是怎麼想的?這篇文章給你答案
怎樣用技術流寫一篇文章?
開放式廚房,不知道怎麼設計?這篇文章給你答案
怕患者噎食?這篇文章可能有用
核桃該怎麼吃 這篇文章告訴您
闢謠!胖的人容易得脂肪瘤?應該如何治療?這篇文章告訴你答案
我花生命寫這篇文章,你也在花生命閱讀這篇文章!
為什麼要給貓狗做絕育?這篇文章告訴你答案
缺乏維生素導致什麼病?一篇文章告訴你!