當前位置:
首頁 > 科技 > 專欄 | 深入理解圖注意力機制

專欄 | 深入理解圖注意力機制


機器之心DGL專欄


作者:張昊、李牧非、王敏捷、張崢

圖卷積網路 Graph Convolutional Network (GCN) 告訴我們將局部的圖結構和節點特徵結合可以在節點分類任務中獲得不錯的表現。美中不足的是 GCN 結合鄰近節點特徵的方式和圖的結構依依相關,這局限了訓練所得模型在其他圖結構上的泛化能力。

Graph Attention Network (GAT) 提出了用注意力機制對鄰近節點特徵加權求和。鄰近節點特徵的權重完全取決於節點特徵,獨立於圖結構。

在這個教程里我們將:





  • 解釋什麼是 Graph Attention Network



  • 演示用 DGL 實現這一模型



  • 深入理解學習所得的注意力權重



  • 初探歸納學習 (inductive learning)

難度:★★★★?(需要對圖神經網路訓練和 Pytorch 有基本了解)



在 GCN 里引入注意力機制

GAT 和 GCN 的核心區別在於如何收集並累和距離為 1 的鄰居節點的特徵表示。

在 GCN 里,一次圖卷積操作包含對鄰節點特徵的標準化求和:

其中 N(i) 是對節點 i 距離為 1 鄰節點的集合。我們通常會加一條連接節點 i 和它自身的邊使得 i 本身也被包括在 N(i) 里。

 

是一個基於圖結構的標準化常數;σ是一個激活函數(GCN 使用了 ReLU);W^((l)) 是節點特徵轉換的權重矩陣,被所有節點共享。由於 c_ij 和圖的機構相關,使得在一張圖上學習到的 GCN 模型比較難直接應用到另一張圖上。解決這一問題的方法有很多,比如 GraphSAGE 提出了一種採用相同節點特徵更新規則的模型,唯一的區別是他們將 c_ij 設為了|N(i)|。

圖注意力模型 GAT 用注意力機制替代了圖卷積中固定的標準化操作。以下圖和公式定義了如何對第 l 層節點特徵做更新得到第 l+1 層節點特徵:



圖 1:圖注意力網路示意圖和更新公式。


對於上述公式的一些解釋:





  • 公式(1)對 l 層節點嵌入

    做了線性變換,W^((l)) 是該變換可訓練的參數



  • 公式(2)計算了成對節點間的原始注意力分數。它首先拼接了兩個節點的 z 嵌入,注意 || 在這裡表示拼接;隨後對拼接好的嵌入以及一個可學習的權重向量 做點積;最後應用了一個 LeakyReLU 激活函數。這一形式的注意力機制通常被稱為加性注意力,區別於 Transformer 里的點積注意力。



  • 公式(3)對於一個節點所有入邊得到的原始注意力分數應用了一個 softmax 操作,得到了注意力權重。



  • 公式(4)形似 GCN 的節點特徵更新規則,對所有鄰節點的特徵做了基於注意力的加權求和。

出於簡潔的考量,在本教程中,我們選擇省略了一些論文中的細節,如 dropout, skip connection 等等。感興趣的讀者們歡迎參閱文末鏈接的模型完整實現。

本質上,GAT 只是將原本的標準化常數替換為使用注意力權重的鄰居節點特徵聚合函數。


GAT 的 DGL 實現

以下代碼給讀者提供了在 DGL 里實現一個 GAT 層的總體印象。別擔心,我們會將以下代碼拆分成三塊,並逐塊講解每塊代碼是如何實現上面的一條公式。


import

 torch

import

 torch.nn 

as

 nn

import

 torch.nn.functional 

as

 F

class

 

GATLayer(nn.Module)

:


    

def

 

__init__(self, g, in_dim, out_dim)

:


        super(GATLayer, self).__init__()
        self.g = g
        

# 公式 (1)


        self.fc = nn.Linear(in_dim, out_dim, bias=

False

)
        

# 公式 (2)


        self.attn_fc = nn.Linear(

2

 * out_dim, 

1

, bias=

False

)

    

def

 

edge_attention(self, edges)

:


        

