一、簡介
圖注意力網路(GAT,Graph Attention Networks),是一種基於注意力機制(Attention Mechanism)的圖神經網路(Graph Neural Network),由Petar Veličković等人於2017年提出。與傳統的圖卷積神經網路(Graph Convolutional Network,簡稱GCN)相比,GAT不僅考慮了每個節點本身的特徵,還考慮了該節點與其周圍節點之間的關係,從而表達出全局信息。該網路在圖節點分類、節點聚類和圖分類等領域表現出色,被廣泛應用於社交網路、推薦系統、蛋白質結構預測等領域。
二、注意力機制
注意力機制是GAT中最重要的組成部分之一,該機制通過為每個節點學習一個權重分布來動態地調整網路中每個節點的重要性。在基本的注意力機制中,目標節點與周圍節點之間的相似性通過計算內積來量化:
1. 基本的注意力機制
def attention(input_x, neighboors_features):
# 卷積核初始化,設輸入特徵維度為d,分為h組
w = tf.Variable(tf.zeros([input_x.shape[-1], input_x.shape[-1]]))
a = tf.Variable(tf.zeros([2*input_x.shape[-1], 1]))
# 卷積核賦值,初始化
tf.compat.v1.random_normal_initializer(stddev=0.1)()(w)
tf.compat.v1.random_normal_initializer(stddev=0.1)()(a)
# 節點矩陣X,和鄰居節點矩陣
h = tf.matmul(input_x, w) # (N, h)
f_1, f_2 = tf.reshape(h, [-1, 1, h.shape[-1]]), \
tf.reshape(neighboors_features, [-1, neighboors_features.shape[1], h.shape[-1]])
# self-attention 注意力機制本質上就是計算出相似性,這裡比較的是節點自身的特徵以及相鄰節點的特徵
# 相似度結果經過 softmax,使得相似度在每行之間的分布是一個概率分布,最後使用相似度的分布對鄰居特徵進行加權平均。
# 然後進行再一次卷積並剪枝
attention_lev_1 = tf.nn.softmax(tf.reduce_sum(tf.multiply(f_1, f_2), axis=-1)) # (N, M)
attention_lev_2 = tf.multiply(tf.reshape(attention_lev_1, [-1, neighboors_features.shape[1], 1]), neighboors_features)
h_level_2 = tf.reduce_sum(attention_lev_2, axis=1) # (N, h)
return tf.nn.relu(tf.matmul(h_level_2, w))
一個節點的鄰居節點特徵與節點本身特徵加權平均後得到該節點的新特徵向量。
2. 多頭注意力機制
在GAT中,作者提出了多頭注意力機制,引入了一個超參數heads,代表了圖中每個節點可以學到不同注意力權重分布,進而提取具有不同重要性特徵:
class MultiHeadAttention(tf.keras.Model):
def __init__(self, num_head, embedding_size, output_dim, feature_map=None, activation=tf.nn.elu):
super(MultiHeadAttention, self).__init__()
self.num_head = num_head
self.embedding_size = embedding_size
self.output_dim = output_dim
self.feature_map = feature_map
self.activation = activation
# 計算特徵向量
self.embedding_w = tf.Variable(tf.compat.v1.random_normal([self.embedding_size, self.output_dim], stddev=0.1))
self.embedding_b = tf.Variable(tf.zeros([self.output_dim]))
# 計算注意力權重
self.attention_w = tf.Variable(tf.compat.v1.random_normal([self.output_dim, 1], stddev=0.1))
self.attention_b = tf.Variable(tf.zeros([self.num_head, 1]))
def call(self, input_x, neighboors_features):
assert isinstance(neighboors_features, list)
input_x_ = tf.reshape(tf.matmul(input_x, self.embedding_w), [-1, self.num_head, self.output_dim])
n_features = [tf.reshape(tf.matmul(f, self.embedding_w) , [-1, self.num_head, self.output_dim]) for f in neighboors_features]
# 每一個頭都有自己的注意力層。注意力權重通過對節點和鄰居節點的特徵的內積得到。
# Attention weights are calculated by dot product of node features and neighbor node features.
att_val = tf.concat([tf.matmul(tf.nn.leaky_relu(tf.matmul(input_x_, tf.tile(tf.expand_dims(self.attention_w, 0), [tf.shape(input_x_)[0], 1, 1])) + tf.tile(self.attention_b, [tf.shape(input_x_)[0], 1, 1])),
tf.concat([tf.matmul(tf.nn.leaky_relu(tf.matmul(input_x_, tf.tile(tf.expand_dims(self.attention_w, 0), [tf.shape(input_x_)[0], 1, 1])) + tf.tile(self.attention_b, [tf.shape(input_x_)[0], 1, 1])),
tf.matmul(tf.nn.leaky_relu(tf.matmul(nf, tf.tile(tf.expand_dims(self.attention_w, 0), [tf.shape(nf)[0], 1, 1])) + tf.tile(self.attention_b, [tf.shape(nf)[0], 1, 1]))], axis=1)], axis=1)
att_val = self.activation(att_val + tf.tile(self.embedding_b, [tf.shape(input_x_)[0], 1, 1]))
# 最後通過一個全連接層進行融合
out = tf.reduce_mean(att_val, axis=1)
if self.feature_map:
out = tf.concat([self.feature_map(input_x), out], axis=1)
return out
三、GAT 的應用
圖注意力網路已經被廣泛應用於實際的場景中。
1. 社交網路
社交網路中存在好友、關注等關係,每個用戶的影響力都不同,GAT可以將好友之間的關係視作圖數據,學習每個用戶與其好友需關注的重要性,從而進行信息傳播和影響力分析。
2. 推薦系統
推薦系統中用戶和物品都可以看作圖中的節點,節點之間的邊表示用戶和物品之間的交互行為,GAT可以對節點進行特徵抽取,通過對相鄰節點進行注意力加權來提升推薦效果。
3. 蛋白質結構預測
蛋白質結構預測是生物學中的重要問題之一,無法通過傳統的生物實驗方法進行預測。GAT可以將蛋白質中的氨基酸、二級結構等信息視作圖數據,學習節點之間的關係特徵,通過注意力機制來預測蛋白質結構。
四、總結
圖注意力網路是一種基於注意力機制的圖神經網路,適用於處理任意類型的圖結構數據。其多頭注意力機制能夠學習不同注意力權重分布,提取具有不同重要性的特徵。該網路已經被廣泛應用於社交網路、推薦系統、蛋白質結構預測等領域。
原創文章,作者:HNJMS,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/371048.html