當前位置:
首頁 > 知識 > PyTorch 的預訓練,是時候學習一下了

PyTorch 的預訓練,是時候學習一下了

PyTorch 的預訓練,是時候學習一下了



前言


最近使用 PyTorch 感覺妙不可言,有種當初使用 Keras 的快感,而且速度還不慢。各種設計直接簡潔,方便研究,比 tensorflow 的臃腫好多了。今天讓我們來談談 PyTorch 的預訓練,主要是自己寫代碼的經驗以及論壇PyTorch Forums上的一些回答的總結整理。


直接載入預訓練模型

如果我們使用的模型和原模型完全一樣,那麼我們可以直接載入別人訓練好的模型:


my_resnet = MyResNet(*args, **kwargs)


my_resnet.load_state_dict(torch.load("my_resnet.pth"))


當然這樣的載入方法是基於 PyTorch 推薦的存儲模型的方法:


torch.save(my_resnet.state_dict(), "my_resnet.pth")


還有第二種載入方法:


my_resnet = torch.load("my_resnet.pth")


載入部分預訓練模型


其實大多數時候我們需要根據我們的任務調節我們的模型,所以很難保證模型和公開的模型完全一樣,但是預訓練模型的參數確實有助於提高訓練的準確率,為了結合二者的優點,就需要我們載入部分預訓練模型。


pretrained_dict = model_zoo.load_url(model_urls[ resnet152 ])

model_dict = model.state_dict()


# 將 pretrained_dict 里不屬於 model_dict 的鍵剔除掉


pretrained_dict =


# 更新現有的 model_dict


model_dict.update(pretrained_dict)


# 載入我們真正需要的 state_dict


model.load_state_dict(model_dict)


因為需要剔除原模型中不匹配的鍵,也就是層的名字,所以我們的新模型改變了的層需要和原模型對應層的名字不一樣,比如:resnet 最後一層的名字是 fc(PyTorch 中),那麼我們修改過的 resnet 的最後一層就不能取這個名字,可以叫 fc_


微改基礎模型預訓練


對於改動比較大的模型,我們可能需要自己實現一下再載入別人的預訓練參數。但是,對於一些基本模型 PyTorch 中已經有了,而且我只想進行一些小的改動那麼怎麼辦呢?難道我又去實現一遍嗎?當然不是。

我們首先看看怎麼進行微改模型。


微改基礎模型


PyTorch 中的 torchvision 里已經有很多常用的模型了,可以直接調用:


AlexNet


VGG


ResNet


SqueezeNet


DenseNet


import torchvision.models as models


resnet18 = models.resnet18()

alexnet = models.alexnet()


squeezenet = models.squeezenet1_0()


densenet = models.densenet_161()


但是對於我們的任務而言有些層並不是直接能用,需要我們微微改一下,比如,resnet 最後的全連接層是分 1000 類,而我們只有 21 類;又比如,resnet 第一層卷積接收的通道是 3, 我們可能輸入圖片的通道是 4,那麼可以通過以下方法修改:


resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)


resnet.fc = nn.Linear(2048, 21)


簡單預訓練


模型已經改完了,接下來我們就進行簡單預訓練吧。


我們先從 torchvision 中調用基本模型,載入預訓練模型,然後,重點來了,將其中的層直接替換為我們需要的層即可


# 原本為 1000 類,改為 10 類

resnet.fc = torch.nn.Linear(2048, 10)


其中使用了 pretrained 參數,會直接載入預訓練模型,內部實現和前文提到的載入預訓練的方法一樣。因為是先載入的預訓練參數,相當於模型中已經有參數了,所以替換掉最後一層即可。OK!


AI研習社按:本文作者ycszen,文章原載於作者的知乎專欄


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 唯物 的精彩文章:

一個實例告訴你:Kaggle 數據競賽都有哪些套路
深度學習之父的神經網路第五課!
生成對抗網路研究年度進展評述
開放Alexa語音控制技術 亞馬遜欲打造最有影響力的聊天機器人開發平台
史上最簡潔易懂教程 用Excel理解梯度下降

TAG:唯物 |

您可能感興趣

是時候好好介紹一下Beautyweek了!
Android P來了 是時候向Nexus說再見了
你被淘汰的時候,沒有人會say sorry
Weekly | 確認過眼神,是時候和假期講 bye bye 啦
用iPhoneX,為什麼現在,完全沒有iPhone4那時候的風光?
iPhone6,是到說再見的時候了
日本插畫設計師為韓國New Balance設計了一個NB boy,那麼問題來了,NB girl什麼時候出來?
當你還在嘲笑Louis Vuitton、Gucci狗年限定系列丑的時候,別人已經在研究用這種風格做爆款了
是時候和「亞健康」say goodbye了!
什麼時候使用 CountDownLatch
MacBook Air氣數已盡 蘋果是時候向前看了
iPhone X賣的火爆,但高價下不及iP7時候的輝煌
Everybody!是時候購置「新」年貨啦
現在這個時候,是否適合買iPhone X了?
Jasper與媽媽一起做瑜伽,應采兒小時候照片竟神似Angelababy
新款 Virgil Abloh x AJ I,什麼時候入手最合適?
Nintendo Labo這麼火,任天堂啥時候回歸主業呢?
當Jordan Brand談未來的時候,都在談些什麼?
老設備是時候say 88,Android P將不會支持Nexus 5X、6P
等「小蛙」回家的時候!你還有原價購買4雙Balenciaga Triple S的機會!