On Hypergraph

Hint

In the following examples, three typical graph/hypergraph neural networks are used to perform vertex classification task on the hypergraph structure.

Models

Dataset

The Cooking 200 dataset (dhg.data.Cooking200) is collected from Yummly.com for vertex classification task. It is a hypergraph dataset, in which vertex denotes the dish and hyperedge denotes the ingredient. Each dish is also associated with category information, which indicates the dish’s cuisine like Chinese, Japanese, French, and Russian.

Note

The dataset is a hypergraph dataset, which cannot be directly used for GCN model. Thus, the clique expansion is adpoted to reduce the hypergraph structure to a graph structure.

Note

The dataset do not contain the vertex features. Thus, we generate a identity matrix for vertex features.

Warning

Generating identity matrix for vertex features will lead to unstable parameters in training stage. Thus, the batch_norm is used for the GCN, HGNN, and HGNN+ models in the following examples.

Results

Model

Accuracy on Validation

Accuracy on Testing

F1 score on Testing

GCN

0.500

0.434

0.356

HGNN

0.485

0.495

0.376

HGNN+

0.475

0.520

0.391

GCN on Cooking200

Import Libraries

import time
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn.functional as F

from dhg import Graph, Hypergraph
from dhg.data import Cooking200
from dhg.models import GCN
from dhg.random import set_seed
from dhg.metrics import HypergraphVertexClassificationEvaluator as Evaluator

Define Functions

def train(net, X, A, lbls, train_idx, optimizer, epoch):
    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")
    return loss.item()


@torch.no_grad()
def infer(net, X, A, lbls, idx, test=False):
    net.eval()
    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]
    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)
    return res

Main

Note

More details about the metric Evaluator can be found in the Building Evaluator section.

if __name__ == "__main__":
    set_seed(2021)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])
    data = Cooking200()

    X, lbl = torch.eye(data["num_vertices"]), data["labels"]
    ft_dim = X.shape[1]
    HG = Hypergraph(data["num_vertices"], data["edge_list"])
    G = Graph.from_hypergraph_clique(HG, weighted=True)
    train_mask = data["train_mask"]
    val_mask = data["val_mask"]
    test_mask = data["test_mask"]

    net = GCN(ft_dim, 32, data["num_classes"], use_bn=True)
    optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)

    X, lbl = X.to(device), lbl.to(device)
    G = G.to(device)
    net = net.to(device)

    best_state = None
    best_epoch, best_val = 0, 0
    for epoch in range(200):
        # train
        train(net, X, G, lbl, train_mask, optimizer, epoch)
        # validation
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = infer(net, X, G, lbl, val_mask)
            if val_res > best_val:
                print(f"update best: {val_res:.5f}")
                best_epoch = epoch
                best_val = val_res
                best_state = deepcopy(net.state_dict())
    print("\ntrain finished!")
    print(f"best val: {best_val:.5f}")
    # test
    print("test...")
    net.load_state_dict(best_state)
    res = infer(net, X, G, lbl, test_mask, test=True)
    print(f"final result: epoch: {best_epoch}")
    print(res)

Outputs

