王宇龍:如何通過關鍵數據通路去理解網路行為?
AI科技評論按:神經網路長久以來的「黑盒」屬性,導致人們一直無法理解網路的內部是如何運作的,針對這個困擾已久的問題,學界主要存在三種研究方向:數據歸因模式、特徵解碼模式以及模型理解模式。
在近日的 AI 研習社公開課上,清華大學的王宇龍就從模型理解的角度入手,為我們詳細介紹了如何通過發現網路中的關鍵數據通路(critical data routing paths, CDRPs),更好地理解網路。
公開課回放地址:
http://www.mooc.ai/open/course/520
分享主題:利用關鍵數據通路理解神經網路
分享提綱:
方法介紹——distillation guided routing(DGR)演算法
結果分析——路由的通路包含一定語義含義,幫助我們更好地理解網路行為
應用領域——安全對抗樣本檢測
AI 科技評論將其分享內容整理如下:
這張圖大概總結了一下當前這種網路可解釋性的含義,我們常常把這種網路、深度神經網路看作一個black box model,也就是一個黑盒模型,就像圖片中所描述的一樣。
我們想知道網路裡面運行的時候到底在做什麼,學習完之後究竟學習到什麼樣的知識,能夠對我們人類有什麼樣的啟發,所以對於網路可理解性來說,目前有這樣三個主要的研究方向:
左邊綠色箭頭指向的,是說我將網路所做的決策,或者說預測結果,直接歸因到數據層面,去分析樣本中哪些數據、或者說數據當中哪一塊區域,它的特徵更加重要,我就直接拿來作為網路行為的一種解釋。
右邊藍色箭頭指向的則是第二種方向也就是將網路所學到的這種行為或者特徵和人類已有的這種像知識圖譜的一些概念去進行聯繫,使網路在進行決策的時候,我們人類能夠理解它究竟在做些什麼。
下面又是另外一種方向,我畫了一個顯微鏡的圖案來表示,我們要直接去探求網路內部究竟在做什麼,這個相對來說會比較難,因為它是一種更加直接的理解網路的方法——我們的工作就屬於這個方向。
我再細說一下這三個方向。
第一類方法是歸因到數據層面的方法,我們叫做attribution methods,中文為「溯因方法」,它的一個大致的流程是說,給一張圖片,或者給一個數據,網路幫我們做出了這個決策,那麼現在我要了解,究竟是數據的哪一面最終導致這個預測結果,然後我們就可以通過attribution methods去追溯和歸因到數據層面上去。像右上方第一張圖顯示的,就是幾種不同attribution methods的一個展示結果,主要通過一種叫做saliency maps(中文稱作「顯著圖」)的方式來進行展示,圖片上面可以看到像某種熱力圖一樣,顏色越深的,代表這一部分區域所做的貢獻越大,對最終預測結果起著最關鍵的影響。
第二類方法是歸因到知識上去,這種被稱作feature decoding的方式(中文「特徵解碼」),它將網路中間的中間層的特徵給解碼出來,然後轉換到一個可理解的概念上去。換句話說,我在預測結果的同時,同時產生一段文本去解釋這個預測結果產生的原因。比如說右面的第二張圖,它在預測每一種鳥類的同時,也給出了這些鳥類局部的一些特徵,比如說這個鳥之所以是某種鳥類,是因為它的脖子或者頭部或者嘴巴具備什麼樣的特徵,是一個比較接近人類的這樣一種解釋方式。不過,這種解釋方法迴避了模型本身的解釋,也就是說產生這種文本的解釋是用另外一種網路來去訓練的,那麼產生解釋的網路又該如何去理解呢?
第三類方法是直接去理解模型本身的行為。這一類方法目前沒有一個統一的範式,主要靠大家從多種不同的角度來做解釋。比如說我本來有一個很深的網路,或者說一個很複雜的模型,那麼我通過像知識蒸餾或者說模仿這種行為,去訓練一個更容易理解的模型,通過這個更容易理解的模型,針對某個局部面再去模仿原來模型的行為,是一種基於局部的一種解釋。或者說從一開始設計這種網路的時候,就是要設計成一個可解釋的模型,比如說每一步都分別對應一種語義含義行為的設計方法。
關鍵數據通路
我們的工作主要是通過發現網路中的關鍵數據通路(critical data routing paths,CDRPs),更好地去理解網路。我們從之前類似網路壓縮的一些工作中發現,網路中其實存在很多冗餘,並不是所有的節點或者說神經元都被利用到,這些神經元即便刪去,也不會影響最終的預測結果。因此我們認為,每個樣本進來的時候,網路都只是利用了其中一部分的節點或者說通路來完成最終預測的。
我們的工作希望能夠發掘出這些通路的特徵或者說規律。在定義關鍵數據通路以前,我們首先要定義的是關鍵節點,因為關鍵數據通路實際上是由這些通路上的關鍵節點所組成的。比如像右面這張圖顯示的,共有三個卷積層,每一層紅色代表的都是不重要的可以進行刪除的節點,而綠色則代表的關鍵節點。通過連接每一層的關鍵節點,我們也就組成了所謂的關鍵數據通路。
所謂的關鍵節點,我們可以理解成輸出的關鍵channel,比如像一個卷積層,它的輸出是個三維的增量,在長寬兩維上我們認為是一種空間上的信息保留,而channel維就是一個第三維的信息保留,我們認為它是包含了這種語義含義的,或者說是代表這種節點的概念。如果一個通道全部被置為零,對最終預測結果產生了很大影響的話,這個節點就是關鍵節點。
從方法上來說,我們首先引入了control gates(中文為控制門)的概念,這個想法受到過去的模型壓縮、模型減枝方法的一些啟發,channel-wise control gates說的是通道維上的每一維或者每一個通道都去關聯一個lambda(這個lambda是一個標量值,我們認為這個lambda就是一個control gates),一旦這個 lambda 為 0,最後的優化問題也同樣顯示為 0 的話,就默認它是一個不重要的節點,完全可以刪除掉。如果它是一個帶 0 的值,之後的值我們是不做限制的,因為值的大小代表了它在這一次預測中的重要性。
我們應該如何求解 lambda 的優化目標是什麼呢?其實我們是借鑒了這個所謂「知識蒸餾」的概念(Hinton在2015年所提出),也就是說網路在進行預測的時候,它所輸出的概率分布不只是包含預測結果,還包括了隱藏知識——網路里被認為包含其他類的概念有多少。我們的目標是要在刪掉這些不重要的節點以後,網路的數據概率值能夠盡量接近原始的網路數據概率值。
同一時間,我們要加上一定的正則項去約束control gate:一方面約束它大於0,非負數;另一方面則要約束它具備「稀疏」的特性。右邊的優化目標表示的是第一部分花體 L 的 loss function,度量的是兩次網路輸出的概率分布距離。
第一項這個f_ heta(x),是原始網路在獲得樣本後的一個輸出概率分布,是針對單樣本來說的(每次只考慮一個樣本)。第二項加了一個帶有 lambda 的式子,代表的是引入了control gate後,「減枝」網路的輸出概率分布和原始網路的輸出概率分布的接近程度——L。L 實際上就是一個 cross entropy,衡量兩個概率分布的距離。
接下來說一下lambda的約束限制。
第一項是要求它「非負」。lambda可以為 0,代表的是被完全抑制掉、刪除掉(假設該值大於0,小於1,說明比原來的響應相對要小一點,要是大於 0 且大於 1,說明它比原來的響應要更高一點,有一點放大的作用),但是我們不能讓它變為負數,因為負數相當於把整個 channel 的 activations 的全部正負號都交換了,相當於所有的值都取了一個相反數,我們認為這樣對原始網路輸出值的分布範圍會有較大影響,且會對最終行為存在較大幹擾。所以我們做網路可解釋性的首要條件,是在保證盡量少改動的情況下去解釋當前網路,一旦引入過多額外的干擾,你就很難保證說現在的解釋對於原來的網路還是成立的。
第二項是要求它具有一定的稀疏性,這個和已有的一些「稀疏學習」的部分主張是吻合的,可以理解為越稀疏的模型,它將這種不同的屬性都進行解耦並取了關鍵屬性,就越發具備可理解性。
路徑的表示
我們以上所說的是 distillation guided routing(DGR)的一個大致方法。接下來我再說一下,如何對最終尋找到的路徑進行表示。
每次優化完以後,每層它都有一個contol gate value,由粗體的lambda表示(大K表示的是網路擁有K層這樣一個概念),只要將所有的control gate value拼接成一個最終長的向量,就是我們對相關路徑的一個表示。因為我們可以直接對長的向量使用tresholding這一種取閾值的方法,來獲得最終的critical nodes。比如說我認為大於0.5的才是真正的critical nodes,小於0.5的則不是,那我們可以通過取義值,得到一個最終的二值mask,那麼它就代表了哪些可以被刪除,哪些可以被保留。
我們在後來的實驗中發現,這一種表示包含了非常豐富的信息——如果不取義值,只將它原始優化出來的浮點值保留下來的話,網路在進行預測的時候,我們將發現更加豐富的功能性過程(可以把它視為一種新的activations,網路響應都是一層一層傳到最高層,最高層的feature就可以看成一個響應,我們相當於側面在channel維上去引入了新的特徵表示)。
接下來我詳細說一下這一頁(PPT第4頁)的優化問題,我們應該如何進行求解。
求解的方法其實很簡單,就是通過梯度下降演算法,每一次根據優化目標對control gate value進行求導(原始網路的權重值都是固定不變的)。所以我們解釋一些已有模型(比如像VGG,Alexnet, ResNet),都是通過引入並求解control gate value,接下來當我們再去解釋或者優化時就會非常簡單,因為它需要更新的參數非常少,比如我們在實驗中只需設置30個iteration,就能得到一個很好的解釋結果。
在優化的過程當中,這些引入control gate value的網路預測,比如說top-1 prediction,也就是那個最大類別的響應,要和原始網路的預測保持一致。比如說原始網路它看圖片預測出來是狗,那麼新的網路也要保障它的預測結果是狗。至於其他類別的響應,我們則不做要求,因為既然是distillation,肯定就會存在一定程度的不同。總的來說,你在解釋的網路的時候,不該改變網路的原始行為。
接下來說一下對抗樣本檢測,我們之所以會將該方法用到這個任務上去,是因為我們發現,我們所找到的這個feature對於對抗樣本檢測有很大的幫助。
首先什麼叫對抗樣本?非常簡單,看下面這張圖,第一個是大熊貓,它被輸入進一個標準的網路裡面,被顯示為55.7%的一個預測信度,但是我在中間加了這個雜訊圖,最後得到一張新的圖片,再把這張新的圖片輸入到網路里時,結果預測為「長臂猿」,同時擁有很高的信度,達到99.3%。從人的視角來看,新生成的圖片跟原始圖片並沒有太大差別,這種現象我們就叫做對抗樣本,也就是說新的圖片對網路而言是具有「對抗性」或者說「攻擊性」的。
對抗樣本現象引發了人們對網路可理解性的關注,因為網路的「黑盒」特性使我們無從得知它為什麼會預測正確或者預測錯誤,而且這種錯誤的特性還特別不符合人類的直覺,人類無法理解說這樣一個雜訊為何能夠引起這麼大的一個改變。因此現在有大量的工作就是在做對抗的樣本攻擊以及對抗的樣本防禦。我們的組在這方面之前也是做了很多工作,在去年的NIPS 2017年有一個對抗攻防比賽,我們的組在攻擊和對抗方面都做到了第一。
我們接下來會利用關鍵數據通路去進行對抗樣本檢測。我們的思考是這樣的,兩種樣本在輸入端從人類的感覺上看來差別並不大,這也意味著前幾層所走的網路關鍵路徑按理來說差別不大。只是對抗樣本的雜訊越往高層走,它被干擾的程度不知因何被放大了,才導致路徑開始偏離,最終走到另一個類別上去,導致預測結果完全不一樣。
那麼我們其實可以訓練出某種分類器,專門用來檢測真實樣本與對抗樣本的關鍵數據通路。如果查出來差別,就有一定的概率檢測出它究竟是真實樣本還是對抗樣本。
接下來說一下實驗的部分。
我們首先做了一個定量實驗來檢驗方法的有效性,這個實驗叫做post-hoc interpretation(中文是「事後解釋」),就是針對網路最終的預測結果再做一次解釋(一張圖片只解釋一個)。在實驗中,數據集採用來自 ImagNet 的五萬張 validation images,訓練網路則用的AlexNet、VGG-16、ResNet-50等。
需要說明的是,實驗只聚焦在卷積層,因為類似 VGG-16、ResNet 的 fully-connected layers,我們認為是一個最終的分類器,所以不考慮這一層面的關鍵數據通路。再者,ResNet 的網路層較深,我們也不可能將所有的卷積層都考慮進來,太冗餘且沒有必要。所以對於 ResNet,我們的處理方法就是只關注 ResBlocks 的輸出,而這個 Block 的量相對較少,我們再根據這些 Block 的輸出去觀察它所利用到的關鍵節點。
給大家介紹這個實驗,當我們找到關鍵節點以後,我們將有序地抑制掉一部分的關鍵節點,然後再觀察它對網路最終造成多大程度的影響。
在操作上有兩種方式,一種是先刪除control gate value最大的,我們稱作Top Mode,或者反過來,我們先刪除control gate value最小的,這兩種刪除方式最後引起網路性能下降的一個曲線,在下面這兩張圖上展示。(註:control gate value越大,那麼說明它的影響/重要性越大)
可以看下上邊這張圖,橫坐標顯示的是被抑制的關鍵節點比例,我們可以看到,只有1%的關鍵節點被抑制(通道置為0),原模型的top-1 acc還有top-5 acc就會面臨非常劇烈的下降,分別是top-1 acc下降百分之三十多,top-5 acc下降百分之二十多。
也就是說,只要1%的關鍵節點,還不是所有節點(關鍵節點其實只佔網路節點的百分之十左右)被刪除的話,網路性能就會面臨劇烈的下降。在某種程度上來說,這個結果證明了我們所尋找到的關鍵節點的有效性。
節點的語義含義
其實我們更重要的工作成果在這一部分,那就是我們所尋找的節點其實包含了一定的語義含義,這是網路可解釋性領域一直在關注的。首先我們會關注層內的路由節點的語義含義,比如說一個樣本進來,它經過每一層,我們會先看每層有哪些節點,然後再看它擁有什麼樣的語義含義。
我們在上方展示了5張圖,每張圖上有五萬個點,對應的是五萬張圖片。不過我們都知道,網路里的channel維都是像512、256這樣一個向量,我們怎麼樣可以把這五萬個向量之間的相似性更直觀的展示出來呢?我們最終採用的是t-SNE方法,類似於說將一些向量投影到二維平面上去。投影的結果就像下面5張圖展示的,顏色代表類別,同樣類別的圖片所對應的點,顏色都是相同的。我們會看到,隨著層數加深,它的點也隨著變得更加稀疏起來,然而實際上點的數量是沒有改變的,依然還是五萬張,五萬個點。
為什麼會呈現這樣一個稀疏或者分離的現象呢?因為同個類別的點都聚集在同一處,距離也就變得更加靠近,所以看起來中間有很多空白的部分。這也說明,在越高層的地方,同個類別所走過的節點或路徑會越加相似,簡單來說就是貓走貓的路徑,狗走狗的節點。
這張圖全面地展示了VGG-16里13個卷積層每一層關鍵節點的二維圖,我們會看到,在底層里各個類別都混雜在一起,沒有特別明顯的區分,然而隨著層數變高,顏色會開始有規律地聚集到同一個區域,說明這些類別開始各走各自的路徑。越到高層越稀疏。
在知曉每層節點的語義情況下,我們想進一步了解由這些節點連接構成的關鍵數據路徑,究竟具備什麼樣的語義特徵。於是我們做了一個實驗,針對類內樣本(樣本都是屬於同一類的),我們將它們所有的CDRPs的特徵表示拿去做一個層次化聚集聚類,看看它們的CDRPs表徵究竟有什麼相似性。
上面的樹形圖,表示每個樣本之間的相似程度,越往底層,兩個樣本就越靠近,而越往高層,就越慢被聚到一起。縱坐標代表了兩個樣本的距離,裡面的相似顏色代表的是他們被聚成一個子類了。我們在看這些圖片的聚類情況會發現,如果圖片特徵很相似,那麼他們的CDRPs聚類結果也是很相似的。
另外還有一個很有趣的發現,像左邊這50張圖,應該是某種魚,魚的圖片有這樣的一些分布規律:魚處在中間位置,採用的是橫拍模式,另外還有一類圖片,則是垂釣愛好者手裡捧著魚蹲在地上拍照。我們發現,這兩類圖片都被歸到魚這個類別,然而實際上圖片的特徵存在很大的不同。
目前看來網路 features 是檢驗不出來這種差異的,因為它們最終都被預測為魚這一個類別。然而我們的 CDRPs 表徵就細緻發現了其中的差異,就體現在兩者所走的關鍵路徑其實是不一樣的。
像左邊這張圖有紅框框起來的4張圖,其實是通過 CDRPs 的所分析出來的類似 outliner 的圖片。如果仔細看,會發現其中有一張圖片是一個人抱著魚,但是方向卻被旋轉了90度,按理來說這是一個類似於雜訊一樣的存在,然而我們的CDRPs卻能把它歸類到魚的類別,只是所走的關鍵路徑和其他樣本有著不一樣的特徵,因此把它給聚類出來,變成一個發現。
像右面則是一個白頭鷹,中間第二個的聚類都是聚焦於鷹的頭部,而第三類則聚焦於鷹站在樹上,而左邊這個是單獨的 outliner,都是一些非常不清晰的圖像。
這裡展示的是更多的一些結果。
對抗樣本檢測的應用
最後呢,我們嘗試用來做對抗樣本檢,像我之前所說的,正常樣本與對抗樣本,從輸入端來說沒有太大差別,但是從最後的預測結果來說,是有很大區別的。在我們看來,是兩張圖片在網路里所走的關鍵路徑逐漸有了分歧,導致了最終的分開。
我們先看上方左邊這張圖,這張圖首先是一個正常樣本,加了雜訊以後,預測結果由貓變成了車輪。我們該如何體現這兩種關鍵路徑的區別呢?我們主要算的是這兩個樣本在不同層上所走的關鍵節點的相關性,我們先找到每一層各自的關鍵節點,然後有一個向量,然後根據這個向量去推算相關係數來表示兩個路徑的相似性。
上面這張圖裡的橘紅線代表了這個相似程度,可以發現對抗樣本對於正常樣本的相似性是隨著層數增高的,而大致趨勢是逐漸下降的。簡單來說,高層的相似性要比底層小得多。
我們又算了下對抗樣本對於目標類別,它們這些樣本所走的這些關鍵路徑的相似性,接著計算車輪這一類別樣本的路徑相似性。我們找來車輪這一類別的50張圖片,將這50張樣本的每個路徑都算一遍相關係數,上面這張圖叫做violinplot,展示的就是這50個係數的分布展示情況。
由於每個樣本之間存在差異,所有顯示結果有的高有的低。最後發現,隨著層數加深,目標類別的相似係數會越來越高。比如在最高層的地方,violinplot的最低點都要比原始橘紅色的點要高。這也就是說,對抗樣本在高層所走的路徑和目標類別所走的路徑是很相似的,後面幾張圖也是在闡述這樣一個情況,具體的情況大家可以細緻地去參考一下論文。
接著我們又去做對抗樣本檢測,檢測方法是通過取一些正常的樣本,比如說從 ImageNet 里挑出一千種類別,每一種類別取出1張圖片(有些實驗取出5張圖片,有些取出10張圖片等等),然後每一張圖片我們都產生一個對抗樣本(用的 FGSM 演算法),然後作為訓練集,接著用我們的演算法去算它的 CDRPs 表徵,再取一個二分類的分類器來檢測和判斷這個路徑是屬於正常樣本還是對抗樣本。
訓練結束以後,我們就用這個分類器來做對抗樣本檢測,換句話說,我們自己構造了一個包含正常樣本與對抗樣本的數據集,然後用訓練所得到的分類器來預測哪一個是正常樣本,哪一個是對抗樣本。
下面的表格展示了我們不同實驗室的實驗結果,這個值如果越高,越近於1,就說明這個分類越完美。隨著訓練樣本的增加,分類結果變得越來越好之餘,不同的二分類器所能達到的水準還是比較相似的(可能使用像 gradient boosting 或者 random forest 的方法會更好一些)。
結論
最後總結一下我今天所分享的內容,首先是我們提了一個全新的角度來進行網路可解釋性,也就是通過尋找關鍵數據路徑,我們會發現有一些語義含義包含在數據路徑裡頭。包括像層內節點,它會有一定的區分能力,而且隨著層數的增高,區分能力會逐漸加深。
同一時間,關鍵路徑又體現出類內樣本不同的輸入特徵,有助於幫助我們發現一些數據集當中的樣本問題。
最後我們提了一個新的對抗樣本檢測演算法,通過利用CDRPs的特徵來檢測它究竟是真實樣本還是對抗樣本。CDRPs反映出對抗樣本在高層與正常樣本的距離較遠,在底層與正常樣本距離較近這樣一種特徵模式,利用這種特徵模式我們可以進行檢測,達到一個很好的防禦效果。
※探討自然語言處理的商業落地:從基礎平台到數據演算法
※清華大學Thinker團隊在VLSI 2018發表兩款極低功耗AI晶元
TAG:AI科技評論 |