# 公式 (2) 所需,邊上的用戶定義函數


        z2 = torch.cat([edges.src[

"z"

], edges.dst[

"z"

]], dim=

1

)
        a = self.attn_fc(z2)
        

return

 {

"e"

 : F.leaky_relu(a)}

    

def

 

message_func(self, edges)

:


        

# 公式 (3), (4)所需,傳遞消息用的用戶定義函數


        

return

 {

"z"

 : edges.src[

"z"

], 

"e"

 : edges.data[

"e"

]}

    

def

 

reduce_func(self, nodes)

:


        

# 公式 (3), (4)所需, 歸約用的用戶定義函數


        

# 公式 (3)


        alpha = F.softmax(nodes.mailbox[

"e"

], dim=

1

)
        

# 公式 (4)


        h = torch.sum(alpha * nodes.mailbox[

"z"

], dim=

1

)
        

return

 {

"h"

 : h}

    

def

 

forward(self, h)

:


        

# 公式 (1)


        z = self.fc(h)
        self.g.ndata[

"z"

] = z
        

# 公式 (2)


        self.g.apply_edges(self.edge_attention)
        

# 公式 (3) & (4)


        self.g.update_all(self.message_func, self.reduce_func)
        

return

 self.g.ndata.pop(

"h"

)


實現公式 (1)



第一個公式相對比較簡單。線性變換非常常見。在 PyTorch 里,我們可以通過 torch.nn.Linear 很方便地實現。



實現公式 (2)



原始注意力權重 e_ij 是基於一對鄰近節點 i 和 j 的表示計算得到。我們可以把注意力權重 e_ij 看成在 i->j 這條邊的數據。因此,在 DGL 里,我們可以使用 g.apply_edges 這一 API 來調用邊上的操作,用一個邊上的用戶定義函數來指定具體操作的內容。我們在用戶定義函數里實現了公式(2)的操作:


 

def

 

edge_attention(self, edges)

:


        

# 公式 (2) 所需,邊上的用戶定義函數


        z2 = torch.cat([edges.src[

"z"

], edges.dst[

"z"

]], dim=

1

)
        a = self.attn_fc(z2)
        

return

 {

"e"

 : F.leaky_relu(a)}

公式中的點積同樣藉由 PyTorch 的一個線性變換 attn_fc 實現。注意 apply_edges 會把所有邊上的數據打包為一個張量,這使得拼接和點積可以並行完成。



實現公式 (3) 和 (4)



類似 GCN,在 DGL 里我們使用 update_all API 來觸發所有節點上的消息傳遞函數。update_all 接收兩個用戶自定義函數作為參數。message_function 發送了兩種張量作為消息:消息原節點的 z 表示以及每條邊上的原始注意力權重。reduce_function 隨後進行了兩項操作:



  1. 使用 softmax 歸一化注意力權重(公式(3))。



  2. 使用注意力權重聚合鄰節點特徵(公式(4))。

這兩項操作都先從節點的 mailbox 獲取了數據,隨後在數據的第二維(dim = 1 ) 上進行了運算。注意數據的第一維代表了節點的數量,第二維代表了每個節點收到消息的數量。


 

def

 

reduce_func(self, nodes)

:


        

# 公式 (3), (4)所需, 歸約用的用戶定義函數


        

# 公式 (3)


        alpha = F.softmax(nodes.mailbox[

"e"

], dim=

1

)
        

# 公式 (4)


        h = torch.sum(alpha * nodes.mailbox[

"z"

], dim=

1

)
        

return

 {

"h"

 : h}


多頭注意力 (Multi-head attention)

神似卷積神經網路里的多通道,GAT 引入了多頭注意力來豐富模型的能力和穩定訓練的過程。每一個注意力的頭都有它自己的參數。如何整合多個注意力機制的輸出結果一般有兩種方式:

以上式子中 K 是注意力頭的數量。作者們建議對中間層使用拼接對最後一層使用求平均。

我們之前有定義單頭注意力的 GAT 層,它可作為多頭注意力 GAT 層的組建單元:


class

 

MultiHeadGATLayer(nn.Module)

:


    

def

 