Epoch: 0, Time: 7.29884s, Loss: 3.02374
update best: 0.05000
Epoch: 1, Time: 0.02545s, Loss: 2.47223
Epoch: 2, Time: 0.02411s, Loss: 2.41279
update best: 0.05500
Epoch: 3, Time: 0.02656s, Loss: 2.36803
update best: 0.07500
Epoch: 4, Time: 0.02486s, Loss: 2.33794
Epoch: 5, Time: 0.02224s, Loss: 2.30590
Epoch: 6, Time: 0.02089s, Loss: 2.28631
Epoch: 7, Time: 0.02136s, Loss: 2.25775
Epoch: 8, Time: 0.02186s, Loss: 2.24081
update best: 0.08000
Epoch: 9, Time: 0.02203s, Loss: 2.22660
update best: 0.09500
Epoch: 10, Time: 0.02155s, Loss: 2.20722
update best: 0.14500
Epoch: 11, Time: 0.02141s, Loss: 2.19497
Epoch: 12, Time: 0.02263s, Loss: 2.17880
Epoch: 13, Time: 0.02199s, Loss: 2.16433
Epoch: 14, Time: 0.02258s, Loss: 2.15038
Epoch: 15, Time: 0.02230s, Loss: 2.13811
Epoch: 16, Time: 0.02135s, Loss: 2.12440
Epoch: 17, Time: 0.02217s, Loss: 2.11146
Epoch: 18, Time: 0.02183s, Loss: 2.10333
Epoch: 19, Time: 0.03591s, Loss: 2.09031
Epoch: 20, Time: 0.02081s, Loss: 2.07710
Epoch: 21, Time: 0.02111s, Loss: 2.06423
Epoch: 22, Time: 0.02114s, Loss: 2.05410
Epoch: 23, Time: 0.02137s, Loss: 2.04545
update best: 0.15500
Epoch: 24, Time: 0.02159s, Loss: 2.03412
update best: 0.16000
Epoch: 25, Time: 0.02189s, Loss: 2.01589
update best: 0.17500
Epoch: 26, Time: 0.02204s, Loss: 2.01508
Epoch: 27, Time: 0.02206s, Loss: 1.99630
Epoch: 28, Time: 0.02180s, Loss: 1.98635
update best: 0.18500
Epoch: 29, Time: 0.02168s, Loss: 1.97526
update best: 0.20000
Epoch: 30, Time: 0.02155s, Loss: 1.96057
update best: 0.21000
Epoch: 31, Time: 0.02147s, Loss: 1.95878
update best: 0.21500
Epoch: 32, Time: 0.02174s, Loss: 1.94054
Epoch: 33, Time: 0.02147s, Loss: 1.93238
Epoch: 34, Time: 0.02176s, Loss: 1.92268
update best: 0.23000
Epoch: 35, Time: 0.02169s, Loss: 1.91224
update best: 0.24000
Epoch: 36, Time: 0.02141s, Loss: 1.89593
update best: 0.25000
Epoch: 37, Time: 0.02133s, Loss: 1.89175
update best: 0.25500
Epoch: 38, Time: 0.02230s, Loss: 1.88137
Epoch: 39, Time: 0.02201s, Loss: 1.87121
Epoch: 40, Time: 0.02050s, Loss: 1.85513
Epoch: 41, Time: 0.02120s, Loss: 1.85149
Epoch: 42, Time: 0.02102s, Loss: 1.83702
update best: 0.27000
Epoch: 43, Time: 0.02095s, Loss: 1.82509
update best: 0.27500
Epoch: 44, Time: 0.02139s, Loss: 1.81752
update best: 0.29000
Epoch: 45, Time: 0.02115s, Loss: 1.80817
Epoch: 46, Time: 0.02119s, Loss: 1.79938
update best: 0.29500
Epoch: 47, Time: 0.02088s, Loss: 1.78561
update best: 0.33000
Epoch: 48, Time: 0.02106s, Loss: 1.78137
update best: 0.34000
Epoch: 49, Time: 0.02088s, Loss: 1.76117
update best: 0.34500
Epoch: 50, Time: 0.02143s, Loss: 1.75598
update best: 0.36000
Epoch: 51, Time: 0.02129s, Loss: 1.74965
Epoch: 52, Time: 0.02177s, Loss: 1.73695
Epoch: 53, Time: 0.02160s, Loss: 1.72132
update best: 0.36500
Epoch: 54, Time: 0.02177s, Loss: 1.71943
update best: 0.37000
Epoch: 55, Time: 0.02115s, Loss: 1.71475
update best: 0.37500
Epoch: 56, Time: 0.02157s, Loss: 1.69237
update best: 0.38500
Epoch: 57, Time: 0.02164s, Loss: 1.68571
update best: 0.39500
Epoch: 58, Time: 0.02150s, Loss: 1.67695
update best: 0.40000
Epoch: 59, Time: 0.02156s, Loss: 1.66385
Epoch: 60, Time: 0.02155s, Loss: 1.65498
Epoch: 61, Time: 0.02102s, Loss: 1.65138
update best: 0.41000
Epoch: 62, Time: 0.02167s, Loss: 1.63215
update best: 0.42000
Epoch: 63, Time: 0.02174s, Loss: 1.62920
update best: 0.43500
Epoch: 64, Time: 0.02154s, Loss: 1.61913
update best: 0.44000
Epoch: 65, Time: 0.02159s, Loss: 1.61141
Epoch: 66, Time: 0.02195s, Loss: 1.60337
Epoch: 67, Time: 0.02069s, Loss: 1.58908
update best: 0.45500
Epoch: 68, Time: 0.02115s, Loss: 1.57248
Epoch: 69, Time: 0.02138s, Loss: 1.57386
update best: 0.46500
Epoch: 70, Time: 0.02106s, Loss: 1.56231
Epoch: 71, Time: 0.02118s, Loss: 1.55329
Epoch: 72, Time: 0.02242s, Loss: 1.54713
Epoch: 73, Time: 0.02136s, Loss: 1.53178
Epoch: 74, Time: 0.02172s, Loss: 1.52513
Epoch: 75, Time: 0.02200s, Loss: 1.51584
Epoch: 76, Time: 0.02123s, Loss: 1.50966
update best: 0.47000
Epoch: 77, Time: 0.02147s, Loss: 1.50546
update best: 0.47500
Epoch: 78, Time: 0.02270s, Loss: 1.49482
Epoch: 79, Time: 0.02264s, Loss: 1.47653
Epoch: 80, Time: 0.02349s, Loss: 1.46740
Epoch: 81, Time: 0.02231s, Loss: 1.46205
Epoch: 82, Time: 0.02251s, Loss: 1.44632
Epoch: 83, Time: 0.02184s, Loss: 1.44394
Epoch: 84, Time: 0.02175s, Loss: 1.43398
Epoch: 85, Time: 0.02109s, Loss: 1.43450
Epoch: 86, Time: 0.02110s, Loss: 1.41855
Epoch: 87, Time: 0.02112s, Loss: 1.41488
Epoch: 88, Time: 0.02119s, Loss: 1.40113
Epoch: 89, Time: 0.02133s, Loss: 1.38627
Epoch: 90, Time: 0.02178s, Loss: 1.38061
Epoch: 91, Time: 0.02106s, Loss: 1.38012
Epoch: 92, Time: 0.02245s, Loss: 1.36612
Epoch: 93, Time: 0.02165s, Loss: 1.36384
Epoch: 94, Time: 0.02169s, Loss: 1.35315
Epoch: 95, Time: 0.02287s, Loss: 1.33591
Epoch: 96, Time: 0.02321s, Loss: 1.33441
Epoch: 97, Time: 0.02267s, Loss: 1.32461
Epoch: 98, Time: 0.02246s, Loss: 1.31650
Epoch: 99, Time: 0.02192s, Loss: 1.30920
Epoch: 100, Time: 0.02145s, Loss: 1.29616
Epoch: 101, Time: 0.02106s, Loss: 1.28773
Epoch: 102, Time: 0.02128s, Loss: 1.28913
Epoch: 103, Time: 0.02125s, Loss: 1.27793
Epoch: 104, Time: 0.02174s, Loss: 1.27127
Epoch: 105, Time: 0.02135s, Loss: 1.26090
Epoch: 106, Time: 0.02187s, Loss: 1.25673
Epoch: 107, Time: 0.02137s, Loss: 1.23971
Epoch: 108, Time: 0.02163s, Loss: 1.23427
Epoch: 109, Time: 0.02173s, Loss: 1.23829
Epoch: 110, Time: 0.02228s, Loss: 1.21614
Epoch: 111, Time: 0.02190s, Loss: 1.22033
Epoch: 112, Time: 0.02146s, Loss: 1.21155
update best: 0.48000
Epoch: 113, Time: 0.02183s, Loss: 1.19760
Epoch: 114, Time: 0.02472s, Loss: 1.20577
Epoch: 115, Time: 0.02249s, Loss: 1.18268
Epoch: 116, Time: 0.02274s, Loss: 1.17723
Epoch: 117, Time: 0.02290s, Loss: 1.16582
Epoch: 118, Time: 0.02262s, Loss: 1.16943
Epoch: 119, Time: 0.02180s, Loss: 1.16023
Epoch: 120, Time: 0.02193s, Loss: 1.14612
update best: 0.48500
Epoch: 121, Time: 0.02191s, Loss: 1.14254
Epoch: 122, Time: 0.02162s, Loss: 1.13199
Epoch: 123, Time: 0.02136s, Loss: 1.12077
Epoch: 124, Time: 0.02165s, Loss: 1.11500
Epoch: 125, Time: 0.02177s, Loss: 1.11730
Epoch: 126, Time: 0.02150s, Loss: 1.10626
Epoch: 127, Time: 0.02119s, Loss: 1.09788
Epoch: 128, Time: 0.02119s, Loss: 1.09148
Epoch: 129, Time: 0.02130s, Loss: 1.08841
Epoch: 130, Time: 0.02211s, Loss: 1.08878
Epoch: 131, Time: 0.02171s, Loss: 1.08039
Epoch: 132, Time: 0.02172s, Loss: 1.06337
Epoch: 133, Time: 0.02185s, Loss: 1.05798
Epoch: 134, Time: 0.02197s, Loss: 1.05995
Epoch: 135, Time: 0.02310s, Loss: 1.04716
Epoch: 136, Time: 0.02271s, Loss: 1.03834
update best: 0.49000
Epoch: 137, Time: 0.02218s, Loss: 1.03407
Epoch: 138, Time: 0.02329s, Loss: 1.02641
Epoch: 139, Time: 0.02310s, Loss: 1.02540
Epoch: 140, Time: 0.02245s, Loss: 1.02152
Epoch: 141, Time: 0.02171s, Loss: 1.01990
Epoch: 142, Time: 0.02151s, Loss: 1.00520
Epoch: 143, Time: 0.02128s, Loss: 1.01225
Epoch: 144, Time: 0.02179s, Loss: 1.00302
Epoch: 145, Time: 0.02164s, Loss: 0.98153
Epoch: 146, Time: 0.02117s, Loss: 0.97740
Epoch: 147, Time: 0.02110s, Loss: 0.97149
Epoch: 148, Time: 0.02131s, Loss: 0.97149
Epoch: 149, Time: 0.02128s, Loss: 0.97657
Epoch: 150, Time: 0.02155s, Loss: 0.95241
Epoch: 151, Time: 0.02171s, Loss: 0.96010
Epoch: 152, Time: 0.02174s, Loss: 0.94509
Epoch: 153, Time: 0.02167s, Loss: 0.94987
Epoch: 154, Time: 0.02262s, Loss: 0.94258
Epoch: 155, Time: 0.02226s, Loss: 0.93526
Epoch: 156, Time: 0.02236s, Loss: 0.93201
Epoch: 157, Time: 0.02148s, Loss: 0.92291
Epoch: 158, Time: 0.02158s, Loss: 0.93494
Epoch: 159, Time: 0.02159s, Loss: 0.91413
Epoch: 160, Time: 0.02150s, Loss: 0.91853
Epoch: 161, Time: 0.02143s, Loss: 0.90566
Epoch: 162, Time: 0.02117s, Loss: 0.90713
Epoch: 163, Time: 0.02124s, Loss: 0.89651
Epoch: 164, Time: 0.02103s, Loss: 0.89034
Epoch: 165, Time: 0.02168s, Loss: 0.88661
Epoch: 166, Time: 0.02163s, Loss: 0.88348
Epoch: 167, Time: 0.02174s, Loss: 0.87290
Epoch: 168, Time: 0.02185s, Loss: 0.87435
Epoch: 169, Time: 0.02155s, Loss: 0.86458
Epoch: 170, Time: 0.02088s, Loss: 0.87389
Epoch: 171, Time: 0.02264s, Loss: 0.86114
Epoch: 172, Time: 0.02286s, Loss: 0.84979
Epoch: 173, Time: 0.02272s, Loss: 0.85025
Epoch: 174, Time: 0.02237s, Loss: 0.85343
Epoch: 175, Time: 0.02243s, Loss: 0.84297
Epoch: 176, Time: 0.02235s, Loss: 0.84274
Epoch: 177, Time: 0.02185s, Loss: 0.83616
Epoch: 178, Time: 0.02188s, Loss: 0.83237
Epoch: 179, Time: 0.02110s, Loss: 0.83829
Epoch: 180, Time: 0.02102s, Loss: 0.83292
Epoch: 181, Time: 0.02157s, Loss: 0.82355
Epoch: 182, Time: 0.02148s, Loss: 0.82146
Epoch: 183, Time: 0.02148s, Loss: 0.82488
Epoch: 184, Time: 0.02128s, Loss: 0.81608
Epoch: 185, Time: 0.02128s, Loss: 0.81082
Epoch: 186, Time: 0.02121s, Loss: 0.81338
Epoch: 187, Time: 0.02183s, Loss: 0.81301
Epoch: 188, Time: 0.02234s, Loss: 0.79188
Epoch: 189, Time: 0.02182s, Loss: 0.79709
update best: 0.50000
Epoch: 190, Time: 0.02134s, Loss: 0.78706
Epoch: 191, Time: 0.02183s, Loss: 0.77257
Epoch: 192, Time: 0.02276s, Loss: 0.77896
Epoch: 193, Time: 0.02326s, Loss: 0.77773
Epoch: 194, Time: 0.02287s, Loss: 0.76515
Epoch: 195, Time: 0.02281s, Loss: 0.76747
Epoch: 196, Time: 0.02164s, Loss: 0.76833
Epoch: 197, Time: 0.02182s, Loss: 0.75029
Epoch: 198, Time: 0.02136s, Loss: 0.76452
Epoch: 199, Time: 0.02135s, Loss: 0.75916

