Source code for dhg.nn.convs.graphs.gin_conv

import torch
import torch.nn as nn

from dhg.structure.graphs import Graph


[docs]class GINConv(nn.Module): r"""The GIN convolution layer proposed in `How Powerful are Graph Neural Networks? <https://arxiv.org/pdf/1810.00826>`_ paper (ICLR 2019). Sparse Format: .. math:: \mathbf{x}^{\prime}_i = MLP \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right). Matrix Format: .. math:: \mathbf{X}^{\prime} = MLP \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right). Args: ``MLP`` (``nn.Module``): The neural network to be applied after message passing, i.e. ``nn.Linear``, ``nn.Sequential``. ``eps`` (``float``): The epsilon value. ``train_eps`` (``bool``): If set to ``True``, the epsilon value will be trainable. """ def __init__(self, MLP: nn.Module, eps: float = 0.0, train_eps: bool = False): super().__init__() self.MLP = MLP if train_eps: self.eps = nn.Parameter(torch.Tensor([eps])) else: self.eps = eps
[docs] def forward(self, X: torch.Tensor, g: Graph) -> torch.Tensor: r"""The forward function. Args: X (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N_v, C_{in})`. g (``dhg.Graph``): The graph structure that contains :math:`N_v` vertices. """ X = (1 + self.eps) * X + g.v2v(X, aggr="sum") X = self.MLP(X) return X