__init__(self, g, in_dim, out_dim, num_heads, merge=

"cat"

)

:


        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        

for

 i 

in

 range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    

def

 

forward(self, h)

:


        head_outs = [attn_head(h) 

for

 attn_head 

in

 self.heads]
        

if

 self.merge == 

"cat"

:
            

# 對輸出特徵維度(第1維)做拼接


            

return

 torch.cat(head_outs, dim=

1

)
        

else

:
            

# 用求平均整合多頭結果


            

return

 torch.mean(torch.stack(head_outs))

在 Cora 數據集上訓練一個 GAT 模型


Cora 是經典的文章引用網路數據集。Cora 圖上的每個節點是一篇文章,邊代表文章和文章間的引用關係。每個節點的初始特徵是文章的詞袋(Bag of words)表示。其目標是根據引用關係預測文章的類別(比如機器學習還是遺傳演算法)。在這裡,我們定義一個兩層的 GAT 模型:


class

 

GAT(nn.Module)

:


    

def

 

__init__(self, g, in_dim, hidden_dim, out_dim, num_heads)

:


        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        

# 注意輸入的維度是 hidden_dim * num_heads 因為多頭的結果都被拼接在了


        

# 一起。 此外輸出層只有一個頭。


        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 

1

)

    

def

 

forward(self, h)

:


        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        

return

 h

我們使用 DGL 自帶的數據模塊載入 Cora 數據集。



from

 dgl 

import

 DGLGraph

from

 dgl.data 

import

 citation_graph 

as

 citegrh

def

 

load_cora_data()

:


    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.ByteTensor(data.train_mask)
    g = DGLGraph(data.graph)
    

return

 g, features, labels, mask

模型訓練的流程和 GCN 教程里的一樣。


import

 time

import

 numpy 

as

 np
g, features, labels, mask = load_cora_data()

# 創建模型


net = GAT(g, 
          in_dim=features.size()[

1

], 
          hidden_dim=

8


          out_dim=

7


          num_heads=

8

)
print(net)

# 創建優化器


optimizer = torch.optim.Adam(net.parameters(), lr=

1e-3

)

# 主流程


dur = []

for

 epoch 

in

 range(

30

):
    

if

 epoch >=

3

:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 

1

)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    

if

 epoch >=

3

:
        dur.append(time.time() - t0)

    print(

"Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}"

.format(
            epoch, loss.item(), np.mean(dur)))

可視化並理解學到的注意力

Cora 數據集

以下表格總結了 GAT 論文以及 dgl 實現的模型在 Cora 數據集上的表現:

可以看到 DGL 能完全復現原論文中的實驗結果。對比圖卷積網路 GCN,GAT 在 Cora 上有 2~3 個百分點的提升。

不過,我們的模型究竟學到了怎樣的注意力機制呢?

由於注意力權重

與圖上的邊密切相關,我們可以通過給邊著色來可視化注意力權重。以下圖片中我們選取了 Cora 的一個子圖並且在圖上畫出了 GAT 模型最後一層的注意力權重。我們根據圖上節點的標籤對節點進行了著色,根據注意力權重的大小對邊進行了著色(可參考圖右側的色條)。




圖 2:Cora 數據集上學習到的注意力權重。

乍看之下模型似乎學到了不同的注意力權重。為了對注意力機制有一個全局觀念,我們衡量了注意力分布的熵。對於節點 i,{α_ij }_(j∈N(i)) 構成了一個在 i 鄰節點上的離散概率分布。它的熵被定義為:






直觀地說,熵低代表了概率高度集中,反之亦然。熵為 0 則所有的注意力都被放在一個點上。均勻分布具有最高的熵(log N(i))。在理想情況下,我們想要模型習得一個熵較低的分布(即某一、兩個節點比其它節點重要的多)。注意由於節點的入度不同,它們注意力權重的分布所能達到的最大熵也會不同。

基於圖中所有節點的熵,我們畫了所有頭注意力的直方圖。




圖 3:Cora 數據集上學到的注意力權重直方圖。

作為參考,下圖是在所有節點的注意力權重都是均勻分布的情況下得到的直方圖。



出人意料的,

模型學到的節點注意力權重非常接近均勻分布