train finished!
best val: 0.50000
test...
final result: epoch: 189
{'accuracy': 0.4340996742248535, 'f1_score': 0.35630662515488015, 'f1_score -> average@micro': 0.43409967156932744}

HGNN on Cooking200

Import Libraries

import time
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn.functional as F

from dhg import Hypergraph
from dhg.data import Cooking200
from dhg.models import HGNN
from dhg.random import set_seed
from dhg.metrics import HypergraphVertexClassificationEvaluator as Evaluator

Define Functions

def train(net, X, A, lbls, train_idx, optimizer, epoch):
    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")
    return loss.item()


@torch.no_grad()
def infer(net, X, A, lbls, idx, test=False):
    net.eval()
    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]
    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)
    return res

Main

Note

More details about the metric Evaluator can be found in the Building Evaluator section.

if __name__ == "__main__":
    set_seed(2021)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])
    data = Cooking200()

    X, lbl = torch.eye(data["num_vertices"]), data["labels"]
    G = Hypergraph(data["num_vertices"], data["edge_list"])
    train_mask = data["train_mask"]
    val_mask = data["val_mask"]
    test_mask = data["test_mask"]

    net = HGNN(X.shape[1], 32, data["num_classes"], use_bn=True)
    optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)

    X, lbl = X.to(device), lbl.to(device)
    G = G.to(device)
    net = net.to(device)

    best_state = None
    best_epoch, best_val = 0, 0
    for epoch in range(200):
        # train
        train(net, X, G, lbl, train_mask, optimizer, epoch)
        # validation
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = infer(net, X, G, lbl, val_mask)
            if val_res > best_val:
                print(f"update best: {val_res:.5f}")
                best_epoch = epoch
                best_val = val_res
                best_state = deepcopy(net.state_dict())
    print("\ntrain finished!")
    print(f"best val: {best_val:.5f}")
    # test
    print("test...")
    net.load_state_dict(best_state)
    res = infer(net, X, G, lbl, test_mask, test=True)
    print(f"final result: epoch: {best_epoch}")
    print(res)

