當前位置:
首頁 > 最新 > 教你用PyTorch實現「看圖說話」

教你用PyTorch實現「看圖說話」

作者:FAIZAN SHAIKH

翻譯:和中華

校對:白靜

本文共2200字,建議閱讀10分鐘


本文用淺顯易懂的方式解釋了什麼是「看圖說話」(Image Captioning),藉助github上的PyTorch代碼帶領大家自己做一個模型,並附帶了很多相關的學習資源。

深度學習目前是一個非常活躍的領域---每天都會有許多應用出現。進一步學習DeepLearning最好的方法就是親自動手。儘可能多的接觸項目並且嘗試自己去做。這將會幫助你更深刻地掌握各個主題,成為一名更好的DeepLearning實踐者。

這篇文章將和大家一起看一個有趣的多模態主題,我們將結合圖像和文本處理技術來構建一個有用的深度學習應用,即看圖說話(ImageCaptioning)。看圖說話是指從一個圖像中基於其中的對象和動作生成文本描述的過程。例如:

這種過程在現實生活中有很多潛在的應用場景。一個明顯的應用比如保存圖片的描述字幕,以便該圖片隨後可以根據這個描述輕鬆地被檢索出來。

我們開始吧!

注意: 本文假定你了解深度學習的基礎知識,以前曾使用CNN處理過圖像問題。如果想複習這些概念,可以先閱讀下面的文章:

Fundamentalsof Deep Learning – Starting with Artificial Neural Network

Architectureof Convolutional Neural Networks (CNNs) demystified

Tutorial:Optimizing Neural Networks using Keras (with Image recognition casestudy)

Essentialsof Deep Learning – Sequence to Sequence modelling with Attention(using python)


什麼是ImageCaptioning問題?

解決任務的方法

應用演練

下一步工作


設想你看到了這張圖:

你首先想到的是什麼?下面是一些人們可以想到的句子:

A man and a girl sit on the ground and eat . (一個男人和一個女孩坐在地上吃東西)

A man and a little girl are sitting on a sidewalk near a blue bag eating . (一個男人和一個小女孩坐在藍色包旁邊的人行道上吃東西)

A man wearing a black shirt and a little girl wearing an orange dress share a treat .(一個穿黑色襯衣的男人和一個穿橘色連衣裙的小女孩分享美食)

快速看一眼就足以讓你理解和描述圖片中發生的事情。從一個人造系統中自動生成這種文字描述就是ImageCaptioning的任務。

該任務很明確,即產生的輸出是用一句話來描述這幅圖片中的內容---存在的對象,屬性,正在發生的動作以及對象之間的互動等。但是與其他圖像處理問題一樣,在人造系統中再現這種行為也是一項艱巨的任務。因此需要使用像DeepLearning這樣先進複雜的技術來解決該任務。

在繼續下文之前,我想特別感謝AndrejKartpathy等學者,他們富有洞察力的課程CS231n幫助我理解了這個主題。


可以把imagecaptioning任務在邏輯上分為兩個模塊——一個是基於圖像的模型,從圖像中提取特徵和細微的差別,另一個是基於語言的模型,將第一個模型給出的特徵和對象翻譯成自然的語句。

對於基於圖像的模型而言(即編碼器)我們通常依靠CNN網路。對於基於語言的模型而言(即解碼器),我們依賴RNN網路。下圖總結了前面提到的方法:

通常,一個預先訓練好的CNN網路從輸入圖像中提取特徵。特徵向量被線性轉換成與RNN/LSTM網路的輸入具有相同的維度。這個網路被訓練作為我們特徵向量的語言模型。

為了訓練LSTM模型,我們預先定義了標籤和目標文本。比如,如果字幕是Aman and a girl sit on the ground and eat .(一個男人和一個女孩坐在地上吃東西),則我們的標籤和目標文本如下:

這樣做是為了讓模型理解我們標記序列的開始和結束。

讓我們看一個Pytorch中imagecaptioning的簡單實現。我們將以一幅圖作為輸入,然後使用深度學習模型來預測它的描述。

例子的代碼可以在GitHub上找到。代碼的原始作者是YunjeyChoi 向他傑出的pytorch例子致敬。

在本例中,一個預先訓練好的ResNet-152被用作編碼器,而解碼器是一個LSTM網路。

要運行本例中的代碼,你需要安裝必備軟體,確保有一個可以工作的python環境,最好使用anaconda。然後運行以下命令來安裝其他所需要的庫。

gitclone https://github.com/pdollar/coco.git

cdcoco/PythonAPI/

make

pythonsetup.py build

pythonsetup.py install

cd../../

gitclone https://github.com/yunjey/pytorch-tutorial.git

cdpytorch-tutorial/tutorials/03-advanced/image_captioning/

pipinstall -r requirements.txt

設置完系統後,就該下載所需的數據集並且訓練模型了。這裡我們使用的是MS-COCO數據集。可以運行如下命令來自動下載數據集:

chmod+x download.sh

./download.sh

現在可以繼續並開始模型的構建過程了。首先,你需要處理輸入:

#Search for all the possible words in the dataset and

#build a vocabulary list

pythonbuild_vocab.py

