探索用于图分类的图胶囊网络 | Exploring graph capsual network for graph classification
摘要
图神经网络(GNN) 因其在拓扑结构建模和特征信息聚合学习图表示方面的强大能力而受到极大关注。然而,从GNN学习到的标量节点表示可能不足以有效地保留节点/图特征的属性,从而导致次优图表示。重复平均收集了过多的噪声,使得不同类的节点特征过度混合,导致过平滑问题。受Hinton提出的胶囊网络概念的启发,作者们提出了一种新的图分类框架CapsualGNN,该框架充分利用了图神经网络和胶囊网络的优点。具体来说,首先将节点表示为胶囊组,每个胶囊提取对应节点的不同特征。然后,利用路由机制,通过对每个图生成多个嵌入来捕获图级的重要信息和属性,并利用注意机制来聚焦重要特征。最后,为了解决过平滑的问题,引入了GCN的类残差连接。此外,作者们还引入了参数来区分自连接节点和其他节点。本文利用生物信息和社会网络上的6个图数据集对该框架进行了评估,并证明了CapsualGNN在图分类任务上优于其他SOTA技术。
背景
GNN的动机主要来自于CNN和图嵌入。GNN学习节点表示,并通过迭代聚合其拓扑邻域的特征来更新它们,因此GNN可以捕获节点的局部结构信息。然后采用最大池法或平均池法得到图的表示形式。然而,1) 将所有相邻的特征相加或平均,使GNN缺乏可解释性。2) 由于不可能确定模型中使用的节点特征的最重要部分,因此需要复杂的事后方法来解释特定应用域的GNN。3) 此外,从GNN中学到的标量节点表示可能不足以有效地保留节点/图特征的属性,导致图嵌入的次优。故而,基于向量的胶囊网络(Capsule Network, vector-based neural network)被引入以解决上述问题。然而,现有的胶囊网络方法中也存在如下问题:
- 图中没有从节点到节点本身的边,因此,图中存在与其他边具有相同效果的自连通边是不合理的
- 当堆叠多个消息传递层来捕获远程邻居信息时,可能会聚合过多的噪声
- 在GNN中,节点特征在不同的类中难以区分,导致过度平滑
贡献
- 作者引入了参数来区分自连接节点和其他节点,它可以反映自连接节点及其邻居节点在消息传递时的不同重要性
- 作者们为GCN引入类残差连接,解决了在获取远距离邻域信息时,堆叠多个消息传递层从而导致噪声过度聚集造成的过平滑问题
相关工作
图神经网络 (Graph Neural Networks, GNNs)
图卷积神经网络主要有两种类型。一种是基于空间域,另一种是基于光谱域。
- 空间卷积将图像上的卷积运算推广为图,聚合方法可采用加权和。使用空间卷积的模型有DCNN、GAT、GIN、GraphSAGE等。
频谱域卷积中,采用傅立叶变换和傅立叶反变换。谱域卷积模型包括ChebNet、GCN、GCAPS-CNN等。
GCN变体: 1)开发新的图滤波器,增强图卷积运算的效果,更合理地聚合有效的图信息 2)设计适当的图池操作来选择网络中传输的信息 ChebNet: 采用k阶截断切比雪夫多项式作为图滤波器,不仅可以利用切比雪夫多项式的递归计算性质简化计算过程,而且可以从中心节点聚合k-hop邻域内的节点信息 DCNN: 使用均值池作为卷积层后的池化层,聚合所有节点的表示,形成图表示 GIN讨论了各种池操作对结果的影响,发现总和池化比最大池化和平均值池化更好
问题: 这些池化操作仍然会丢失一些关键信息,并且这些方法使用的卷积是标量卷积,不能度量特征之间的空间关系。如图1所示,左右图的结构不同,但均值池化后的值是相同的,此时无法区分两图。
胶囊网络 (Capsule Network)
利用胶囊网络可以改善池化操作丢失部分信息的缺陷。改进的动态路由池操作可以避免池化过程中信息的丢失,因为它考虑了每个节点,可以通过不同的权值知道不同的图结构。引入胶囊可以让我们使用路由机制来生成高级特征,与CNN的最大池化(max-pooling)中除了最活跃的信息外所有信息都被丢弃相比,路由保留了来自低级胶囊的所有信息,并将它们路由到最近的高级胶囊。此外,这允许用多个嵌入来建模每个图,并且每个嵌入反映了图的不同属性。这比在其他基于标量的方法中只使用一种嵌入更具代表性。
图分类 (Graph Classification)
图分类的早期解决方案是使用图核(graph kernels),它将图结构嵌入到向量空间中,并根据子图的成对相似度计算核函数。随后,提出了各种子图,如++路径++和++子树++。然后,许多研究开始使用GNN对图进行分类。图池化方法被提出来处理节点嵌入问题。通常,池化方法是对所有节点特征进行汇总或平均,但这样可能会丢弃大量信息。最近,受Hinton提出的胶囊概念的启发,引入了胶囊网络,并使用动态路由机制通过为每个图生成多个嵌入来捕获图级的重要信息和属性,而不是池化操作。
CapsualGNN模型
CapsualGNN框架由四个模块组成:
- 1)节点特征嵌入模块: 利用可训练矩阵将离散的节点特征转换为连续的特征
- 2)主胶囊生成模块: 利用具有类残差连接的多层GCN生成多个张量,并将其串联成主胶囊
- 3)图嵌入生成模块: 利用注意机制获取主胶囊中的重要特征,然后利用动态路由形成图胶囊
- 4)图分类模块: 重用图胶囊上的动态路由,形成类胶囊,用于图分类
1)节点特征嵌入模块
图的节点通常具有特征,这些特征有时是节点的标签(离散特征)。离散特征不利于网络计算。如果将特征独热编码,在特征值范围分布较广的情况下,节点的特征维数过大,增加了时空复杂度,不适合图的泛化。且独热编码会包含过多的零,导致结构稀疏,难以表示网络。针对这种情况,CapsualGNN采用了一种类似于词嵌入的思想。论文使用一个可训练的嵌入矩阵W_0 \in R^{l×d}, 其中l为不同特征值的个数,d为嵌入的维数。作者将每个节点的F维特征扩展为F×d维特征。这些嵌入是归一化为值-1到1的,形成最终的输入数据属于R^{N×F×d}: 其中b0是可学习偏差。节点嵌入模块将离散特征转换为连续特征。
2)主胶囊生成模块
2.1)胶囊网络
网络结构由三部分组成。第一部分是输入层,用于处理输入样本并将其转换为初级胶囊。第二部分是通过动态路由将初级胶囊转化为数字胶囊。第三部分是通过三层全连通网络获取数字胶囊的分类结果。胶囊结构也采用了上述方法。在分类层,首先构造主胶囊,然后采用动态路由的方法构建图胶囊进行分类。中间层可以添加多个胶囊层。
2.2)图卷积网络
离散卷积本质上是一个加权和。因此,根据卷积原理,GCN使用邻接矩阵和单位矩阵作为权值卷积核来聚合邻居信息。当使用k层GCN时,相当于聚合k阶邻居信息。Kipf提出的GCN卷积形式为: GCN使用全连接网络作为分类器。GCN是谱卷积的一阶局部逼近,是一个多层图卷积神经网络。每个卷积层只聚合一阶邻域信息。多阶邻域信息传输可通过多个卷积层叠加实现。但是GCN的缺点是它只能处理无向图,并且没有考虑到自连通边在重整化过程中与其他边不同,所以它们的权值是相同的。
2.3)注意力机制
传统的Seq2Seq模型有两个缺陷: 1) 它将输入X的所有信息编码成一个固定长度的隐藏向量Z,而不考虑X的长度,这导致当输入句子很长时,模型的性能急剧下降; 2) 将输入的X编码为固定长度,并为句子中的每个单词分配相同的权重是不合理的。基于此,注意力机制被提出。本研究注意机制的计算过程如图4所示。
主胶囊网络生成: 区分自连接边和其他边的一种方法是为自连接边在0到1的范围内添加系数k。该系数可以通过训练得到。有了这个系数,图中的自连通边就不会具有与其他类型的边相同的效果。由于GCN过于平滑,CapsualGNN使用带有类残差连接的GCN作为节点嵌入的生成器。该模块连接GCN每一层的输出以形成主胶囊,而不是仅使用最后一层的输出来构建主胶囊。根据Kipf提出的GCN公式,生成函数表示为: 模块为每个节点生成h个胶囊,每个胶囊为d维,d为GCN层的输出维数。与Sara Sabour提出的模型不同,CapsualGNN使用多个胶囊来表示一个节点,可以理解为从多个角度描述一个节点。多个不同的透视图可以帮助我们更全面地理解节点的信息。
3)图嵌入生成模块
主胶囊表示节点信息,但图分类需要获取图信息,使用多个胶囊从不同角度表示节点的思路是一样的。多个胶囊可以用来表示一个图的信息。利用注意机制和动态路由将原始胶囊转化为图胶囊。在CapsualGNN中,使用双层全连接网络作为注意值的计算函数,其表示为F_atten(·)。此处利用的注意力过程如下: 在图中应用注意机制,不仅可以使模型注意到对分类有重要意义的节点,且可以使结果与图的大小无关,即减少节点数量对结果的影响。在使用动态路由之前,可以使用坐标添加模块来包含节点的位置信息。具体的方法是将从不同角度描述每个节点的多个胶囊中的最后一个胶囊作为存储坐标信息的胶囊。对于胶囊和剩余胶囊,利用参数矩阵改变特征维数,将坐标胶囊的信息合并到剩余胶囊的向量维数中。计算过程如图8所示。
4)图分类模块
在CapsualGNN中,胶囊的模块长度表示胶囊所代表的实体的概率。在图分类的过程中,得到图嵌入后,需要映射到每个类别,因此我们需要将图嵌入胶囊路由到对应的类别胶囊。与前一种方法类似,该路由方法通过注意路由和动态路由方法实现,分类胶囊和类别的数量相同。得到分类囊后,通过计算分类囊的模长,可以知道图属于类别的概率。特别地,动态路由机制如算法一所示:
CapsualGNN如算法二所示:
在Algorithm2中,第1行是节点嵌入模块。第2行是GCN的第一层,第4行到第7行得到GCN各层的输出,类残差连接。8号线连接子形成初级胶囊。第10行和第11行使用注意机制来增强对分类重要的节点的重要性。第12行和第13行使用动态路由方法获取图胶囊和类胶囊。
5)模型训练
分类损失(classification loss):
重建损失(reconstruction loss):
总损失(total loss):