Outputs

Epoch: 0, Time: 0.57807s, Loss: 2.99290
update best: 0.10000
Epoch: 1, Time: 0.02624s, Loss: 2.28624
Epoch: 2, Time: 0.02707s, Loss: 2.15988
Epoch: 3, Time: 0.02373s, Loss: 2.05894
Epoch: 4, Time: 0.02545s, Loss: 1.99918
Epoch: 5, Time: 0.02619s, Loss: 1.92948
Epoch: 6, Time: 0.02215s, Loss: 1.88097
Epoch: 7, Time: 0.02229s, Loss: 1.83393
Epoch: 8, Time: 0.02181s, Loss: 1.79070
Epoch: 9, Time: 0.02256s, Loss: 1.75345
Epoch: 10, Time: 0.02264s, Loss: 1.70969
Epoch: 11, Time: 0.02248s, Loss: 1.68242
Epoch: 12, Time: 0.02248s, Loss: 1.64419
Epoch: 13, Time: 0.02257s, Loss: 1.60876
Epoch: 14, Time: 0.02238s, Loss: 1.58108
Epoch: 15, Time: 0.02194s, Loss: 1.54466
Epoch: 16, Time: 0.02172s, Loss: 1.52140
Epoch: 17, Time: 0.02130s, Loss: 1.48225
Epoch: 18, Time: 0.02156s, Loss: 1.46237
Epoch: 19, Time: 0.02133s, Loss: 1.43527
Epoch: 20, Time: 0.02148s, Loss: 1.40451
Epoch: 21, Time: 0.02133s, Loss: 1.39555
Epoch: 22, Time: 0.02182s, Loss: 1.36368
Epoch: 23, Time: 0.02151s, Loss: 1.33732
Epoch: 24, Time: 0.02178s, Loss: 1.32686
Epoch: 25, Time: 0.02232s, Loss: 1.30681
Epoch: 26, Time: 0.02289s, Loss: 1.28287
Epoch: 27, Time: 0.02245s, Loss: 1.28563
Epoch: 28, Time: 0.02210s, Loss: 1.24644
Epoch: 29, Time: 0.02195s, Loss: 1.22813
Epoch: 30, Time: 0.02205s, Loss: 1.20336
Epoch: 31, Time: 0.02245s, Loss: 1.20308
Epoch: 32, Time: 0.02129s, Loss: 1.16802
Epoch: 33, Time: 0.02144s, Loss: 1.17182
Epoch: 34, Time: 0.02215s, Loss: 1.14047
Epoch: 35, Time: 0.02195s, Loss: 1.13377
Epoch: 36, Time: 0.02233s, Loss: 1.09250
Epoch: 37, Time: 0.02283s, Loss: 1.09588
Epoch: 38, Time: 0.02356s, Loss: 1.09042
Epoch: 39, Time: 0.02211s, Loss: 1.08532
Epoch: 40, Time: 0.02340s, Loss: 1.04074
update best: 0.11000
Epoch: 41, Time: 0.02125s, Loss: 1.05056
update best: 0.13500
Epoch: 42, Time: 0.02302s, Loss: 1.02834
update best: 0.14000
Epoch: 43, Time: 0.02278s, Loss: 0.99903
update best: 0.14500
Epoch: 44, Time: 0.02238s, Loss: 1.01756
update best: 0.15000
Epoch: 45, Time: 0.02286s, Loss: 0.99652
update best: 0.17500
Epoch: 46, Time: 0.02251s, Loss: 0.97935
update best: 0.21500
Epoch: 47, Time: 0.02234s, Loss: 0.97873
update best: 0.24500
Epoch: 48, Time: 0.02245s, Loss: 0.95888
update best: 0.26000
Epoch: 49, Time: 0.02228s, Loss: 0.95761
update best: 0.28000
Epoch: 50, Time: 0.02254s, Loss: 0.94229
Epoch: 51, Time: 0.02264s, Loss: 0.92833
update best: 0.29000
Epoch: 52, Time: 0.02238s, Loss: 0.92601
update best: 0.30000
Epoch: 53, Time: 0.02311s, Loss: 0.90252
update best: 0.31000
Epoch: 54, Time: 0.02189s, Loss: 0.89501
update best: 0.32500
Epoch: 55, Time: 0.02193s, Loss: 0.89724
Epoch: 56, Time: 0.02246s, Loss: 0.87068
update best: 0.33500
Epoch: 57, Time: 0.02181s, Loss: 0.87531
update best: 0.34000
Epoch: 58, Time: 0.02287s, Loss: 0.84288
update best: 0.34500
Epoch: 59, Time: 0.02227s, Loss: 0.84243
update best: 0.36500
Epoch: 60, Time: 0.02149s, Loss: 0.83892
update best: 0.38500
Epoch: 61, Time: 0.02253s, Loss: 0.83062
update best: 0.40000
Epoch: 62, Time: 0.02271s, Loss: 0.82245
update best: 0.42000
Epoch: 63, Time: 0.02195s, Loss: 0.81214
update best: 0.43000
Epoch: 64, Time: 0.02162s, Loss: 0.80847
update best: 0.44000
Epoch: 65, Time: 0.02136s, Loss: 0.78325
Epoch: 66, Time: 0.02245s, Loss: 0.79052
update best: 0.45500
Epoch: 67, Time: 0.02248s, Loss: 0.78128
Epoch: 68, Time: 0.02295s, Loss: 0.77049
Epoch: 69, Time: 0.02315s, Loss: 0.75469
Epoch: 70, Time: 0.02331s, Loss: 0.74771
Epoch: 71, Time: 0.02317s, Loss: 0.73701
Epoch: 72, Time: 0.02307s, Loss: 0.74350
Epoch: 73, Time: 0.02176s, Loss: 0.73698
Epoch: 74, Time: 0.02164s, Loss: 0.72565
Epoch: 75, Time: 0.02148s, Loss: 0.70553
update best: 0.46500
Epoch: 76, Time: 0.02136s, Loss: 0.71696
Epoch: 77, Time: 0.02111s, Loss: 0.72410
Epoch: 78, Time: 0.02111s, Loss: 0.71131
update best: 0.47000
Epoch: 79, Time: 0.02180s, Loss: 0.68748
Epoch: 80, Time: 0.02095s, Loss: 0.68774
Epoch: 81, Time: 0.02147s, Loss: 0.70136
Epoch: 82, Time: 0.02122s, Loss: 0.66882
Epoch: 83, Time: 0.02164s, Loss: 0.64563
Epoch: 84, Time: 0.02149s, Loss: 0.66794
Epoch: 85, Time: 0.02194s, Loss: 0.65860
Epoch: 86, Time: 0.02157s, Loss: 0.66000
Epoch: 87, Time: 0.02267s, Loss: 0.65452
Epoch: 88, Time: 0.02250s, Loss: 0.64512
Epoch: 89, Time: 0.02169s, Loss: 0.64318
Epoch: 90, Time: 0.02175s, Loss: 0.63814
Epoch: 91, Time: 0.02177s, Loss: 0.62040
Epoch: 92, Time: 0.02108s, Loss: 0.61942
Epoch: 93, Time: 0.02111s, Loss: 0.61757
Epoch: 94, Time: 0.02118s, Loss: 0.60520
Epoch: 95, Time: 0.02112s, Loss: 0.58358
Epoch: 96, Time: 0.02129s, Loss: 0.58866
Epoch: 97, Time: 0.02171s, Loss: 0.58599
Epoch: 98, Time: 0.02220s, Loss: 0.59330
Epoch: 99, Time: 0.02243s, Loss: 0.56555
Epoch: 100, Time: 0.02262s, Loss: 0.57273
Epoch: 101, Time: 0.02240s, Loss: 0.57785
Epoch: 102, Time: 0.02086s, Loss: 0.56949
Epoch: 103, Time: 0.02111s, Loss: 0.55187
Epoch: 104, Time: 0.02136s, Loss: 0.55166
Epoch: 105, Time: 0.02119s, Loss: 0.54706
Epoch: 106, Time: 0.02107s, Loss: 0.55239
Epoch: 107, Time: 0.02136s, Loss: 0.53656
Epoch: 108, Time: 0.02115s, Loss: 0.53478
Epoch: 109, Time: 0.02146s, Loss: 0.52564
Epoch: 110, Time: 0.02189s, Loss: 0.52242
Epoch: 111, Time: 0.02248s, Loss: 0.52779
Epoch: 112, Time: 0.02191s, Loss: 0.50813
Epoch: 113, Time: 0.02182s, Loss: 0.51623
Epoch: 114, Time: 0.02143s, Loss: 0.51834
Epoch: 115, Time: 0.02220s, Loss: 0.49232
Epoch: 116, Time: 0.02117s, Loss: 0.51582
Epoch: 117, Time: 0.02116s, Loss: 0.49434
Epoch: 118, Time: 0.02110s, Loss: 0.49518
Epoch: 119, Time: 0.02147s, Loss: 0.49155
Epoch: 120, Time: 0.02122s, Loss: 0.48029
Epoch: 121, Time: 0.02153s, Loss: 0.49079
Epoch: 122, Time: 0.02151s, Loss: 0.48253
Epoch: 123, Time: 0.02170s, Loss: 0.46945
Epoch: 124, Time: 0.02259s, Loss: 0.47764
Epoch: 125, Time: 0.02228s, Loss: 0.47102
Epoch: 126, Time: 0.02196s, Loss: 0.45784
Epoch: 127, Time: 0.02184s, Loss: 0.46020
Epoch: 128, Time: 0.02245s, Loss: 0.45922
Epoch: 129, Time: 0.02191s, Loss: 0.46458
Epoch: 130, Time: 0.02215s, Loss: 0.46924
Epoch: 131, Time: 0.02222s, Loss: 0.45952
Epoch: 132, Time: 0.02226s, Loss: 0.44490
Epoch: 133, Time: 0.02174s, Loss: 0.44763
Epoch: 134, Time: 0.02143s, Loss: 0.45225
Epoch: 135, Time: 0.02149s, Loss: 0.42556
Epoch: 136, Time: 0.02141s, Loss: 0.42714
Epoch: 137, Time: 0.02150s, Loss: 0.43604
Epoch: 138, Time: 0.02171s, Loss: 0.42259
Epoch: 139, Time: 0.02168s, Loss: 0.41784
Epoch: 140, Time: 0.02149s, Loss: 0.41759
Epoch: 141, Time: 0.02125s, Loss: 0.41633
Epoch: 142, Time: 0.02220s, Loss: 0.42547
Epoch: 143, Time: 0.02271s, Loss: 0.41790
Epoch: 144, Time: 0.02280s, Loss: 0.39776
Epoch: 145, Time: 0.02264s, Loss: 0.41429
Epoch: 146, Time: 0.02128s, Loss: 0.39543
Epoch: 147, Time: 0.02141s, Loss: 0.39529
Epoch: 148, Time: 0.02100s, Loss: 0.41145
Epoch: 149, Time: 0.02103s, Loss: 0.40083
Epoch: 150, Time: 0.02170s, Loss: 0.39246
Epoch: 151, Time: 0.02154s, Loss: 0.39613
Epoch: 152, Time: 0.02188s, Loss: 0.38080
Epoch: 153, Time: 0.02213s, Loss: 0.39159
Epoch: 154, Time: 0.02236s, Loss: 0.38570
Epoch: 155, Time: 0.02209s, Loss: 0.38382
Epoch: 156, Time: 0.02146s, Loss: 0.37949
update best: 0.47500
Epoch: 157, Time: 0.02179s, Loss: 0.37078
Epoch: 158, Time: 0.02223s, Loss: 0.37063
Epoch: 159, Time: 0.02219s, Loss: 0.37556
Epoch: 160, Time: 0.02217s, Loss: 0.37468
Epoch: 161, Time: 0.02146s, Loss: 0.38581
update best: 0.48500
Epoch: 162, Time: 0.02278s, Loss: 0.36664
Epoch: 163, Time: 0.02172s, Loss: 0.35075
Epoch: 164, Time: 0.02139s, Loss: 0.35056
Epoch: 165, Time: 0.02156s, Loss: 0.36339
Epoch: 166, Time: 0.02149s, Loss: 0.36245
Epoch: 167, Time: 0.02133s, Loss: 0.34675
Epoch: 168, Time: 0.02141s, Loss: 0.36043
Epoch: 169, Time: 0.02148s, Loss: 0.34538
Epoch: 170, Time: 0.02128s, Loss: 0.34694
Epoch: 171, Time: 0.02138s, Loss: 0.33723
Epoch: 172, Time: 0.02260s, Loss: 0.34017
Epoch: 173, Time: 0.02259s, Loss: 0.33932
Epoch: 174, Time: 0.02307s, Loss: 0.33170
Epoch: 175, Time: 0.02290s, Loss: 0.31819
Epoch: 176, Time: 0.02261s, Loss: 0.33577
Epoch: 177, Time: 0.02269s, Loss: 0.34146
Epoch: 178, Time: 0.02284s, Loss: 0.33086
Epoch: 179, Time: 0.02215s, Loss: 0.34498
Epoch: 180, Time: 0.02317s, Loss: 0.33026
Epoch: 181, Time: 0.02228s, Loss: 0.32811
Epoch: 182, Time: 0.02216s, Loss: 0.33203
Epoch: 183, Time: 0.02248s, Loss: 0.31955
Epoch: 184, Time: 0.02239s, Loss: 0.34238
Epoch: 185, Time: 0.02253s, Loss: 0.30963
Epoch: 186, Time: 0.02240s, Loss: 0.31527
Epoch: 187, Time: 0.02199s, Loss: 0.31484
Epoch: 188, Time: 0.02200s, Loss: 0.32514
Epoch: 189, Time: 0.02171s, Loss: 0.32029
Epoch: 190, Time: 0.02169s, Loss: 0.32122
Epoch: 191, Time: 0.02157s, Loss: 0.30233
Epoch: 192, Time: 0.02125s, Loss: 0.30417
Epoch: 193, Time: 0.02159s, Loss: 0.30060
Epoch: 194, Time: 0.02142s, Loss: 0.29333
Epoch: 195, Time: 0.02155s, Loss: 0.29596
Epoch: 196, Time: 0.02158s, Loss: 0.30458
Epoch: 197, Time: 0.02204s, Loss: 0.29744
Epoch: 198, Time: 0.02227s, Loss: 0.29473
Epoch: 199, Time: 0.02259s, Loss: 0.30488

