Source code for dhg.nn.convs.common

from typing import List, Optional

import torch
import torch.nn as nn


[docs]class MLP(nn.Module): r"""A Multi-Layer Perception (MLP) model. Args: ``channel_list`` (``List[int]``): The list of channels of each layer. ``act_name`` (``str``): The name of activation function can be any `activation layer <https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity>`_ in Pytorch. ``act_kwargs`` (``dict``, optional): The keyword arguments of activation function. Defaults to ``None``. ``use_bn`` (``bool``): Whether to use batch normalization. ``drop_rate`` (``float``): Dropout ratio. Defaults to ``0.5``. ``is_last`` (``bool``): If set to True, the last layer will not use activation, batch normalization, and dropout. """ def __init__( self, channel_list: List[int], act_name: str = "ReLU", act_kwargs: Optional[dict] = None, use_bn: bool = True, drop_rate: float = 0.5, is_last: bool = False, ) -> None: assert len(channel_list) >= 2, "At least two elements in channel_list." super().__init__() act_module = getattr(nn.modules.activation, act_name) self.layers = nn.ModuleList() for _idx in range(1, len(channel_list) - 1): self.layers.append(nn.Linear(channel_list[_idx - 1], channel_list[_idx])) self.layers.append(act_module(**({} if act_kwargs is None else act_kwargs))) if use_bn: self.layers.append(nn.BatchNorm1d(channel_list[_idx])) self.layers.append(nn.Dropout(drop_rate)) if is_last: self.layers.append(nn.Linear(channel_list[-2], channel_list[-1])) else: self.layers.append(nn.Linear(channel_list[-2], channel_list[-1])) self.layers.append(act_module(**({} if act_kwargs is None else act_kwargs))) if use_bn: self.layers.append(nn.BatchNorm1d(channel_list[-1])) self.layers.append(nn.Dropout(drop_rate))
[docs] def forward(self, X: torch.Tensor) -> torch.Tensor: r"""The forward function. """ for layer in self.layers: X = layer(X) return X