EmbeddingRegularization

class dhg.nn.EmbeddingRegularization(*args, **kwargs)[source]

Bases: torch.nn.Module

Regularization function for embeddings.

Parameters
  • p (int) – The power to use in the regularization. Defaults to 2.

  • weight_decay (float) – The weight of the regularization. Defaults to 1e-4.

forward(*embs)[source]

The forward function.

Parameters

embs (List[torch.Tensor]) – The input embeddings.