(換言之,所有的鄰節點都獲得了同等重視)。這在一定程度上解釋了為什麼在 Cora 上 GAT 的表現和 GCN 非常接近(在上面表格里我們可以看到兩者的差距平均下來不到 2%)。由於沒有顯著區分節點,注意力並沒有那麼重要。

這是否說明了注意力機制沒什麼用?不!在接下來的數據集上我們觀察到了完全不同的現象。

蛋白質交互網路 (PPI)

PPI(蛋白質間相互作用)數據集包含了 24 張圖,對應了不同的人體組織。節點最多可以有 121 種標籤(比如蛋白質的一些性質、所處位置等)。因此節點標籤被表示為有 121 個元素的二元張量。數據集的任務是預測節點標籤。

我們使用了 20 張圖進行訓練,2 張圖進行驗證,2 張圖進行測試。平均下來每張圖有 2372 個節點。每個節點有 50 個特徵,包含定位基因集合、特徵基因集合以及免疫特徵。至關重要的是,測試用圖在訓練過程中對模型完全不可見。這一設定被稱為歸納學習。

我們比較了 dgl 實現的 GAT 和 GCN 在 10 次隨機訓練中的表現。模型的超參數在驗證集上進行了優化。在實驗中我們使用了 micro f1 score 來衡量模型的表現。

在訓練過程中,我們使用了 BCEWithLogitsLoss 作為損失函數。下圖繪製了 GAT 和 GCN 的學習曲線;顯然 GAT 的表現遠優於 GCN。




圖 4:PPI 數據集上 GCN 和 GAT 學習曲線比較。

像之前一樣,我們可以通過繪製節點注意力分布之熵的直方圖來有一個統計意義上的直觀了解。以下我們基於一個 3 層 GAT 模型中不同模型層不同注意力頭繪製了直方圖。

第一層學到的注意力



第二層學到的注意力



最後一層學到的注意力



作為參考,下圖是在所有節點的注意力權重都是均勻分布的情況下得到的直方圖。



可以很明顯地看到,

GAT 在 PPI 上確實學到了一個尖銳的注意力權重分布

。與此同時,GAT 層與層之間的注意力也呈現出一個清晰的模式:在中間層隨著層數的增加註意力權重變得愈發集中;最後的輸出層由於我們對不同頭結果做了平均,注意力分布再次趨近均勻分布。


不同於在 Cora 數據集上非常有限的收益,GAT 在 PPI 數據集上較 GCN 和其它圖模型的變種取得了明顯的優勢(根據原論文的結果在測試集上的表現提升了至少 20%)。我們的實驗揭示了 GAT 學到的注意力顯著區別於均勻分布。雖然這值得進一步的深入研究,一個由此而生的假設是 GAT 的優勢在於處理更複雜領域結構的能力。



拓展閱讀

到目前為止我們演示了如何用 DGL 實現 GAT。簡介起見,我們忽略了 dropout, skip connection 等一些細節。這些細節很常見且獨立於 DGL 相關的概念。有興趣的讀者歡迎參閱完整的代碼實現。



  • 經過優化的完整代碼實現:https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py



  • 在下一個教程中我們將介紹如何通過並行多頭注意力和稀疏矩陣向量乘法來加速 GAT 模型,敬請期待!


關於 DGL 專欄: DGL 是一款全新的面向圖神經網路的開源框架。通過該專欄,我們 DGL 團隊希望和大家一起學習圖神經網路的最新進展。同時展示 DGL 的靈活性和高效性。通過系統學習演算法,通過演算法理解系統。

更多 DGL 專欄信息,請查看機器之心官網,或者點擊閱讀原文。




本文為機器之心專欄,

轉載請聯繫本公眾號獲得授權



?------------------------------------------------


加入機器之心(全職記者 / 實習生):hr@jiqizhixin.com


投稿或尋求報道:

content

@jiqizhixin.com


廣告 & 商務合作:bd@jiqizhixin.com

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

9102年,你已經是個大春節了,你要自己學會用AI了
AI貓窩:一位工程師鏟屎官給流浪貓主子們的賀年禮

TAG:機器之心 |