當前位置:
首頁 > 最新 > DenseNet:比ResNet更優的CNN模型

DenseNet:比ResNet更優的CNN模型

作者:葉 虎

編輯:祝鑫泉

前 言

在計算機視覺領域,卷積神經網路(CNN)已經成為最主流的方法,比如最近的GoogLenet,VGG-19,Incepetion等模型。CNN史上的一個里程碑事件是ResNet模型的出現,ResNet可以訓練出更深的CNN模型,從而實現更高的準確度。ResNet模型的核心是通過建立前面層與後面層之間的「短路連接」(shortcuts,skip connection),這有助於訓練過程中梯度的反向傳播,從而能訓練出更深的CNN網路。今天我們要介紹的是DenseNet模型,它的基本思路與ResNet一致,但是它建立的是前面所有層與後面層的密集連接(dense connection),它的名稱也是由此而來。DenseNet的另一大特色是通過特徵在channel上的連接來實現特徵重用(feature reuse)。這些特點讓DenseNet在參數和計算成本更少的情形下實現比ResNet更優的性能,DenseNet也因此斬獲CVPR 2017的最佳論文獎。本篇文章首先介紹DenseNet的原理以及網路架構,然後講解DenseNet在Pytorch上的實現。

01


設計理念

相比ResNet,DenseNet提出了一個更激進的密集連接機制:即互相連接所有的層,具體來說就是每個層都會接受其前面所有層作為其額外的輸入。圖1為ResNet網路的連接機制,作為對比,圖2為DenseNet的密集連接機制。可以看到,ResNet是每個層與前面的某層(一般是2~3層)短路連接在一起,連接方式是通過元素級相加。而在DenseNet中,每個層都會與前面所有層在channel維度上連接(concat)在一起(這裡各個層的特徵圖大小是相同的,後面會有說明),並作為下一層的輸入。對於一個L層的網路,DenseNet共包含個連接,相比ResNet,這是一種密集連接。而且DenseNet是直接concat來自不同層的特徵圖,這可以實現特徵重用,提升效率,這一特點是DenseNet與ResNet最主要的區別。

圖1 ResNet網路的短路連接機制(其中+代表的是元素級相加操作)

圖2 DenseNet網路的密集連接機制(其中c代表的是channel級連接操作)

如果用公式表示的話,傳統的網路在L層的輸出為:

而對於ResNet,增加了來自上一層輸入的identity函數:

在DenseNet中,會連接前面所有層作為輸入:

其中,上面的代表是非線性轉化函數(non-liear transformation),它是一個組合操作,其可能包括一系列的BN(Batch Normalization),ReLU,Pooling及Conv操作。注意這裡L層與層之間可能實際上包含多個卷積層。

圖3 DenseNet的前向過程

CNN網路一般要經過Pooling或者stride>1的Conv來降低特徵圖的大小,而DenseNet的密集連接方式需要特徵圖大小保持一致。為了解決這個問題,DenseNet網路中使用DenseBlock+Transition的結構,其中DenseBlock是包含很多層的模塊,每個層的特徵圖大小相同,層與層之間採用密集連接方式。而Transition模塊是連接兩個相鄰的DenseBlock,並且通過Pooling使特徵圖大小降低。圖4給出了DenseNet的網路結構,它共包含4個DenseBlock,各個DenseBlock之間通過Transition連接在一起。

02


網路結構

如前所示,DenseNet的網路結構主要由DenseBlock和Transition組成,如圖5所示。下面具體介紹網路的具體實現細節。

圖6 DenseNet的網路結構

在DenseBlock中,各個層的特徵圖大小一致,可以在channel維度上連接。DenseBlock中的非線性組合函數採用的BN+ReLU+3x3 Conv的結構,如圖6所示。另外值得注意的一點是,與ResNet不同,所有DenseBlock中各個層卷積之後均輸出個k特徵圖,即得到的特徵圖的channel數為k,或者說採用k個卷積核。k在DenseNet稱為growth rate,這是一個超參數。一般情況下使用較小的k(比如12),就可以得到較佳的性能。假定輸入層的特徵圖的channel數為,那麼L層輸入的channel數為,因此隨著層數增加,儘管k設定得較小,DenseBlock的輸入會非常多,不過這是由於特徵重用所造成的,每個層僅有k個特徵是自己獨有的。

圖6 DenseBlock中的非線性轉換結構

由於後面層的輸入會非常大,DenseBlock內部可以採用bottleneck層來減少計算量,主要是原有的結構中增加1x1 Conv,如圖7所示,即BN+ReLU+1x1 Conv+BN+ReLU+3x3 Conv,稱為DenseNet-B結構。其中1x1 Conv得到4k個特徵圖它起到的作用是降低特徵數量,從而提升計算效率。

圖7 使用bottleneck層的DenseBlock結構

對於Transition層,它主要是連接兩個相鄰的DenseBlock,並且降低特徵圖大小。Transition層包括一個1x1的卷積和2x2的AvgPooling,結構為BN+ReLU+1x1 Conv+2x2 AvgPooling。另外,Transition層可以起到壓縮模型的作用。假定Transition的上接DenseBlock得到的特徵圖channels數為m,Transition層可以產生個特徵(通過卷積層),其中是壓縮係數(compression rate)。當時,特徵個數經過Transition層沒有變化,即無壓縮,而當壓縮係數小於1時,這種結構稱為DenseNet-C,文中使用。對於使用bottleneck層的DenseBlock結構和壓縮係數小於1的Transition組合結構稱為DenseNet-BC。

