Source code for dhg.structure.graphs.bipartite_graph

import random
import pickle
from pathlib import Path
from copy import deepcopy
from typing import Union, Optional, List, Tuple, Any, Dict, TYPE_CHECKING

import torch
import numpy as np

from dhg.structure.hypergraphs import Hypergraph
from dhg.visualization.structure.draw import draw_bigraph
from ..base import BaseGraph
from dhg.utils.sparse import sparse_dropout


[docs]class BiGraph(BaseGraph): r""" Class for bipartite graph. Args: ``num_u`` (``int``): The Number of vertices in set :math:`\mathcal{U}`. ``num_v`` (``int``): The Number of vertices in set :math:`\mathcal{V}`. ``e_list`` (``Union[List[int], List[List[int]]], optional``): Initial edge set. Defaults to ``None``. ``e_weight`` (``Union[float, List[float]], optional``): A list of weights for edges. Defaults to ``None``. ``merge_op`` (``str``): The operation to merge those conflicting edges, which can be one of ``'mean'``, ``'sum'``, or ``'max'``. Defaults to ``'mean'``. ``device`` (``torch.device``, optional): The device to store the bipartite graph. Defaults to ``torch.device('cpu')``. """ def __init__( self, num_u: int, num_v: int, e_list: Optional[Union[List[int], List[List[int]]]] = None, e_weight: Optional[Union[float, List[float]]] = None, merge_op: str = "mean", device: torch.device = torch.device("cpu"), ): super().__init__(num_v, device=device) self._num_u = num_u if e_list is not None: self.add_edges(e_list, e_weight, merge_op=merge_op) def __repr__(self) -> str: r"""Print the bipartite graph information. """ return f"Bipartite Graph(num_u={self.num_u}, num_v={self.num_v}, num_e={self.num_e})" @property def state_dict(self) -> Dict[str, Any]: r"""Get the state dict of the bipartite graph. """ return { "num_u": self.num_u, "num_v": self.num_v, "raw_e_dict": self._raw_e_dict, }
[docs] def save(self, file_path: Union[str, Path]): r"""Save the DHG's bipartite graph structure to a file. Args: ``file_path`` (``Union[str, Path]``): The file path to store the DHG's bipartite graph structure. """ file_path = Path(file_path) assert file_path.parent.exists(), "The directory does not exist." data = { "class": "BiGraph", "state_dict": self.state_dict, } with open(file_path, "wb") as fp: pickle.dump(data, fp)
[docs] @staticmethod def load(file_path: Union[str, Path]): r"""Load the DHG's bipartite graph structure from a file. Args: ``file_path`` (``Union[str, Path]``): The file path to load the DHG's bipartite graph structure. """ file_path = Path(file_path) assert file_path.exists(), "The file does not exist." with open(file_path, "rb") as fp: data = pickle.load(fp) assert data["class"] == "BiGraph", "The file is not a bipartite graph." return BiGraph.from_state_dict(data["state_dict"])
[docs] def draw( self, e_style: str = "line", u_label: Optional[List[str]] = None, u_size: Union[float, list] = 1.0, u_color: Union[str, list] = "m", u_line_width: Union[str, list] = 1.0, v_label: Optional[List[str]] = None, v_size: Union[float, list] = 1.0, v_color: Union[str, list] = "r", v_line_width: Union[str, list] = 1.0, e_color: Union[str, list] = "gray", e_line_width: Union[str, list] = 1.0, u_font_size: float = 1.0, v_font_size: float = 1.0, font_family: str = "sans-serif", push_u_strength: float = 1.0, push_v_strength: float = 1.0, push_e_strength: float = 1.0, pull_e_strength: float = 1.0, pull_u_center_strength: float = 1.0, pull_v_center_strength: float = 1.0, ): r"""Draw the bipartite graph structure. Args: ``e_style`` (``str``): The edge style. The supported edge styles are only ``'line'``. Defaults to ``'line'``. ``u_label`` (``list``): The label of vertices in set :math:`\mathcal{U}`. Defaults to ``None``. ``u_size`` (``Union[str, list]``): The size of vertices in set :math:`\mathcal{U}`. If ``u_size`` is a ``float``, all vertices will have the same size. If ``u_size`` is a ``list``, the size of each vertex will be set according to the corresponding element in the list. Defaults to ``1.0``. ``u_color`` (``Union[str, list]``): The `color <https://matplotlib.org/stable/gallery/color/named_colors.html>`_ of vertices in set :math:`\mathcal{U}`. If ``u_color`` is a ``str``, all vertices will have the same color. If ``u_color`` is a ``list``, the color of each vertex will be set according to the corresponding element in the list. Defaults to ``'m'``. ``u_line_width`` (``Union[str, list]``): The line width of vertices in set :math:`\mathcal{U}`. If ``u_line_width`` is a ``float``, all vertices will have the same line width. If ``u_line_width`` is a ``list``, the line width of each vertex will be set according to the corresponding element in the list. Defaults to ``1.0``. ``v_label`` (``list``): The label of vertices in set :math:`\mathcal{V}`. Defaults to ``None``. ``v_size`` (``Union[str, list]``): The size of vertices in set :math:`\mathcal{V}`. If ``v_size`` is a ``float``, all vertices will have the same size. If ``v_size`` is a ``list``, the size of each vertex will be set according to the corresponding element in the list. Defaults to ``1.0``. ``v_color`` (``Union[str, list]``): The `color <https://matplotlib.org/stable/gallery/color/named_colors.html>`_ of vertices in set :math:`\mathcal{V}`. If ``v_color`` is a ``str``, all vertices will have the same color. If ``v_color`` is a ``list``, the color of each vertex will be set according to the corresponding element in the list. Defaults to ``'r'``. ``v_line_width`` (``Union[str, list]``): The line width of vertices in set :math:`\mathcal{V}`. If ``v_line_width`` is a ``float``, all vertices will have the same line width. If ``v_line_width`` is a ``list``, the line width of each vertex will be set according to the corresponding element in the list. Defaults to ``1.0``. ``e_color`` (``Union[str, list]``): The `color <https://matplotlib.org/stable/gallery/color/named_colors.html>`_ of edges. If ``e_color`` is a ``str``, all edges will have the same color. If ``e_color`` is a ``list``, the color of each edge will be set according to the corresponding element in the list. Defaults to ``'gray'``. ``e_line_width`` (``Union[str, list]``): The line width of edges. If ``e_line_width`` is a ``float``, all edges will have the same line width. If ``e_line_width`` is a ``list``, the line width of each edge will be set according to the corresponding element in the list. Defaults to ``1.0``. ``u_font_size`` (``float``): The font size of vertex labels in set :math:`\mathcal{U}`. Defaults to ``1.0``. ``v_font_size`` (``float``): The font size of vertex labels in set :math:`\mathcal{V}`. Defaults to ``1.0``. ``font_family`` (``str``): The font family of vertex labels. Defaults to ``'sans-serif'``. ``push_u_strength`` (``float``): The strength of pushing vertices in set :math:`\mathcal{U}`. Defaults to ``1.0``. ``push_v_strength`` (``float``): The strength of pushing vertices in set :math:`\mathcal{V}`. Defaults to ``1.0``. ``push_e_strength`` (``float``): The strength of pushing edges. Defaults to ``1.0``. ``pull_e_strength`` (``float``): The strength of pulling edges. Defaults to ``1.0``. ``pull_u_center_strength`` (``float``): The strength of pulling vertices in set :math:`\mathcal{U}` to the center. Defaults to ``1.0``. ``pull_v_center_strength`` (``float``): The strength of pulling vertices in set :math:`\mathcal{V}` to the center. Defaults to ``1.0``. """ draw_bigraph( self, e_style, u_label, u_size, u_color, u_line_width, v_label, v_size, v_color, v_line_width, e_color, e_line_width, u_font_size, v_font_size, font_family, push_u_strength, push_v_strength, push_e_strength, pull_e_strength, pull_u_center_strength, pull_v_center_strength, )
[docs] def clear(self): r"""Remove all edges in the bipartite graph. """ return super().clear()
[docs] def clone(self): r"""Clone the bipartite graph. """ _g = BiGraph(self.num_u, self.num_v, device=self.device) if self._raw_e_dict is not None: _g._raw_e_dict = deepcopy(self._raw_e_dict) _g.cache = deepcopy(self.cache) return _g
[docs] def to(self, device: torch.device): r"""Move the bipartite graph to the specified device. Args: ``device`` (``torch.device``): The device to store the bipartite graph. """ return super().to(device)
# utils def _format_edges( self, e_list: Union[List[int], List[List[int]]], e_weight: Optional[Union[float, List[float]]] = None, ) -> Tuple[List[List[int]], List[float]]: r"""Check the format of input e_list, and convert raw edge list into edge list. .. note:: If edges in ``e_list`` only have two elements, we will append default weight ``1`` to all edges. Args: ``e_list`` (``List[List[int]]``): Edge list should be a list of edge with pair elements. ``e_weight`` (``List[float]``, optional): Edge weights for each edge. Defaults to ``None``. """ if e_list is None: return [], [] # only one edge if isinstance(e_list[0], int) and len(e_list) == 2: e_list = [e_list] if e_weight is not None: e_weight = [e_weight] e_array = np.array(e_list) assert e_array[:, 0].max() < self.num_u, "The u_idx in e_list is out of range." assert e_array[:, 1].max() < self.num_v, "The v_idx in e_list is out of range." # complete the weight if e_weight is None: e_weight = [1.0] * len(e_list) return e_list, e_weight # ===================================================================================== # some construction functions
[docs] @staticmethod def from_state_dict(state_dict: dict): r"""Load the bipartite graph structure from a state dictionary. Args: ``state_dict`` (``dict``): The state dictionary to load the bipartite graph structure. """ _g = BiGraph(state_dict["num_u"], state_dict["num_v"]) _g._raw_e_dict = deepcopy(state_dict["raw_e_dict"]) return _g
[docs] @staticmethod def from_adj_list( num_u: int, num_v: int, adj_list: List[List[int]], device: torch.device = torch.device("cpu"), ) -> "BiGraph": r"""Construct a bipartite graph from the adjacency list. Each line in the adjacency list has two components. The first element in each line is the ``u_idx``, and the rest elements are the ``v_idx`` that connected to the ``u_idx``. .. note:: This function can only construct the unweighted bipartite graph. Args: ``num_u`` (``int``): The number of vertices in set :math:`\mathcal{U}`. ``num_v`` (``int``): The number of vertices in set :math:`\mathcal{V}`. ``adj_list`` (``List[List[int]]``): Adjacency list. ``device`` (``torch.device``): The device to store the bipartite graph. Defaults to ``torch.device('cpu')``. """ e_list = [] for line in adj_list: if len(line) <= 1: continue u_idx = line[0] e_list.extend([[u_idx, v_idx] for v_idx in line[1:]]) _g = BiGraph(num_u, num_v, e_list, device=device) return _g
[docs] @staticmethod def from_hypergraph( hypergraph: Hypergraph, vertex_as_U: bool = True, weighted: bool = False, device: torch.device = torch.device("cpu"), ) -> "BiGraph": r"""Construct a bipartite graph from the hypergraph. Args: ``hypergraph`` (``Hypergraph``): Hypergraph. ``vertex_as_U`` (``bool``): If set to ``True``, vertices in hypergraph will be transformed to vertices in set :math:`U`, and hyperedges in hypergraph will be transformed to vertices in set :math:`V`. Otherwise, vertices in hypergraph will be transformed to vertices in set :math:`V`, and hyperedges in hypergraph will be transformed to vertices in set :math:`U`. Defaults to ``True``. ``weighted`` (``bool``): If set to ``True``, the bipartite graph will be constructed with weighted edges. The weight of each edge is assigned by the weight of the associated hyperedge in the original hypergraph. Defaults to ``False``. ``device`` (``torch.device``): The device to store the bipartite graph. Defaults to ``torch.device('cpu')``. """ assert isinstance(hypergraph, Hypergraph), "The input `hypergraph` should be a instance of `Hypergraph` class." raw_e_list, raw_e_weight = deepcopy(hypergraph.e) e_weight = None if vertex_as_U: num_u, num_v = hypergraph.num_v, hypergraph.num_e e_list = [(v_idx, e_idx) for e_idx, v_list in enumerate(raw_e_list) for v_idx in v_list] if weighted: e_weight = [ e_weight for e_idx, e_weight in enumerate(raw_e_weight) for _ in range(len(raw_e_list[e_idx])) ] else: num_u, num_v = hypergraph.num_e, hypergraph.num_v e_list = [(e_idx, v_idx) for e_idx, v_list in enumerate(raw_e_list) for v_idx in v_list] if weighted: e_weight = [ e_weight for e_idx, e_weight in enumerate(raw_e_weight) for _ in range(len(raw_e_list[e_idx])) ] _g = BiGraph(num_u, num_v, e_list, e_weight, device=device) return _g
# ===================================================================================== # some structure modification functions
[docs] def add_edges( self, e_list: Union[List[int], List[List[int]]], e_weight: Optional[Union[float, List[float]]] = None, merge_op: str = "mean", ): r"""Add edges to the bipartite graph. Args: ``e_list`` (``Union[List[int], List[List[int]]]``): Edge list. ``e_weight`` (``Union[float, List[float]], optional``): A list of weights for edges. Defaults to ``None``. ``merge_op`` (``str``): The operation to merge those conflicting edges, which can be one of ``'mean'``, ``'sum'``, or ``'max'``. Defaults to ``'mean'``. """ if len(e_list) == 0: return e_list, e_weight = self._format_edges(e_list, e_weight) for (src, dst), w in zip(e_list, e_weight): self._add_edge(src, dst, w, merge_op) self._clear_cache()
def _add_edge(self, src: int, dst: int, w: float, merge_op: str): r"""Add an edge to the bipartite graph. Args: ``src`` (``int``): Source vertex index. ``dst`` (``int``): Destination vertex index. ``w`` (``float``): Edge weight. ``merge_op`` (``str``): The merge operation for the conflicting edges. """ if merge_op == "mean": merge_func = lambda x, y: (x + y) / 2 elif merge_op == "max": merge_func = lambda x, y: max(x, y) elif merge_op == "sum": merge_func = lambda x, y: x + y else: raise ValueError(f"Unknown edge merge operation: {merge_op}.") if (src, dst) in self._raw_e_dict: self._raw_e_dict[(src, dst)] = merge_func(self._raw_e_dict[(src, dst)], w) else: self._raw_e_dict[(src, dst)] = w self._clear_cache()
[docs] def remove_edges(self, e_list: Union[List[int], List[List[int]]]): r"""Remove specifed edges in the bipartite graph. Args: ``e_list`` (``Union[List[int], List[List[int]]]``): Edges to be removed. """ e_list, _ = self._format_edges(e_list) for src, dst in e_list: self._remove_edge(src, dst) self._clear_cache()
[docs] def switch_uv(self): r"""Switch the set :math:`\mathcal{U}` and set :math:`\mathcal{V}` of the bipartite graph, and return the vertex set switched bipartite graph. """ _g = self.clone() _g._num_u, _g._num_v = self.num_v, self.num_u _g._raw_e_dict = {(v, u): w for (u, v), w in self._raw_e_dict.items()} _g._clear_cache() return _g
[docs] def drop_edges(self, drop_rate: float, ord: str = "uniform"): r"""Randomly drop edges from the bipartite graph. This function will return a new bipartite graph with non-dropped edges. Args: ``drop_rate`` (``float``): The drop rate of edges. ``ord`` (``str``): The order of dropping edges. Currently, only ``'uniform'`` is supported. Defaults to ``uniform``. """ if ord == "uniform": _raw_e_dict = {k: v for k, v in self._raw_e_dict.items() if random.random() > drop_rate} state_dict = { "num_u": self.num_u, "num_v": self.num_v, "raw_e_dict": _raw_e_dict, } _g = BiGraph.from_state_dict(state_dict) _g = _g.to(self.device) else: raise ValueError(f"Unknown drop order: {ord}.") return _g
# ============================================================================== # properties for representation @property def u(self) -> List[int]: r"""Return the list of vertices in set :math:`\mathcal{U}`. """ return list(range(self.num_u)) @property def v(self) -> List[int]: r"""Return the list of vertices in set :math:`\mathcal{V}`. """ return super().v @property def e(self) -> Tuple[List[List[int]], List[float]]: r"""Return edges and their weights in the bipartite graph with ``(edge_list, edge_weight_list)`` format. ``i-th`` element in the ``edge_list`` denotes ``i-th`` edge, :math:`[u \longleftrightarrow v]`. ``i-th`` element in ``edge_weight_list`` denotes the weight of ``i-th`` edge, :math:`e_{w}`. The lenght of the two lists are both :math:`|\mathcal{E}|`. """ return super().e @property def num_u(self) -> int: r"""Return the number of vertices in set :math:`\mathcal{U}`. """ return self._num_u @property def num_v(self) -> int: r"""Return the number of vertices in set :math:`\mathcal{V}`. """ return super().num_v @property def num_e(self) -> int: r"""Return the number of edges in the bipartite graph. """ return super().num_e @property def deg_u(self) -> torch.Tensor: r"""Return the degree list of vertices in set :math:`\mathcal{U}`. """ return self.D_u._values().cpu().numpy().tolist() @property def deg_v(self) -> torch.Tensor: r"""Return the degree list of vertices in set :math:`\mathcal{V}`. """ return self.D_v._values().cpu().numpy().tolist()
[docs] def nbr_v(self, u_idx: int) -> torch.Tensor: r"""Return a neighbor vertex list in set :math:`\mathcal{V}` of the specified vertex ``u_idx``. Args: ``u_idx`` (``int``): The index of the vertex in set :math:`\mathcal{U}`. """ return self.N_v(u_idx).cpu().numpy().tolist()
[docs] def nbr_u(self, v_idx: int) -> torch.Tensor: r"""Return a neighbor vertex list in set :math:`\mathcal{U}` of the specified vertex ``v_idx``. Args: ``v_idx`` (``int``): The index of the vertex in set :math:`\mathcal{V}`. """ return self.N_u(v_idx).cpu().numpy().tolist()
# ===================================================================================== # properties for deep learning @property def vars_for_DL(self) -> List[str]: r"""Return a name list of available variables for deep learning in the bipartite graph including Sparse Matrices: .. math:: \mathbf{A}, \mathbf{B}, \mathbf{B}^\top Sparse Diagnal Matrices: .. math:: \mathbf{D}_u, \mathbf{D}_v, \mathbf{D}_u^{-1}, \mathbf{D}_v^{-1} Vectors: .. math:: \vec{e}_{u}, \vec{e}_{v}, \vec{e}_{weight} """ return [ "A", "B", "B_T", "D_u", "D_v", "D_u_neg_1", "D_v_neg_1", "e_u", "e_v", "e_weight", ] @property def A(self) -> torch.Tensor: r"""Return the adjacency matrix :math:`\mathbf{A}` of the bipartite graph with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{U}| + |\mathcal{V}|, |\mathcal{U}| + |\mathcal{V}|)`. """ if self.cache.get("A", None) is None: UU = torch.sparse_coo_tensor(size=(self.num_u, self.num_u), device=self.device) VV = torch.sparse_coo_tensor(size=(self.num_v, self.num_v), device=self.device) A_up = torch.hstack([UU, self.B]) A_down = torch.hstack([self.B_T, VV]) self.cache["A"] = torch.vstack([A_up, A_down]).coalesce() return self.cache["A"] @property def B(self) -> torch.Tensor: r"""Return the bipartite adjacency matrix :math:`\mathbf{B}` of the bipartite graph with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{U}|, |\mathcal{V}|)`. """ if self.cache.get("B", None) is None: if self.num_e == 0: self.cache["B"] = torch.sparse_coo_tensor(size=(self.num_u, self.num_v)) else: e_list, e_weight = self.e self.cache["B"] = torch.sparse_coo_tensor( indices=torch.tensor(e_list).t(), values=torch.tensor(e_weight), size=(self.num_u, self.num_v), device=self.device, ).coalesce() return self.cache["B"] @property def B_T(self) -> torch.Tensor: r"""Return the transposed bipartite adjacency matrix :math:`\mathbf{B}^\top` of the bipartite graph with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{V}|, |\mathcal{U}|)`. """ if self.cache.get("B_T", None) is None: self.cache["B_T"] = self.B.t().coalesce() return self.cache["B_T"] @property def D_u(self) -> torch.Tensor: r"""Return the diagnal matrix of vertex in degree :math:`\mathbf{D}_u` with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{U}|, |\mathcal{U}|)`. """ if self.cache.get("D_u", None) is None: _tmp = torch.sparse.sum(self.B, dim=1).to_dense().clone().view(-1) self.cache["D_u"] = torch.sparse_coo_tensor( indices=torch.arange(0, self.num_u, device=self.device).view(1, -1).repeat(2, 1), values=_tmp, size=torch.Size([self.num_u, self.num_u]), device=self.device, ).coalesce() return self.cache["D_u"] @property def D_v(self) -> torch.Tensor: r"""Return the diagnal matrix of vertex out degree :math:`\mathbf{D}_v` with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{V}|, |\mathcal{V}|)`. """ if self.cache.get("D_v", None) is None: _tmp = torch.sparse.sum(self.B_T, dim=1).to_dense().clone().view(-1) self.cache["D_v"] = torch.sparse_coo_tensor( indices=torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), values=_tmp, size=torch.Size([self.num_v, self.num_v]), device=self.device, ).coalesce() return self.cache["D_v"] @property def D_u_neg_1(self) -> torch.Tensor: r"""Return the nomalized diagnal matrix of vertex in degree :math:`\mathbf{D}_u^{-1}` with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{U}|, |\mathcal{U}|)`. """ if self.cache.get("D_u_neg_1", None) is None: _mat = self.D_u.clone() _val = _mat._values() ** -1 _val[torch.isinf(_val)] = 0 self.cache["D_u_neg_1"] = torch.sparse_coo_tensor( _mat._indices(), _val, _mat.size(), device=self.device ).coalesce() return self.cache["D_u_neg_1"] @property def D_v_neg_1(self) -> torch.Tensor: r"""Return the nomalized diagnal matrix of vertex out degree :math:`\mathbf{D}_v^{-1}` with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{V}|, |\mathcal{V}|)`. """ if self.cache.get("D_v_neg_1", None) is None: _mat = self.D_v.clone() _val = _mat._values() ** -1 _val[torch.isinf(_val)] = 0 self.cache["D_v_neg_1"] = torch.sparse_coo_tensor( _mat._indices(), _val, _mat.size(), device=self.device ).coalesce() return self.cache["D_v_neg_1"]
[docs] def N_v(self, u_idx: int) -> torch.Tensor: r"""Return neighbor vertices in set :math:`\mathcal{V}` of the specified vertex ``u_idx`` with ``torch.Tensor`` format. Args: ``u_idx`` (``int``): The index of the vertex. """ sub_v_set = self.B[u_idx]._indices()[0].clone() return sub_v_set
[docs] def N_u(self, v_idx: int) -> torch.Tensor: r"""Return neighbor vertices in set :math:`\mathcal{U}` of the specified vertex ``v_idx`` with ``torch.Tensor`` format. Args: ``v_idx`` (``int``): The index of the vertex. """ sub_u_set = self.B_T[v_idx]._indices()[0].clone() return sub_u_set
@property def e_u(self) -> torch.Tensor: r"""Return the index vector :math:`\vec{e}_{u}` of vertices in set :math:`\mathcal{U}` in the bipartite graph with ``torch.Tensor`` format. Size :math:`(|\mathcal{E}|,)`. """ return self.B._indices()[0, :].clone() @property def e_v(self) -> torch.Tensor: r"""Return the index vector :math:`\vec{e}_{v}` of vertices in set :math:`\mathcal{V}` in the bipartite graph with ``torch.Tensor`` format. Size :math:`(|\mathcal{E}|,)`. """ return self.B._indices()[1, :].clone() @property def e_weight(self) -> torch.Tensor: r"""Return the weight vector :math:`\vec{e}_{weight}` of edges in the bipartite graph with ``torch.Tensor`` format. Size :math:`(|\mathcal{E}|,)`. """ return self.B._values().clone() # ============================================================================== # spectral-based convolution/smoothing
[docs] def smoothing(self, X: torch.Tensor, L: torch.Tensor, lamb: float) -> torch.Tensor: return super().smoothing(X, L, lamb)
@property def L_GCN(self) -> torch.Tensor: r"""Return the GCN Laplacian matrix of the bipartite graph with ``torch.Tensor`` format. Size :math:`(|\mathcal{U}| + |\mathcal{V}|, |\mathcal{U}| + |\mathcal{V}|)`. """ if self.cache.get("L_GCN", None) is None: selfloop_indices = torch.arange(0, self.num_u + self.num_v).view(1, -1).repeat(2, 1) selfloop_values = torch.ones(self.num_u + self.num_v).view(-1) A_ = torch.sparse_coo_tensor( indices=torch.hstack([self.A._indices().cpu(), selfloop_indices]), values=torch.hstack([self.A._values().cpu(), selfloop_values]), size=torch.Size([self.num_u + self.num_v, self.num_u + self.num_v]), device=self.device, ).coalesce() D_v_neg_1_2 = torch.sparse.sum(A_, dim=1).to_dense().view(-1) ** (-0.5) D_v_neg_1_2[torch.isinf(D_v_neg_1_2)] = 0 D_v_neg_1_2 = torch.sparse_coo_tensor( indices=selfloop_indices, values=D_v_neg_1_2, size=torch.Size([self.num_u + self.num_v, self.num_u + self.num_v]), device=self.device, ).coalesce() self.cache["L_GCN"] = D_v_neg_1_2.mm(A_).mm(D_v_neg_1_2).clone().coalesce() return self.cache["L_GCN"]
[docs] def smoothing_with_GCN(self, X: torch.Tensor, drop_rate: float = 0.0) -> torch.Tensor: r"""Return the smoothed feature matrix with GCN Laplacian matrix :math:`\mathcal{L}_{GCN}`. Args: ``X`` (``torch.Tensor``): Vertex feature matrix of the bipartite graph. Size :math:`(|\mathcal{U}| + |\mathcal{V}|, C)`. ``drop_rate`` (``float``): Dropout rate. Randomly dropout the connections in adjacency matrix with probability ``drop_rate``. Default: ``0.0``. """ if self.device != X.device: X = X.to(self.device) if drop_rate > 0.0: L_GCN = sparse_dropout(self.L_GCN, drop_rate) else: L_GCN = self.L_GCN return L_GCN.mm(X)
# ============================================================================== # spatial-based convolution/message-passing functions # general message passing
[docs] def u2v( self, X: torch.Tensor, aggr: str = "mean", e_weight: Optional[torch.Tensor] = None, drop_rate: float = 0.0 ) -> torch.Tensor: r"""Message passing from vertices in set :math:`\mathcal{U}` to vertices in set :math:`\mathcal{V}` on the bipartite graph structure. Args: ``X`` (``torch.Tensor``): Feature matrix of vertices in set :math:`\mathcal{U}`. Size: :math:`(|\mathcal{U}|, C)`. ``aggr`` (``str``, optional): Aggregation function for neighbor messages, which can be ``'mean'``, ``'sum'``, or ``'softmax_then_sum'``. Default: ``'mean'``. ``e_weight`` (``torch.Tensor``, optional): The edge weight vector. Size: :math:`(|\mathcal{E}|,)`. Defaults to ``None``. ``drop_rate`` (``float``): Dropout rate. Randomly dropout the connections in adjacency matrix with probability ``drop_rate``. Default: ``0.0``. """ assert aggr in ["mean", "sum", "softmax_then_sum",], "aggr must be one of ['mean', 'sum', 'softmax_then_sum']" if self.device != X.device: self.to(X.device) if e_weight is None: if drop_rate > 0.0: P = sparse_dropout(self.B_T, drop_rate) else: P = self.B_T # message passing if aggr == "mean": X = torch.sparse.mm(P, X) X = torch.sparse.mm(self.D_v_neg_1, X) elif aggr == "sum": X = torch.sparse.mm(P, X) elif aggr == "softmax_then_sum": P = torch.sparse.softmax(P, dim=1) X = torch.sparse.mm(P, X) else: pass else: # init adjacency matrix assert ( e_weight.shape[0] == self.e_weight.shape[0] ), "The size of e_weight must be equal to the size of self.e_weight." P = torch.sparse_coo_tensor(self.B._indices(), e_weight, self.B.shape, device=self.device).t().coalesce() if drop_rate > 0.0: P = sparse_dropout(P, drop_rate) # message passing if aggr == "mean": X = torch.sparse.mm(P, X) D_v_neg_1 = torch.sparse.sum(P, dim=1).to_dense().view(-1, 1) D_v_neg_1[torch.isinf(D_v_neg_1)] = 0 X = D_v_neg_1 * X elif aggr == "sum": X = torch.sparse.mm(P, X) elif aggr == "softmax_then_sum": P = torch.sparse.softmax(P, dim=1) X = torch.sparse.mm(P, X) else: pass return X
[docs] def v2u( self, X: torch.Tensor, aggr: str = "mean", e_weight: Optional[torch.Tensor] = None, drop_rate: float = 0.0 ) -> torch.Tensor: r"""Message passing from vertices in set :math:`\mathcal{V}` to vertices in set :math:`\mathcal{U}` on the bipartite graph structure. Args: ``X`` (``torch.Tensor``): Feature matrix of vertices in set :math:`\mathcal{V}`. Size: :math:`(|\mathcal{V}|, C)`. ``aggr`` (``str``, optional): Aggregation function for neighbor messages, which can be ``'mean'``, ``'sum'``, or ``'softmax_then_sum'``. Default: ``'mean'``. ``e_weight`` (``torch.Tensor``, optional): The edge weight vector. Size: :math:`(|\mathcal{E}|,)`. Defaults to ``None``. ``drop_rate`` (``float``): Dropout rate. Randomly dropout the connections in adjacency matrix with probability ``drop_rate``. Default: ``0.0``. """ assert aggr in ["mean", "sum", "softmax_then_sum",], "aggr must be one of ['mean', 'sum', 'softmax_then_sum']" if self.device != X.device: self.to(X.device) if e_weight is None: if drop_rate > 0.0: P = sparse_dropout(self.B, drop_rate) else: P = self.B # message passing if aggr == "mean": X = torch.sparse.mm(P, X) X = torch.sparse.mm(self.D_u_neg_1, X) elif aggr == "sum": X = torch.sparse.mm(P, X) elif aggr == "softmax_then_sum": P = torch.sparse.softmax(P, dim=1) X = torch.sparse.mm(P, X) else: pass else: # init adjacency matrix assert ( e_weight.shape[0] == self.e_weight.shape[0] ), "The size of e_weight must be equal to the size of self.e_weight." P = torch.sparse_coo_tensor(self.B._indices(), e_weight, self.B.shape, device=self.device).coalesce() if drop_rate > 0.0: P = sparse_dropout(P, drop_rate) # message passing if aggr == "mean": X = torch.sparse.mm(P, X) D_u_neg_1 = torch.sparse.sum(P, dim=1).to_dense().view(-1, 1) D_u_neg_1[torch.isinf(D_u_neg_1)] = 0 X = D_u_neg_1 * X elif aggr == "sum": X = torch.sparse.mm(P, X) elif aggr == "softmax_then_sum": P = torch.sparse.softmax(P, dim=1) X = torch.sparse.mm(P, X) else: pass return X