Source code for dhg.models.graphs.gin

import torch
import torch.nn as nn

import dhg
from dhg.nn import MLP
from dhg.nn import GINConv


[docs]class GIN(nn.Module): r"""The GIN model proposed in `How Powerful are Graph Neural Networks? <https://arxiv.org/pdf/1810.00826>`_ paper (ICLR 2019). Args: ``in_channels`` (``int``): :math:`C_{in}` is the number of input channels. ``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels. ``num_classes`` (``int``): The Number of class of the classification task. ``num_layers`` (``int``): The number of layers in the GIN model. In the original `code <https://github.com/weihua916/powerful-gnns/blob/master/main.py#L102>`_, it is set to ``5``. ``num_mlp_layers`` (``int``): The number of layers in the MLP. Defaults to ``2``. ``eps`` (``float``): The epsilon value. Defaults to ``0.0``. ``train_eps`` (``bool``): If set to ``True``, the epsilon value will be trainable. Defaults to ``False``. ``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``. ``drop_rate`` (``float``): The dropout ratio. Defaults to ``0.5``. """ def __init__( self, in_channels: int, hid_channels: int, num_classes: int, num_layers: int, num_mlp_layers: int = 2, eps: float = 0.0, train_eps: bool = False, use_bn: bool = False, drop_rate: float = 0.5, ) -> None: super().__init__() assert num_layers >= 2, "num_layers must be greater than or equal to 2." self.layers = nn.ModuleList() self.layers.append( GINConv( MLP( [in_channels] + [hid_channels] * num_mlp_layers, use_bn=use_bn, drop_rate=drop_rate, ), eps, train_eps, ) ) for _ in range(num_layers - 1): self.layers.append( GINConv( MLP( [hid_channels] * (num_mlp_layers + 1), use_bn=use_bn, drop_rate=drop_rate, ), eps, train_eps, ) ) self.pred_layers = nn.ModuleList() self.pred_layers.append(nn.Linear(in_channels, num_classes)) for _ in range(num_layers): self.pred_layers.append(nn.Linear(hid_channels, num_classes))
[docs] def forward(self, X: torch.Tensor, g: "dhg.Graph") -> torch.Tensor: r"""The forward function. Args: ``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`. ``g`` (``dhg.Graph``): The graph structure that contains :math:`N` vertices. """ pred = self.pred_layers[0](X) for idx, layer in enumerate(self.layers): X = layer(X, g) pred += self.pred_layers[idx + 1](X) return pred