DenseNet共在三個圖像分類數據集(CIFAR,SVHN和ImageNet)上進行測試。對於前兩個數據集,其輸入圖片大小為32*32,所使用的DenseNet在進入第一個DenseBlock之前,首先進行進行一次3x3卷積(stride=1),卷積核數為16(對於DenseNet-BC為2K)。DenseNet共包含三個DenseBlock,各個模塊的特徵圖大小分別為32*32,16*16和8*8,每個DenseBlock裡面的層數相同。最後的DenseBlock之後是一個global AvgPooling層,然後送入一個softmax分類器。注意,在DenseNet中,所有的3x3卷積均採用padding=1的方式以保證特徵圖大小維持不變。對於基本的DenseNet,使用如下三種網路配置:,,。而對於DenseNet-BC結構,使用如下三種網路配置,,。這裡的L指的是網路總層數(網路深度),一般情況下,我們只把帶有訓練參數的層算入其中,而像Pooling這樣的無參數層不納入統計中,此外BN層儘管包含參數但是也不單獨統計,而是可以計入它所附屬的卷積層。對於普通的網路,除去第一個卷積層、2個Transition中卷積層以及最後的Linear層,共剩餘36層,均分到三個DenseBlock可知每個DenseBlock包含12層。其它的網路配置同樣可以算出各個DenseBlock所含層數。

對於ImageNet數據集,圖片輸入大小為224*224,網路結構採用包含4個DenseBlock的DenseNet-BC,其首先是一個stride=2的7x7卷積層(卷積核數為2K),然後是一個stride=2的3x3 MaxPooling層,後面才進入DenseBlock。ImageNet數據集所採用的網路配置如表1所示:

03

實驗結果與討論

這裡給出DenseNet在CIFAR-100和ImageNet數據集上與ResNet的對比結果,如圖8和9所示。從圖8中可以看到,只有0.8M的DenseNet-100性能已經超越ResNet-1001,並且後者參數大小為10.2M。而從圖9中可以看出,同等參數大小時,DenseNet也優於ResNet網路。其它實驗結果見原論文。

圖8 在CIFAR-100數據集上ResNet vs DenseNet

圖9 在ImageNet數據集上ResNet vs DenseNet

綜合來看,DenseNet的優勢主要體現在以下幾個方面:

由於密集連接方式,DenseNet提升了梯度的反向傳播,使得網路更容易訓練。由於每層可以直達最後的誤差信號,實現了隱式的「deep supervision」;超鏈接:https://arxiv.org/abs/1409.5185

參數更小且計算更高效,這有點違反直覺,由於DenseNet是通過concat特徵來實現短路連接,實現了特徵重用,並且採用較小的growth rate,每個層所獨有的特徵圖是比較小的;

由於特徵復用,最後的分類器使用了低級特徵。

要注意的一點是,如果實現方式不當的話,DenseNet可能耗費很多GPU顯存,一種高效的實現如圖10所示,更多細節可以見這篇論文Memory-Efficient Implementation of DenseNets,超鏈接:https://arxiv.org/abs/1707.06990。不過我們下面使用Pytorch框架可以自動實現這種優化。

圖10 DenseNet的更高效實現方式

04


使用Pytorch實現Denseet

這裡我們採用Pytorch框架(https://pytorch.org/)來實現DenseNet,目前它已經支持Windows系統。對於DenseNet,Pytorch在torchvision.models模塊(https://github.com/pytorch/vision/tree/master/torchvision/models)里給出了官方實現,這個DenseNet版本是用於ImageNet數據集的DenseNet-BC模型,下面簡單介紹實現過程。

首先實現DenseBlock中的內部結構,這裡是BN+ReLU+1x1 Conv+BN+ReLU+3x3 Conv結構,最後也加入dropout層以用於訓練過程。

據此,實現DenseBlock模塊,內部是密集連接方式(輸入特徵數線性增長):

此外,我們實現Transition層,它主要是一個卷積層和一個池化層:

最後我們實現DenseNet網路:

選擇不同網路參數,就可以實現不同深度的DenseNet,這裡實現DenseNet-121網路,而且Pytorch提供了預訓練好的網路參數:

下面,我們使用預訓練好的網路對圖片進行測試,這裡給出top-5預測值:

給出的預測結果為:

05

小結

這篇文章詳細介紹了DenseNet的設計理念以及網路結構,並給出了如何使用Pytorch來實現。值得注意的是,DenseNet在ResNet基礎上前進了一步,相比ResNet具有一定的優勢,但是其卻並沒有像ResNet那麼出名(吃顯存問題?深度不能太大?)。期待未來有更好的網路模型出現吧!

06

參考文獻

1.DenseNet-CVPR-Slides.

超鏈接:http://www.cs.cornell.edu/~gaohuang/papers/DenseNet-CVPR-Slides.pdf

2.Densely Connected Convolutional Networks.

超鏈接:https://arxiv.org/abs/1608.06993

END

機器學習演算法工程師

一個用心的公眾號

進群,學習,得幫助

你的關注,我們的熱度,

我們一定給你學習最大的幫助

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器學習演算法工程師 的精彩文章:

TAG:機器學習演算法工程師 |