train finished!
best val: 0.48500
test...
final result: epoch: 161
{'accuracy': 0.4949307441711426, 'f1_score': 0.37618299381063885, 'f1_score -> average@micro': 0.49493074396687137}

HGNN+ on Cooking200

Import Libraries

import time
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn.functional as F

from dhg import Hypergraph
from dhg.data import Cooking200
from dhg.models import HGNN, HGNNP
from dhg.random import set_seed
from dhg.metrics import HypergraphVertexClassificationEvaluator as Evaluator

Define Functions

def train(net, X, A, lbls, train_idx, optimizer, epoch):
    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")
    return loss.item()


@torch.no_grad()
def infer(net, X, A, lbls, idx, test=False):
    net.eval()
    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]
    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)
    return res

Main

Note

More details about the metric Evaluator can be found in the Building Evaluator section.

if __name__ == "__main__":
    set_seed(2021)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])
    data = Cooking200()

    X, lbl = torch.eye(data["num_vertices"]), data["labels"]
    G = Hypergraph(data["num_vertices"], data["edge_list"])
    train_mask = data["train_mask"]
    val_mask = data["val_mask"]
    test_mask = data["test_mask"]

    net = HGNNP(X.shape[1], 32, data["num_classes"], use_bn=True)
    optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)

    X, lbl = X.to(device), lbl.to(device)
    G = G.to(device)
    net = net.to(device)

    best_state = None
    best_epoch, best_val = 0, 0
    for epoch in range(200):
        # train
        train(net, X, G, lbl, train_mask, optimizer, epoch)
        # validation
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = infer(net, X, G, lbl, val_mask)
            if val_res > best_val:
                print(f"update best: {val_res:.5f}")
                best_epoch = epoch
                best_val = val_res
                best_state = deepcopy(net.state_dict())
    print("\ntrain finished!")
    print(f"best val: {best_val:.5f}")
    # test
    print("test...")
    net.load_state_dict(best_state)
    res = infer(net, X, G, lbl, test_mask, test=True)
    print(f"final result: epoch: {best_epoch}")
    print(res)

