Source code for dhg.models.graphs.lightgcn

from typing import Tuple

import torch
import torch.nn as nn

from dhg.structure.graphs import BiGraph


[docs]class LightGCN(nn.Module): r"""The LightGCN model proposed in `LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation <https://arxiv.org/pdf/2002.02126>`_ paper (SIGIR 2020). .. note:: The user and item embeddings are initialized with normal distribution. Args: ``num_users`` (``int``): The Number of users. ``num_items`` (``int``): The Number of items. ``emb_dim`` (``int``): Embedding dimension. ``num_layers`` (``int``): The Number of layers. Defaults to ``3``. ``drop_rate`` (``float``): Dropout rate. Randomly dropout the connections in training stage with probability ``drop_rate``. Default: ``0.0``. """ def __init__( self, num_users: int, num_items: int, emb_dim: int, num_layers: int = 3, drop_rate: float = 0.0 ) -> None: super().__init__() self.num_users, self.num_items = num_users, num_items self.num_layers = num_layers self.drop_rate = drop_rate self.u_embedding = nn.Embedding(num_users, emb_dim) self.i_embedding = nn.Embedding(num_items, emb_dim) self.reset_parameters() def reset_parameters(self): r"""Initialize learnable parameters. """ nn.init.normal_(self.u_embedding.weight, 0, 0.1) nn.init.normal_(self.i_embedding.weight, 0, 0.1)
[docs] def forward(self, ui_bigraph: BiGraph) -> Tuple[torch.Tensor, torch.Tensor]: r"""The forward function. Args: ``ui_bigraph`` (``dhg.BiGraph``): The user-item bipartite graph. """ drop_rate = self.drop_rate if self.training else 0.0 u_embs = self.u_embedding.weight i_embs = self.i_embedding.weight all_embs = torch.cat([u_embs, i_embs], dim=0) embs_list = [all_embs] for _ in range(self.num_layers): all_embs = ui_bigraph.smoothing_with_GCN(all_embs, drop_rate=drop_rate) embs_list.append(all_embs) embs = torch.stack(embs_list, dim=1) embs = torch.mean(embs, dim=1) u_embs, i_embs = torch.split(embs, [self.num_users, self.num_items], dim=0) return u_embs, i_embs