Source code for dhg.metrics.base

import abc
from functools import partial
from typing import Union, List, Dict
from collections import defaultdict

import numpy as np
import torch

from dhg._global import AUTHOR_EMAIL


def format_metric_configs(task: str, metric_configs: List[Union[str, Dict[str, dict]]]):
    r"""Format metric_configs.
    
    Args:
        ``task`` (``str``): The type of the task. The supported types include: ``classification``, ``retrieval`` and ``recommender``.
        ``metric_configs`` (``Dict[str, Dict[str, Union[str, int]]]``): The metric configs.
    """
    task = task.lower()
    if task == "classification":
        import dhg.metrics.classification as module

        available_metrics = module.available_classification_metrics()
    elif task == "retrieval":
        import dhg.metrics.retrieval as module

        available_metrics = module.available_retrieval_metrics()
    elif task == "recommender":
        import dhg.metrics.recommender as module

        available_metrics = module.available_recommender_metrics()
    else:
        raise ValueError(f"Task {task} is not supported yet. Please email '{AUTHOR_EMAIL}' to add it.")
    metric_list = []
    for metric in metric_configs:
        if isinstance(metric, str):
            marker, func_name = metric, metric
            assert (
                func_name in available_metrics
            ), f"{func_name} is not supported yet. Please email '{AUTHOR_EMAIL}' to add it."
            func = getattr(module, func_name)
        elif isinstance(metric, dict):
            assert len(metric) == 1
            func_name = list(metric.keys())[0]
            assert (
                func_name in available_metrics
            ), f"{func_name} is not supported yet. Please email '{AUTHOR_EMAIL}' to add it."
            params = metric[func_name]
            func = getattr(module, func_name)
            func = partial(func, **params)
            markder_list = []
            for k, v in params.items():
                _m = f"{k}@"
                if isinstance(v, str):
                    _m += v
                elif isinstance(v, int):
                    _m += str(v)
                elif isinstance(v, float):
                    _m += f"{v:.4f}"
                elif isinstance(v, list) or isinstance(v, tuple) or isinstance(v, set):
                    _m += "_".join([str(_v) for _v in v])
                else:
                    _m += str(v)
                markder_list.append(_m)
            marker = f"{func_name} -> {' | '.join(markder_list)}"
        else:
            raise ValueError
        metric_list.append({"marker": marker, "func": func, "func_name": func_name})
    return metric_list


[docs]class BaseEvaluator: r"""The base class for task-specified metric evaluators. Args: ``task`` (``str``): The type of the task. The supported types include: ``classification``, ``retrieval`` and ``recommender``. ``metric_configs`` (``List[Union[str, Dict[str, dict]]]``): The metric configurations. The key is the metric name and the value is the metric parameters. ``validate_index`` (``int``): The specified metric index used for validation. Defaults to ``0``. """ def __init__( self, task: str, metric_configs: List[Union[str, Dict[str, dict]]], validate_index: int = 0, ): self.validate_index = validate_index metric_configs = format_metric_configs(task, metric_configs) assert validate_index >= 0 and validate_index < len( metric_configs ), "The specified validate metric index is out of range." self.marker_list, self.func_list = [], [] for metric in metric_configs: self.marker_list.append(metric["marker"]) self.func_list.append(metric["func"]) # init batch data containers self.validate_res = [] self.test_res_dict = defaultdict(list) self.last_validate_res, self.last_test_res = None, {} @abc.abstractmethod def __repr__(self) -> str: r"""Print the Evaluator information. """
[docs] def validate_add_batch(self, batch_y_true: torch.Tensor, batch_y_pred: torch.Tensor): r"""Add batch data for validation. Args: ``batch_y_true`` (``torch.Tensor``): The ground truth data. Size :math:`(N_{batch}, -)`. ``batch_y_pred`` (``torch.Tensor``): The predicted data. Size :math:`(N_{batch}, -)`. """ batch_res = self.func_list[self.validate_index](batch_y_true, batch_y_pred, ret_batch=True) batch_res = np.array(batch_res) if len(batch_res.shape) == 1: batch_res = batch_res[:, np.newaxis] self.validate_res.append(batch_res)
[docs] def validate_epoch_res(self): r"""For all added batch data, return the result of the evaluation on the specified ``validate_index``-th metric. """ if self.validate_res == [] and self.last_validate_res is not None: return self.last_validate_res assert self.validate_res != [], "No batch data added for validation." self.last_validate_res = np.vstack(self.validate_res).mean(0).item() # clear batch cache self.validate_res = [] return self.last_validate_res
[docs] def test_add_batch(self, batch_y_true: torch.Tensor, batch_y_pred: torch.Tensor): r"""Add batch data for testing. Args: ``batch_y_true`` (``torch.Tensor``): The ground truth data. Size :math:`(N_{batch}, -)`. ``batch_y_pred`` (``torch.Tensor``): The predicted data. Size :math:`(N_{batch}, -)`. """ for name, func in zip(self.marker_list, self.func_list): batch_res = func(batch_y_true, batch_y_pred, ret_batch=True) if not isinstance(batch_res, tuple): batch_res = np.array(batch_res) if len(batch_res.shape) == 1: batch_res = batch_res[:, np.newaxis] self.test_res_dict[name].append(batch_res) else: if self.test_res_dict[name] == []: self.test_res_dict[name] = [list() for _ in range(len(batch_res))] for idx, batch_sub_res in enumerate(batch_res): batch_sub_res = np.array(batch_sub_res) if len(batch_sub_res.shape) == 1: batch_sub_res = batch_sub_res[:, np.newaxis] self.test_res_dict[name][idx].append(batch_sub_res)
[docs] def test_epoch_res(self): r"""For all added batch data, return results of the evaluation on all the metrics in ``metric_configs``. """ if self.test_res_dict == {} and self.last_test_res is not None: return self.last_test_res assert self.test_res_dict != {}, "No batch data added for testing." for name, res_list in self.test_res_dict.items(): if not isinstance(res_list[0], list): self.last_test_res[name] = np.vstack(res_list).mean(0).squeeze().tolist() else: self.last_test_res[name] = [ np.vstack(sub_res_list).mean(0).squeeze().tolist() for sub_res_list in res_list ] # clear batch cache self.test_res_dict = defaultdict(list) return self.last_test_res
[docs] def validate(self, y_true: torch.LongTensor, y_pred: torch.Tensor): r"""Return the result of the evaluation on the specified ``validate_index``-th metric. Args: ``y_true`` (``torch.LongTensor``): The ground truth labels. Size :math:`(N_{samples}, -)`. ``y_pred`` (``torch.Tensor``): The predicted labels. Size :math:`(N_{samples}, -)`. """ return self.func_list[self.validate_index](y_true, y_pred)
[docs] def test(self, y_true: torch.LongTensor, y_pred: torch.Tensor): r"""Return results of the evaluation on all the metrics in ``metric_configs``. Args: ``y_true`` (``torch.LongTensor``): The ground truth labels. Size :math:`(N_{samples}, -)`. ``y_pred`` (``torch.Tensor``): The predicted labels. Size :math:`(N_{samples}, -)`. """ return {name: func(y_true, y_pred) for name, func in zip(self.marker_list, self.func_list)}