MultiHeadWrapper

class dhg.nn.MultiHeadWrapper(*args, **kwargs)[source]

Bases: torch.nn.Module

A wrapper to apply multiple heads to a given layer.

Parameters
  • num_heads (int) – The number of heads.

  • readout (bool) – The readout method. Can be "mean", "max", "sum", or "concat".

  • layer (nn.Module) – The layer to apply multiple heads.

  • **kwargs – The keyword arguments for the layer.

Example

>>> import torch
>>> import dhg
>>> from dhg.nn import GATConv, MultiHeadWrapper
>>> multi_head_layer = MultiHeadWrapper(
        4,
        "concat",
        GATConv,
        in_channels=16,
        out_channels=8,
    )
>>> X = torch.rand(20, 16)
>>> g = dhg.random.graph_Gnm(20, 15)
>>> X_ = multi_head_layer(X=X, g=g)
forward(**kwargs)[source]

The forward function.

Note

You must explicitly pass the keyword arguments to the layer. For example, if the layer is GATConv, you must pass X=X and g=g.