#resize all the images to bring them to shape 224x224

pythonresize.py

現在,運行下面的命令來訓練模型:

pythontrain.py --num_epochs 10 --learning_rate 0.01

來看一下被封裝好的代碼中是如何定義模型的,可以在model.py文件中找到:

importtorch

importtorch.nn as nn

importtorchvision.models as models

fromtorch.autograd import Variable

classEncoderCNN(nn.Module):

def__init__(self, embed_size):

"""Loadthe pretrained ResNet-152 and replace top fc layer."""

super(EncoderCNN,self).__init__()

resnet= models.resnet152(pretrained=True)

modules= list(resnet.children())[:-1] # delete the last fc layer.

self.resnet= nn.Sequential(*modules)

self.bn= nn.BatchNorm1d(embed_size, momentum=0.01)

self.init_weights()

definit_weights(self):

"""Initializethe weights."""

defforward(self, images):

"""Extractthe image feature vectors."""

features= self.resnet(images)

features= Variable(features.data)

features= features.view(features.size(0), -1)

features= self.bn(self.linear(features))

returnfeatures

classDecoderRNN(nn.Module):

def__init__(self, embed_size, hidden_size, vocab_size, num_layers):

"""Setthe hyper-parameters and build the layers."""

super(DecoderRNN,self).__init__()

self.embed= nn.Embedding(vocab_size, embed_size)

self.lstm= nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)

self.linear= nn.Linear(hidden_size, vocab_size)

self.init_weights()

definit_weights(self):

"""Initializeweights."""

defforward(self, features, captions, lengths):

"""Decodeimage feature vectors and generates captions."""

embeddings= self.embed(captions)

embeddings= torch.cat((features.unsqueeze(1), embeddings), 1)

packed= pack_padded_sequence(embeddings, lengths, batch_first=True)

hiddens,_ = self.lstm(packed)

outputs= self.linear(hiddens[0])

returnoutputs

defsample(self, features, states=None):

"""Samplescaptions for given image features (Greedy search)."""

sampled_ids= []

inputs= features.unsqueeze(1)

fori in range(20): # maximum samplinglength

hiddens,states = self.lstm(inputs, states) # (batch_size, 1,hidden_size),

outputs= self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size)

predicted= outputs.max(1)[1]

sampled_ids.append(predicted)

inputs= self.embed(predicted)

inputs= inputs.unsqueeze(1) # (batch_size, 1,embed_size)

sampled_ids= torch.cat(sampled_ids, 1) # (batch_size, 20)

returnsampled_ids.squeeze()

現在測試我們的模型:

pythonsample.py --image="png/example.png"

對於樣例圖片,我們的模型給出了這樣的輸出:

a group of giraffes standing in a grassy area .

一群長頸鹿站在草地上

以上就是如何建立一個用於imagecaptioning的深度學習模型。


以上模型只是冰山一角。關於這個主題已經有很多的研究。目前在imagecaptioning領域最先進的模型是微軟的CaptionBot。可以在他們的官網上看一個系統的demo.

我列舉一些可以用來構建更好的imagecaptioning模型的想法:

加入更多數據當然這也是深度學習模型通常的趨勢。提供的數據越多,模型效果越好。可以在這裡找到其他的數據集:http://www.cs.toronto.edu/~fidler/slides/2017/CSC2539/Kaustav_slides.pdf

使用Attention模型正如這篇文章所述(Essentialsof Deep Learning – Sequence to Sequence modelling with Attention),使用attention模型有助於微調模型的性能

轉向更大更好的技術研究人員一直在研究一些技術,比如使用強化學習來構建端到端的深度學習系統,或者使用新穎的attention模型用於「視覺哨兵(visualsentinel)」。


這篇文章中,我介紹了imagecaptioning,這是一個多模態任務,它由解密圖片和用自然語句描述圖片兩部分組成。然後我解釋了解決該任務用到的方法並給出了一個應用演練。對於好奇心強的讀者,我還列舉了幾條可以改進模型性能的方法。

希望這篇文章可以激勵你去發現更多可以用深度學習解決的任務,從而在工業中出現越來越多的突破和創新。如果有任何建議/反饋,歡迎在下面的評論中留言!


譯者簡介

和中華,留德軟體工程碩士。由於對機器學習感興趣,碩士論文選擇了利用遺傳演算法思想改進傳統kmeans。目前在杭州進行大數據相關實踐。加入數據派THU希望為IT同行們盡自己一份綿薄之力,也希望結交許多志趣相投的小夥伴。

轉載須知

如需轉載,請在開篇顯著位置註明作者和出處(轉自:數據派THU ID:DatapiTHU),並在文章結尾放置數據派醒目二維碼。有原創標識文章,請發送【文章名稱-待授權公眾號名稱及ID】至聯繫郵箱,申請白名單授權並按要求編輯。

發布後請將鏈接反饋至聯繫郵箱(見下方)。未經許可的轉載以及改編者,我們將依法追究其法律責任。


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 數據派THU 的精彩文章:

有趣的Github項目萬里挑一!
從特徵金字塔網路、Mask R-CNN到學習分割一切

TAG:數據派THU |