import pickle
from pathlib import Path
from copy import deepcopy
from typing import Union, Optional, List, Tuple, Any, Dict, TYPE_CHECKING
import torch
import numpy as np
from ..base import BaseGraph
from dhg.structure.hypergraphs import Hypergraph
# if TYPE_CHECKING:
# from ..hypergraphs import Hypergraph
[docs]class BiGraph(BaseGraph):
r""" Class for bipartite graph.
Args:
``num_u`` (``int``): The Number of vertices in set :math:`\mathcal{U}`.
``num_v`` (``int``): The Number of vertices in set :math:`\mathcal{V}`.
``e_list`` (``Union[List[int], List[List[int]]], optional``): Initial edge set. Defaults to ``None``.
``e_weight`` (``Union[float, List[float]], optional``): A list of weights for edges. Defaults to ``None``.
``merge_op`` (``str``): The operation to merge those conflicting edges, which can be one of ``'mean'``, ``'sum'``, or ``'max'``. Defaults to ``'mean'``.
``device`` (``torch.device``, optional): The device to store the bipartite graph. Defaults to ``torch.device('cpu')``.
"""
def __init__(
self,
num_u: int,
num_v: int,
e_list: Optional[Union[List[int], List[List[int]]]] = None,
e_weight: Optional[Union[float, List[float]]] = None,
merge_op: str = "mean",
device: torch.device = torch.device("cpu"),
):
super().__init__(num_v, device=device)
self._num_u = num_u
if e_list is not None:
self.add_edges(e_list, e_weight, merge_op=merge_op)
def __repr__(self) -> str:
r"""Print the bipartite graph information.
"""
return f"Bipartite Graph(num_u={self.num_u}, num_v={self.num_v}, num_e={self.num_e})"
@property
def state_dict(self) -> Dict[str, Any]:
r"""Get the state dict of the bipartite graph.
"""
return {
"num_u": self.num_u,
"num_v": self.num_v,
"raw_e_dict": self._raw_e_dict,
}
[docs] def save(self, file_path: Union[str, Path]):
r"""Save the DHG's bipartite graph structure to a file.
Args:
``file_path`` (``Union[str, Path]``): The file path to store the DHG's bipartite graph structure.
"""
file_path = Path(file_path)
assert file_path.parent.exists(), "The directory does not exist."
data = {
"class": "BiGraph",
"state_dict": self.state_dict,
}
with open(file_path, "wb") as fp:
pickle.dump(data, fp)
[docs] @staticmethod
def load(file_path: Union[str, Path]):
r"""Load the DHG's bipartite graph structure from a file.
Args:
``file_path`` (``Union[str, Path]``): The file path to load the DHG's bipartite graph structure.
"""
file_path = Path(file_path)
assert file_path.exists(), "The file does not exist."
with open(file_path, "rb") as fp:
data = pickle.load(fp)
assert data["class"] == "BiGraph", "The file is not a bipartite graph."
return BiGraph.load_from_state_dict(data["state_dict"])
[docs] @staticmethod
def load_from_state_dict(state_dict: dict):
r"""Load the bipartite graph structure from a state dictionary.
Args:
``state_dict`` (``dict``): The state dictionary to load the bipartite graph structure.
"""
_g = BiGraph(state_dict["num_u"], state_dict["num_v"])
_g._raw_e_dict = deepcopy(state_dict["raw_e_dict"])
return _g
[docs] def clear(self):
r"""Remove all edges in the bipartite graph.
"""
return super().clear()
[docs] def clone(self):
r"""Clone the bipartite graph.
"""
_g = BiGraph(self.num_u, self.num_v, device=self.device)
if self._raw_e_dict is not None:
_g._raw_e_dict = deepcopy(self._raw_e_dict)
_g.cache = deepcopy(self.cache)
return _g
[docs] def to(self, device: torch.device):
r"""Move the bipartite graph to the specified device.
Args:
``device`` (``torch.device``): The device to store the bipartite graph.
"""
return super().to(device)
# utils
def _format_edges(
self, e_list: Union[List[int], List[List[int]]], e_weight: Optional[Union[float, List[float]]] = None,
) -> Tuple[List[List[int]], List[float]]:
r"""Check the format of input e_list, and convert raw edge list into edge list.
.. note::
If edges in ``e_list`` only have two elements, we will append default weight ``1`` to all edges.
Args:
``e_list`` (``List[List[int]]``): Edge list should be a list of edge with pair elements.
``e_weight`` (``List[float]``, optional): Edge weights for each edge. Defaults to ``None``.
"""
if e_list is None:
return [], []
# only one edge
if isinstance(e_list[0], int) and len(e_list) == 2:
e_list = [e_list]
if e_weight is not None:
e_weight = [e_weight]
e_array = np.array(e_list)
assert e_array[:, 0].max() < self.num_u, "The u_idx in e_list is out of range."
assert e_array[:, 1].max() < self.num_v, "The v_idx in e_list is out of range."
# complete the weight
if e_weight is None:
e_weight = [1.0] * len(e_list)
return e_list, e_weight
# =====================================================================================
# some construction functions
[docs] @staticmethod
def from_adj_list(
num_u: int, num_v: int, adj_list: List[List[int]], device: torch.device = torch.device("cpu"),
) -> "BiGraph":
r"""Construct a bipartite graph from the adjacency list. Each line in the adjacency list has two components. The first element in each line is the ``u_idx``, and the rest elements are the ``v_idx`` that connected to the ``u_idx``.
.. note::
This function can only construct the unweighted bipartite graph.
Args:
``num_u`` (``int``): The number of vertices in set :math:`\mathcal{U}`.
``num_v`` (``int``): The number of vertices in set :math:`\mathcal{V}`.
``adj_list`` (``List[List[int]]``): Adjacency list.
``device`` (``torch.device``): The device to store the bipartite graph. Defaults to ``torch.device('cpu')``.
"""
e_list = []
for line in adj_list:
if len(line) <= 1:
continue
u_idx = line[0]
e_list.extend([[u_idx, v_idx] for v_idx in line[1:]])
_g = BiGraph(num_u, num_v, e_list, device=device)
return _g
[docs] @staticmethod
def from_hypergraph(
hypergraph: Hypergraph,
vertex_as_U: bool = True,
weighted: bool = False,
device: torch.device = torch.device("cpu"),
) -> "BiGraph":
r"""Construct a bipartite graph from the hypergraph.
Args:
``hypergraph`` (``Hypergraph``): Hypergraph.
``vertex_as_U`` (``bool``): If set to ``True``, vertices in hypergraph will be transformed to vertices in set :math:`U`, and hyperedges in hypergraph will be transformed to vertices in set :math:`V`. Otherwise, vertices in hypergraph will be transformed to vertices in set :math:`V`, and hyperedges in hypergraph will be transformed to vertices in set :math:`U`. Defaults to ``True``.
``weighted`` (``bool``): If set to ``True``, the bipartite graph will be constructed with weighted edges. The weight of each edge is assigned by the weight of the associated hyperedge in the original hypergraph. Defaults to ``False``.
``device`` (``torch.device``): The device to store the bipartite graph. Defaults to ``torch.device('cpu')``.
"""
assert isinstance(hypergraph, Hypergraph), "The input `hypergraph` should be a instance of `Hypergraph` class."
raw_e_list, raw_e_weight = deepcopy(hypergraph.e)
e_weight = None
if vertex_as_U:
num_u, num_v = hypergraph.num_v, hypergraph.num_e
e_list = [(v_idx, e_idx) for e_idx, v_list in enumerate(raw_e_list) for v_idx in v_list]
if weighted:
e_weight = [
e_weight for e_idx, e_weight in enumerate(raw_e_weight) for _ in range(len(raw_e_list[e_idx]))
]
else:
num_u, num_v = hypergraph.num_e, hypergraph.num_v
e_list = [(e_idx, v_idx) for e_idx, v_list in enumerate(raw_e_list) for v_idx in v_list]
if weighted:
e_weight = [
e_weight for e_idx, e_weight in enumerate(raw_e_weight) for _ in range(len(raw_e_list[e_idx]))
]
_g = BiGraph(num_u, num_v, e_list, e_weight, device=device)
return _g
# =====================================================================================
# some structure modification functions
[docs] def add_edges(
self,
e_list: Union[List[int], List[List[int]]],
e_weight: Optional[Union[float, List[float]]] = None,
merge_op: str = "mean",
):
r"""Add edges to the bipartite graph.
Args:
``e_list`` (``Union[List[int], List[List[int]]]``): Edge list.
``e_weight`` (``Union[float, List[float]], optional``): A list of weights for edges. Defaults to ``None``.
``merge_op`` (``str``): The operation to merge those conflicting edges, which can be one of ``'mean'``, ``'sum'``, or ``'max'``. Defaults to ``'mean'``.
"""
if len(e_list) == 0:
return
e_list, e_weight = self._format_edges(e_list, e_weight)
for (src, dst), w in zip(e_list, e_weight):
self._add_edge(src, dst, w, merge_op)
self._clear_cache()
def _add_edge(self, src: int, dst: int, w: float, merge_op: str):
r"""Add an edge to the bipartite graph.
Args:
``src`` (``int``): Source vertex index.
``dst`` (``int``): Destination vertex index.
``w`` (``float``): Edge weight.
``merge_op`` (``str``): The merge operation for the conflicting edges.
"""
if merge_op == "mean":
merge_func = lambda x, y: (x + y) / 2
elif merge_op == "max":
merge_func = lambda x, y: max(x, y)
elif merge_op == "sum":
merge_func = lambda x, y: x + y
else:
raise ValueError(f"Unknown edge merge operation: {merge_op}.")
if (src, dst) in self._raw_e_dict:
self._raw_e_dict[(src, dst)] = merge_func(self._raw_e_dict[(src, dst)], w)
else:
self._raw_e_dict[(src, dst)] = w
self._clear_cache()
[docs] def remove_edges(self, e_list: Union[List[int], List[List[int]]]):
r"""Remove specifed edges in the bipartite graph.
Args:
``e_list`` (``Union[List[int], List[List[int]]]``): Edges to be removed.
"""
e_list, _ = self._format_edges(e_list)
for src, dst in e_list:
self._remove_edge(src, dst)
self._clear_cache()
[docs] def switch_uv(self):
r"""Switch the set :math:`\mathcal{U}` and set :math:`\mathcal{V}` of the bipartite graph, and return the vertex set switched bipartite graph.
"""
_g = self.clone()
_g._num_u, _g._num_v = self.num_v, self.num_u
_g._raw_e_dict = {(v, u): w for (u, v), w in self._raw_e_dict.items()}
_g._clear_cache()
return _g
# ==============================================================================
# properties for representation
@property
def u(self) -> List[int]:
r"""Return the list of vertices in set :math:`\mathcal{U}`.
"""
return list(range(self.num_u))
@property
def v(self) -> List[int]:
r"""Return the list of vertices in set :math:`\mathcal{V}`.
"""
return super().v
@property
def e(self) -> Tuple[List[List[int]], List[float]]:
r"""Return the edge list and weight list in the bipartite graph.
"""
return super().e
@property
def num_u(self) -> int:
r"""Return the number of vertices in set :math:`\mathcal{U}`.
"""
return self._num_u
@property
def num_v(self) -> int:
r"""Return the number of vertices in set :math:`\mathcal{V}`.
"""
return super().num_v
@property
def num_e(self) -> int:
r"""Return the number of edges in the bipartite graph.
"""
return super().num_e
@property
def deg_u(self) -> torch.Tensor:
r"""Return the degree list of vertices in set :math:`\mathcal{U}`.
"""
return self.D_u._values().cpu().numpy().tolist()
@property
def deg_v(self) -> torch.Tensor:
r"""Return the degree list of vertices in set :math:`\mathcal{V}`.
"""
return self.D_v._values().cpu().numpy().tolist()
[docs] def nbr_v(self, u_idx: int) -> torch.Tensor:
r"""Return a neighbor vertex list in set :math:`\mathcal{V}` of the specified vertex ``u_idx``.
Args:
``u_idx`` (``int``): The index of the vertex in set :math:`\mathcal{U}`.
"""
return self.N_v(u_idx).cpu().numpy().tolist()
[docs] def nbr_u(self, v_idx: int) -> torch.Tensor:
r"""Return a neighbor vertex list in set :math:`\mathcal{U}` of the specified vertex ``v_idx``.
Args:
``v_idx`` (``int``): The index of the vertex in set :math:`\mathcal{V}`.
"""
return self.N_u(v_idx).cpu().numpy().tolist()
# =====================================================================================
# properties for deep learning
@property
def vars_for_DL(self) -> List[str]:
r"""Return a name list of available variables for deep learning in the bipartite graph including
Sparse Matrices:
.. math::
\mathbf{A}, \mathbf{B}, \mathbf{B}^\top
Sparse Diagnal Matrices:
.. math::
\mathbf{D}_u, \mathbf{D}_v, \mathbf{D}_u^{-1}, \mathbf{D}_v^{-1}
Vectors:
.. math::
\vec{e}_{u}, \vec{e}_{v}, \vec{e}_{weight}
"""
return [
"A",
"B",
"B_T",
"D_u",
"D_v",
"D_u_neg_1",
"D_v_neg_1",
"e_u",
"e_v",
"e_weight",
]
@property
def A(self) -> torch.Tensor:
r"""Return the adjacency matrix :math:`\mathbf{A}` of the bipartite graph with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{U}| + |\mathcal{V}|, |\mathcal{U}| + |\mathcal{V}|)`.
"""
if self.cache.get("A", None) is None:
UU = torch.sparse_coo_tensor(size=(self.num_u, self.num_u), device=self.device)
VV = torch.sparse_coo_tensor(size=(self.num_v, self.num_v), device=self.device)
A_up = torch.hstack([UU, self.B])
A_down = torch.hstack([self.B_T, VV])
self.cache["A"] = torch.vstack([A_up, A_down]).coalesce()
return self.cache["A"]
@property
def B(self) -> torch.Tensor:
r"""Return the bipartite adjacency matrix :math:`\mathbf{B}` of the bipartite graph with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{U}|, |\mathcal{V}|)`.
"""
if self.cache.get("B", None) is None:
if self.num_e == 0:
self.cache["B"] = torch.sparse_coo_tensor(size=(self.num_u, self.num_v))
else:
e_list, e_weight = self.e
self.cache["B"] = torch.sparse_coo_tensor(
indices=torch.tensor(e_list).t(),
values=torch.tensor(e_weight),
size=(self.num_u, self.num_v),
device=self.device,
).coalesce()
return self.cache["B"]
@property
def B_T(self) -> torch.Tensor:
r"""Return the transposed bipartite adjacency matrix :math:`\mathbf{B}^\top` of the bipartite graph with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{V}|, |\mathcal{U}|)`.
"""
if self.cache.get("B_T", None) is None:
self.cache["B_T"] = self.B.t().coalesce()
return self.cache["B_T"]
@property
def D_u(self) -> torch.Tensor:
r"""Return the diagnal matrix of vertex in degree :math:`\mathbf{D}_u` with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{U}|, |\mathcal{U}|)`.
"""
if self.cache.get("D_u", None) is None:
_tmp = torch.sparse.sum(self.B, dim=1).to_dense().clone().view(-1)
self.cache["D_u"] = torch.sparse_coo_tensor(
indices=torch.arange(0, self.num_u).view(1, -1).repeat(2, 1),
values=_tmp,
size=torch.Size([self.num_u, self.num_u]),
device=self.device,
).coalesce()
return self.cache["D_u"]
@property
def D_v(self) -> torch.Tensor:
r"""Return the diagnal matrix of vertex out degree :math:`\mathbf{D}_v` with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{V}|, |\mathcal{V}|)`.
"""
if self.cache.get("D_v", None) is None:
_tmp = torch.sparse.sum(self.B_T, dim=1).to_dense().clone().view(-1)
self.cache["D_v"] = torch.sparse_coo_tensor(
indices=torch.arange(0, self.num_v).view(1, -1).repeat(2, 1),
values=_tmp,
size=torch.Size([self.num_v, self.num_v]),
device=self.device,
).coalesce()
return self.cache["D_v"]
@property
def D_u_neg_1(self) -> torch.Tensor:
r"""Return the nomalized diagnal matrix of vertex in degree :math:`\mathbf{D}_u^{-1}` with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{U}|, |\mathcal{U}|)`.
"""
if self.cache.get("D_u_neg_1", None) is None:
_mat = self.D_u.clone()
_val = _mat._values() ** -1
_val[torch.isinf(_val)] = 0
self.cache["D_u_neg_1"] = torch.sparse_coo_tensor(
_mat._indices(), _val, _mat.size(), device=self.device
).coalesce()
return self.cache["D_u_neg_1"]
@property
def D_v_neg_1(self) -> torch.Tensor:
r"""Return the nomalized diagnal matrix of vertex out degree :math:`\mathbf{D}_v^{-1}` with ``torch.sparse_coo_tensor`` format. Size :math:`(|\mathcal{V}|, |\mathcal{V}|)`.
"""
if self.cache.get("D_v_neg_1", None) is None:
_mat = self.D_v.clone()
_val = _mat._values() ** -1
_val[torch.isinf(_val)] = 0
self.cache["D_v_neg_1"] = torch.sparse_coo_tensor(
_mat._indices(), _val, _mat.size(), device=self.device
).coalesce()
return self.cache["D_v_neg_1"]
[docs] def N_v(self, u_idx: int) -> torch.Tensor:
r"""Return neighbor vertices in set :math:`\mathcal{V}` of the specified vertex ``u_idx`` with ``torch.Tensor`` format.
Args:
``u_idx`` (``int``): The index of the vertex.
"""
sub_v_set = self.B[u_idx]._indices()[0].clone()
return sub_v_set
[docs] def N_u(self, v_idx: int) -> torch.Tensor:
r"""Return neighbor vertices in set :math:`\mathcal{U}` of the specified vertex ``v_idx`` with ``torch.Tensor`` format.
Args:
``v_idx`` (``int``): The index of the vertex.
"""
sub_u_set = self.B_T[v_idx]._indices()[0].clone()
return sub_u_set
@property
def e_u(self) -> torch.Tensor:
r"""Return the index vector :math:`\vec{e}_{u}` of vertices in set :math:`\mathcal{U}` in the bipartite graph with ``torch.Tensor`` format. Size :math:`(|\mathcal{E}|,)`.
"""
return self.B._indices()[0, :].clone()
@property
def e_v(self) -> torch.Tensor:
r"""Return the index vector :math:`\vec{e}_{v}` of vertices in set :math:`\mathcal{V}` in the bipartite graph with ``torch.Tensor`` format. Size :math:`(|\mathcal{E}|,)`.
"""
return self.B._indices()[1, :].clone()
@property
def e_weight(self) -> torch.Tensor:
r"""Return the weight vector :math:`\vec{e}_{weight}` of edges in the bipartite graph with ``torch.Tensor`` format. Size :math:`(|\mathcal{E}|,)`.
"""
return self.B._values().clone()
# ==============================================================================
# spectral-based convolution/smoothing
[docs] def smoothing(self, X: torch.Tensor, L: torch.Tensor, lamb: float) -> torch.Tensor:
return super().smoothing(X, L, lamb)
@property
def L_GCN(self) -> torch.Tensor:
r"""Return the GCN Laplacian matrix of the bipartite graph with ``torch.Tensor`` format. Size :math:`(|\mathcal{U}| + |\mathcal{V}|, |\mathcal{U}| + |\mathcal{V}|)`.
"""
if self.cache.get("L_GCN", None) is None:
selfloop_indices = torch.arange(0, self.num_u + self.num_v).view(1, -1).repeat(2, 1)
selfloop_values = torch.ones(self.num_u + self.num_v).view(-1)
A_ = torch.sparse_coo_tensor(
indices=torch.hstack([self.A._indices().cpu(), selfloop_indices]),
values=torch.hstack([self.A._values().cpu(), selfloop_values]),
size=torch.Size([self.num_u + self.num_v, self.num_u + self.num_v]),
device=self.device,
).coalesce()
D_v_neg_1_2 = torch.sparse.sum(A_, dim=1).to_dense().view(-1) ** (-0.5)
D_v_neg_1_2[torch.isinf(D_v_neg_1_2)] = 0
D_v_neg_1_2 = torch.sparse_coo_tensor(
indices=selfloop_indices,
values=D_v_neg_1_2,
size=torch.Size([self.num_u + self.num_v, self.num_u + self.num_v]),
device=self.device,
).coalesce()
_mm = torch.sparse.mm
self.cache["L_GCN"] = _mm(D_v_neg_1_2, _mm(A_, D_v_neg_1_2)).clone().coalesce()
return self.cache["L_GCN"]
[docs] def smoothing_with_GCN(self, X: torch.Tensor):
r"""Return the smoothed feature matrix with GCN Laplacian matrix :math:`\mathcal{L}_{GCN}`.
Args:
``X`` (``torch.Tensor``): Vertex feature matrix of the bipartite graph. Size :math:`(|\mathcal{U}| + |\mathcal{V}|, C)`.
"""
if self.device != X.device:
X = X.to(self.device)
return torch.sparse.mm(self.L_GCN, X)
# ==============================================================================
# spatial-based convolution/message-passing functions
# general message passing
[docs] def u2v(self, X: torch.Tensor, aggr: str = "mean", e_weight: Optional[torch.Tensor] = None,) -> torch.Tensor:
r"""Message passing from vertices in set :math:`\mathcal{U}` to vertices in set :math:`\mathcal{V}` on the bipartite graph structure.
Args:
``X`` (``torch.Tensor``): Feature matrix of vertices in set :math:`\mathcal{U}`. Size: :math:`(|\mathcal{U}|, C)`.
``aggr`` (``str``, optional): Aggregation function for neighbor messages, which can be ``'mean'``, ``'sum'``, or ``'softmax_then_sum'``. Default: ``'mean'``.
``e_weight`` (``torch.Tensor``, optional): The edge weight vector. Size: :math:`(|\mathcal{E}|,)`. Defaults to ``None``.
"""
assert aggr in ["mean", "sum", "softmax_then_sum",], "aggr must be one of ['mean', 'sum', 'softmax_then_sum']"
if self.device != X.device:
self.to(X.device)
if e_weight is None:
# message passing
if aggr == "mean":
X = torch.sparse.mm(self.B_T, X)
X = torch.sparse.mm(self.D_v_neg_1, X)
elif aggr == "sum":
X = torch.sparse.mm(self.B_T, X)
elif aggr == "softmax_then_sum":
P = torch.sparse.softmax(self.B_T, dim=1)
X = torch.sparse.mm(P, X)
else:
pass
else:
# init adjacency matrix
assert (
e_weight.shape[0] == self.e_weight.shape[0]
), "The size of e_weight must be equal to the size of self.e_weight."
P = torch.sparse_coo_tensor(self.B._indices(), e_weight, self.B.shape, device=self.device).t().coalesce()
# message passing
if aggr == "mean":
X = torch.sparse.mm(P, X)
D_v_neg_1 = torch.sparse.sum(P, dim=1).to_dense().view(-1, 1)
D_v_neg_1[torch.isinf(D_v_neg_1)] = 0
X = D_v_neg_1 * X
elif aggr == "sum":
X = torch.sparse.mm(P, X)
elif aggr == "softmax_then_sum":
P = torch.sparse.softmax(P, dim=1)
X = torch.sparse.mm(P, X)
else:
pass
return X
[docs] def v2u(self, X: torch.Tensor, aggr: str = "mean", e_weight: Optional[torch.Tensor] = None,) -> torch.Tensor:
r"""Message passing from vertices in set :math:`\mathcal{V}` to vertices in set :math:`\mathcal{U}` on the bipartite graph structure.
Args:
``X`` (``torch.Tensor``): Feature matrix of vertices in set :math:`\mathcal{V}`. Size: :math:`(|\mathcal{V}|, C)`.
``aggr`` (``str``, optional): Aggregation function for neighbor messages, which can be ``'mean'``, ``'sum'``, or ``'softmax_then_sum'``. Default: ``'mean'``.
``e_weight`` (``torch.Tensor``, optional): The edge weight vector. Size: :math:`(|\mathcal{E}|,)`. Defaults to ``None``.
"""
assert aggr in ["mean", "sum", "softmax_then_sum",], "aggr must be one of ['mean', 'sum', 'softmax_then_sum']"
if self.device != X.device:
self.to(X.device)
if e_weight is None:
# message passing
if aggr == "mean":
X = torch.sparse.mm(self.B, X)
X = torch.sparse.mm(self.D_u_neg_1, X)
elif aggr == "sum":
X = torch.sparse.mm(self.B, X)
elif aggr == "softmax_then_sum":
P = torch.sparse.softmax(self.B, dim=1)
X = torch.sparse.mm(P, X)
else:
pass
else:
# init adjacency matrix
assert (
e_weight.shape[0] == self.e_weight.shape[0]
), "The size of e_weight must be equal to the size of self.e_weight."
P = torch.sparse_coo_tensor(self.B._indices(), e_weight, self.B.shape, device=self.device).coalesce()
# message passing
if aggr == "mean":
X = torch.sparse.mm(P, X)
D_u_neg_1 = torch.sparse.sum(P, dim=1).to_dense().view(-1, 1)
D_u_neg_1[torch.isinf(D_u_neg_1)] = 0
X = D_u_neg_1 * X
elif aggr == "sum":
X = torch.sparse.mm(P, X)
elif aggr == "softmax_then_sum":
P = torch.sparse.softmax(P, dim=1)
X = torch.sparse.mm(P, X)
else:
pass
return X