Outputs

Epoch: 0, Time: 0.52802s, Loss: 2.98654
update best: 0.05000
Epoch: 1, Time: 0.00738s, Loss: 2.28235
Epoch: 2, Time: 0.00829s, Loss: 2.15288
Epoch: 3, Time: 0.00929s, Loss: 2.05343
Epoch: 4, Time: 0.00716s, Loss: 1.99081
Epoch: 5, Time: 0.00703s, Loss: 1.92390
Epoch: 6, Time: 0.01025s, Loss: 1.87569
Epoch: 7, Time: 0.01015s, Loss: 1.83000
Epoch: 8, Time: 0.00870s, Loss: 1.78668
update best: 0.06500
Epoch: 9, Time: 0.00811s, Loss: 1.75019
Epoch: 10, Time: 0.00792s, Loss: 1.70593
Epoch: 11, Time: 0.00855s, Loss: 1.68245
Epoch: 12, Time: 0.00940s, Loss: 1.64045
Epoch: 13, Time: 0.00667s, Loss: 1.60735
Epoch: 14, Time: 0.00808s, Loss: 1.58477
Epoch: 15, Time: 0.00863s, Loss: 1.54530
Epoch: 16, Time: 0.00839s, Loss: 1.52168
Epoch: 17, Time: 0.00863s, Loss: 1.48935
Epoch: 18, Time: 0.01009s, Loss: 1.46205
Epoch: 19, Time: 0.00998s, Loss: 1.43605
Epoch: 20, Time: 0.00808s, Loss: 1.40635
Epoch: 21, Time: 0.00765s, Loss: 1.39397
Epoch: 22, Time: 0.00749s, Loss: 1.36317
Epoch: 23, Time: 0.00791s, Loss: 1.34086
Epoch: 24, Time: 0.00627s, Loss: 1.32558
Epoch: 25, Time: 0.00784s, Loss: 1.30849
Epoch: 26, Time: 0.00752s, Loss: 1.27822
Epoch: 27, Time: 0.00628s, Loss: 1.28945
Epoch: 28, Time: 0.00731s, Loss: 1.24414
Epoch: 29, Time: 0.00741s, Loss: 1.22858
Epoch: 30, Time: 0.00677s, Loss: 1.20161
Epoch: 31, Time: 0.00777s, Loss: 1.19882
Epoch: 32, Time: 0.00707s, Loss: 1.16460
Epoch: 33, Time: 0.00730s, Loss: 1.16780
Epoch: 34, Time: 0.00787s, Loss: 1.13391
update best: 0.07000
Epoch: 35, Time: 0.00747s, Loss: 1.13935
update best: 0.08500
Epoch: 36, Time: 0.00683s, Loss: 1.08887
update best: 0.12000
Epoch: 37, Time: 0.00780s, Loss: 1.08907
Epoch: 38, Time: 0.00782s, Loss: 1.08394
Epoch: 39, Time: 0.00626s, Loss: 1.07832
Epoch: 40, Time: 0.00783s, Loss: 1.03877
update best: 0.12500
Epoch: 41, Time: 0.00795s, Loss: 1.03990
update best: 0.13500
Epoch: 42, Time: 0.00626s, Loss: 1.02008
update best: 0.14500
Epoch: 43, Time: 0.00709s, Loss: 0.99529
update best: 0.16000
Epoch: 44, Time: 0.00763s, Loss: 1.01162
update best: 0.17500
Epoch: 45, Time: 0.00749s, Loss: 0.99196
update best: 0.20500
Epoch: 46, Time: 0.00629s, Loss: 0.97237
update best: 0.21000
Epoch: 47, Time: 0.00754s, Loss: 0.97511
update best: 0.22500
Epoch: 48, Time: 0.00805s, Loss: 0.95078
update best: 0.23000
Epoch: 49, Time: 0.00745s, Loss: 0.94715
update best: 0.24500
Epoch: 50, Time: 0.00643s, Loss: 0.93461
update best: 0.25500
Epoch: 51, Time: 0.00743s, Loss: 0.92102
update best: 0.27500
Epoch: 52, Time: 0.00772s, Loss: 0.91536
update best: 0.29500
Epoch: 53, Time: 0.00714s, Loss: 0.89386
update best: 0.30500
Epoch: 54, Time: 0.00722s, Loss: 0.88108
Epoch: 55, Time: 0.00777s, Loss: 0.88809
Epoch: 56, Time: 0.00717s, Loss: 0.85739
Epoch: 57, Time: 0.00724s, Loss: 0.86278
update best: 0.31000
Epoch: 58, Time: 0.00804s, Loss: 0.83276
update best: 0.32500
Epoch: 59, Time: 0.00786s, Loss: 0.83001
update best: 0.35000
Epoch: 60, Time: 0.00629s, Loss: 0.83385
update best: 0.37500
Epoch: 61, Time: 0.00712s, Loss: 0.82473
update best: 0.39500
Epoch: 62, Time: 0.00904s, Loss: 0.81101
update best: 0.41000
Epoch: 63, Time: 0.00745s, Loss: 0.80212
Epoch: 64, Time: 0.00715s, Loss: 0.79534
update best: 0.42000
Epoch: 65, Time: 0.00705s, Loss: 0.77077
Epoch: 66, Time: 0.00710s, Loss: 0.77775
update best: 0.43000
Epoch: 67, Time: 0.00717s, Loss: 0.77026
update best: 0.43500
Epoch: 68, Time: 0.00789s, Loss: 0.75978
Epoch: 69, Time: 0.00747s, Loss: 0.74209
Epoch: 70, Time: 0.00639s, Loss: 0.73636
Epoch: 71, Time: 0.00689s, Loss: 0.72454
Epoch: 72, Time: 0.00793s, Loss: 0.72910
Epoch: 73, Time: 0.00729s, Loss: 0.72512
Epoch: 74, Time: 0.00775s, Loss: 0.71034
update best: 0.44500
Epoch: 75, Time: 0.00766s, Loss: 0.69282
update best: 0.45000
Epoch: 76, Time: 0.00627s, Loss: 0.70622
update best: 0.46000
Epoch: 77, Time: 0.00706s, Loss: 0.70540
update best: 0.47500
Epoch: 78, Time: 0.00849s, Loss: 0.69790
Epoch: 79, Time: 0.00731s, Loss: 0.66718
Epoch: 80, Time: 0.00748s, Loss: 0.67149
Epoch: 81, Time: 0.00900s, Loss: 0.68492
Epoch: 82, Time: 0.00624s, Loss: 0.65467
Epoch: 83, Time: 0.00713s, Loss: 0.63049
Epoch: 84, Time: 0.00852s, Loss: 0.65693
Epoch: 85, Time: 0.00622s, Loss: 0.64821
Epoch: 86, Time: 0.00717s, Loss: 0.64481
Epoch: 87, Time: 0.00784s, Loss: 0.64284
Epoch: 88, Time: 0.00630s, Loss: 0.62653
Epoch: 89, Time: 0.00726s, Loss: 0.62808
Epoch: 90, Time: 0.00786s, Loss: 0.62135
Epoch: 91, Time: 0.00729s, Loss: 0.59833
Epoch: 92, Time: 0.00731s, Loss: 0.60561
Epoch: 93, Time: 0.00801s, Loss: 0.60091
Epoch: 94, Time: 0.00630s, Loss: 0.58819
Epoch: 95, Time: 0.00763s, Loss: 0.56774
Epoch: 96, Time: 0.00743s, Loss: 0.57335
Epoch: 97, Time: 0.00662s, Loss: 0.56947
Epoch: 98, Time: 0.00899s, Loss: 0.57430
Epoch: 99, Time: 0.00751s, Loss: 0.56189
Epoch: 100, Time: 0.00719s, Loss: 0.55171
Epoch: 101, Time: 0.00791s, Loss: 0.56934
Epoch: 102, Time: 0.00627s, Loss: 0.54815
Epoch: 103, Time: 0.00731s, Loss: 0.54027
Epoch: 104, Time: 0.00817s, Loss: 0.54291
Epoch: 105, Time: 0.00623s, Loss: 0.52773
Epoch: 106, Time: 0.00737s, Loss: 0.53735
Epoch: 107, Time: 0.00790s, Loss: 0.51841
Epoch: 108, Time: 0.00631s, Loss: 0.51548
Epoch: 109, Time: 0.00753s, Loss: 0.51153
Epoch: 110, Time: 0.00822s, Loss: 0.50702
Epoch: 111, Time: 0.00689s, Loss: 0.50974
Epoch: 112, Time: 0.00648s, Loss: 0.49094
Epoch: 113, Time: 0.00768s, Loss: 0.50044
Epoch: 114, Time: 0.00808s, Loss: 0.50632
Epoch: 115, Time: 0.00744s, Loss: 0.48155
Epoch: 116, Time: 0.00774s, Loss: 0.49875
Epoch: 117, Time: 0.00633s, Loss: 0.48650
Epoch: 118, Time: 0.00742s, Loss: 0.48026
Epoch: 119, Time: 0.00928s, Loss: 0.48162
Epoch: 120, Time: 0.00687s, Loss: 0.46713
Epoch: 121, Time: 0.00679s, Loss: 0.46894
Epoch: 122, Time: 0.00891s, Loss: 0.47300
Epoch: 123, Time: 0.00639s, Loss: 0.45836
Epoch: 124, Time: 0.00676s, Loss: 0.46030
Epoch: 125, Time: 0.00940s, Loss: 0.45373
Epoch: 126, Time: 0.00926s, Loss: 0.44894
Epoch: 127, Time: 0.00701s, Loss: 0.45110
Epoch: 128, Time: 0.00710s, Loss: 0.43749
Epoch: 129, Time: 0.00913s, Loss: 0.45104
Epoch: 130, Time: 0.00706s, Loss: 0.45284
Epoch: 131, Time: 0.00693s, Loss: 0.44452
Epoch: 132, Time: 0.00937s, Loss: 0.43088
Epoch: 133, Time: 0.00810s, Loss: 0.43557
Epoch: 134, Time: 0.00713s, Loss: 0.44251
Epoch: 135, Time: 0.00822s, Loss: 0.41227
Epoch: 136, Time: 0.00981s, Loss: 0.41414
Epoch: 137, Time: 0.00706s, Loss: 0.42148
Epoch: 138, Time: 0.00649s, Loss: 0.40822
Epoch: 139, Time: 0.00860s, Loss: 0.41343
Epoch: 140, Time: 0.00616s, Loss: 0.39754
Epoch: 141, Time: 0.00644s, Loss: 0.39057
Epoch: 142, Time: 0.00860s, Loss: 0.41271
Epoch: 143, Time: 0.00631s, Loss: 0.39916
Epoch: 144, Time: 0.00675s, Loss: 0.37878
Epoch: 145, Time: 0.00897s, Loss: 0.40234
Epoch: 146, Time: 0.00621s, Loss: 0.38136
Epoch: 147, Time: 0.00864s, Loss: 0.38960
Epoch: 148, Time: 0.00633s, Loss: 0.40494
Epoch: 149, Time: 0.00629s, Loss: 0.38099
Epoch: 150, Time: 0.00883s, Loss: 0.37809
Epoch: 151, Time: 0.00621s, Loss: 0.38888
Epoch: 152, Time: 0.00633s, Loss: 0.35971
Epoch: 153, Time: 0.00842s, Loss: 0.37553
Epoch: 154, Time: 0.00622s, Loss: 0.36924
Epoch: 155, Time: 0.00739s, Loss: 0.37269
Epoch: 156, Time: 0.00864s, Loss: 0.36131
Epoch: 157, Time: 0.00627s, Loss: 0.35630
Epoch: 158, Time: 0.00854s, Loss: 0.36315
Epoch: 159, Time: 0.00648s, Loss: 0.37506
Epoch: 160, Time: 0.00638s, Loss: 0.36177
Epoch: 161, Time: 0.00867s, Loss: 0.37122
Epoch: 162, Time: 0.00632s, Loss: 0.35660
Epoch: 163, Time: 0.00641s, Loss: 0.34108
Epoch: 164, Time: 0.00873s, Loss: 0.34228
Epoch: 165, Time: 0.00619s, Loss: 0.34731
Epoch: 166, Time: 0.00656s, Loss: 0.34604
Epoch: 167, Time: 0.00881s, Loss: 0.33136
Epoch: 168, Time: 0.00620s, Loss: 0.35096
Epoch: 169, Time: 0.00874s, Loss: 0.33567
Epoch: 170, Time: 0.00766s, Loss: 0.32705
Epoch: 171, Time: 0.00628s, Loss: 0.32490
Epoch: 172, Time: 0.00880s, Loss: 0.32892
Epoch: 173, Time: 0.00619s, Loss: 0.32556
Epoch: 174, Time: 0.00631s, Loss: 0.32410
Epoch: 175, Time: 0.00878s, Loss: 0.30940
Epoch: 176, Time: 0.00629s, Loss: 0.33027
Epoch: 177, Time: 0.00636s, Loss: 0.32709
Epoch: 178, Time: 0.00887s, Loss: 0.32104
Epoch: 179, Time: 0.00625s, Loss: 0.33687
Epoch: 180, Time: 0.00694s, Loss: 0.31593
Epoch: 181, Time: 0.00861s, Loss: 0.31409
Epoch: 182, Time: 0.00627s, Loss: 0.31477
Epoch: 183, Time: 0.00847s, Loss: 0.30355
Epoch: 184, Time: 0.00642s, Loss: 0.33237
Epoch: 185, Time: 0.00630s, Loss: 0.30555
Epoch: 186, Time: 0.00839s, Loss: 0.29973
Epoch: 187, Time: 0.00631s, Loss: 0.30695
Epoch: 188, Time: 0.00645s, Loss: 0.30313
Epoch: 189, Time: 0.00899s, Loss: 0.30699
Epoch: 190, Time: 0.00626s, Loss: 0.31283
Epoch: 191, Time: 0.00654s, Loss: 0.28851
Epoch: 192, Time: 0.00879s, Loss: 0.28803
Epoch: 193, Time: 0.00621s, Loss: 0.28213
Epoch: 194, Time: 0.00846s, Loss: 0.27823
Epoch: 195, Time: 0.00704s, Loss: 0.29048
Epoch: 196, Time: 0.00638s, Loss: 0.28898
Epoch: 197, Time: 0.00894s, Loss: 0.29096
Epoch: 198, Time: 0.00642s, Loss: 0.27857
Epoch: 199, Time: 0.00817s, Loss: 0.29117

train finished!
best val: 0.47500
test...
final result: epoch: 77
{'accuracy': 0.5203484296798706, 'f1_score': 0.39131907709452823, 'f1_score -> average@micro': 0.5203484221048122}