MultiHeadWrapper
- class dhg.nn.MultiHeadWrapper(*args, **kwargs)[source]
Bases:
torch.nn.ModuleA wrapper to apply multiple heads to a given layer.
- Parameters
num_heads (
int) – The number of heads.readout (
bool) – The readout method. Can be"mean","max","sum", or"concat".layer (
nn.Module) – The layer to apply multiple heads.**kwargs – The keyword arguments for the layer.
Example
>>> import torch >>> import dhg >>> from dhg.nn import GATConv, MultiHeadWrapper >>> multi_head_layer = MultiHeadWrapper( 4, "concat", GATConv, in_channels=16, out_channels=8, ) >>> X = torch.rand(20, 16) >>> g = dhg.random.graph_Gnm(20, 15) >>> X_ = multi_head_layer(X=X, g=g)