Source code for dhg.data.walmart

from typing import Optional
from functools import partial

from .base import BaseData
from dhg.datapipe import load_from_pickle, to_tensor, to_long_tensor


[docs]class WalmartTrips(BaseData): r"""The Walmart Trips dataset is a user-product network dataset for vertex classification task. In Walmart, nodes represent products being purchased at Walmart, and hyperedges equal sets of products purchased together; the node labels are the product categories. More details see `this <https://www.cs.cornell.edu/~arb/data/walmart-trips/>`_ and the `YOU ARE ALLSET: A MULTISET LEARNING FRAMEWORK FOR HYPERGRAPH NEURAL NETWORKS <https://openreview.net/pdf?id=hpBTIv2uy_E>`_ paper. The content of the Walmart Trips dataset includes the following: - ``num_classes``: The number of classes: :math:`12`. - ``num_vertices``: The number of vertices: :math:`88,860`. - ``num_edges``: The number of edges: :math:`69,906`. - ``edge_list``: The edge list. ``List`` with length :math:`69,906`. - ``labels``: The label list. ``torch.LongTensor`` with size :math:`(88,860, )`. Args: ``data_root`` (``str``, optional): The ``data_root`` has stored the data. If set to ``None``, this function will auto-download from server and save into the default direction ``~/.dhg/datasets/``. Defaults to ``None``. """ def __init__(self, data_root: Optional[str] = None) -> None: super().__init__("walmart_trips", data_root) self._content = { "num_classes": 12, "num_vertices": 88860, "num_edges": 69906, "edge_list": { "upon": [{"filename": "edge_list.pkl", "md5": "0827c278282601b9c584f80c3b686a72"}], "loader": load_from_pickle, }, "labels": { "upon": [{"filename": "labels.pkl", "md5": "981684b84f9e7917e86b5aff08d0c594"}], "loader": load_from_pickle, "preprocess": [to_long_tensor], }, }