Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 143 additions & 5 deletions hyperbench/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import zstandard as zstd

from enum import Enum
from typing import List, Tuple
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from hyperbench.types.hypergraph import HIFHypergraph
from hyperbench.types.hdata import HData
Expand Down Expand Up @@ -85,7 +87,6 @@ class Dataset(TorchDataset):
process(): Processes the hypergraph into HData format.
"""

# TODO: move as input to __init__()? So that users can provide new ids and names of datasets formatted in HIF
GDRIVE_FILE_ID = None
DATASET_NAME = None

Expand All @@ -96,9 +97,31 @@ def __init__(self) -> None:
def __len__(self) -> int:
return len(self.hypergraph.nodes)

def __getitem__(self, index: int) -> HData:
# TODO: implement sampling of nodes with given index
return self.hdata
def __getitem__(self, index: int | List[int]) -> HData:
sampled_node_ids_list = self.__get_node_ids_to_sample(index)
self.__validate_node_ids(sampled_node_ids_list)

sampled_edge_index, sampled_node_ids, sampled_edge_ids = (
self.__sample_edge_index(sampled_node_ids_list)
)

new_edge_index = self.__new_edge_index(
sampled_edge_index, sampled_node_ids, sampled_edge_ids
)

new_node_features = self.hdata.x[sampled_node_ids]

new_edge_attr = None
if self.hdata.edge_attr is not None and len(sampled_edge_ids) > 0:
new_edge_attr = self.hdata.edge_attr[sampled_edge_ids]

return HData(
x=new_node_features,
edge_index=new_edge_index,
edge_attr=new_edge_attr,
num_nodes=len(sampled_node_ids),
num_edges=len(sampled_edge_ids),
)

def download(self) -> HIFHypergraph:
"""
Expand All @@ -115,7 +138,6 @@ def process(self) -> HData:
Returns:
HData: Processed hypergraph data.
"""

num_nodes = len(self.hypergraph.nodes)
num_edges = len(self.hypergraph.edges)

Expand Down Expand Up @@ -160,6 +182,122 @@ def process(self) -> HData:

return HData(x, edge_index, edge_attr, num_nodes, num_edges)

def __get_node_ids_to_sample(self, id: int | List[int]) -> List[int]:
if isinstance(id, int):
return [id]

if isinstance(id, list):
if len(id) < 1:
raise ValueError("Index list cannot be empty.")
elif len(id) > self.__len__():
raise ValueError(
"Index list length cannot exceed number of nodes in the hypergraph."
)
return list(set(id))

def __validate_node_ids(self, node_ids: List[int]) -> None:
for id in node_ids:
if id < 0 or id >= self.__len__():
raise IndexError(
f"Node ID {id} is out of bounds (0, {self.__len__() - 1})."
)

def __sample_edge_index(
self,
sampled_node_ids_list: List[int],
) -> Tuple[Tensor, Tensor, Tensor]:
edge_index = self.hdata.edge_index
node_ids = edge_index[0]
edge_ids = edge_index[1]

sampled_node_ids = torch.tensor(sampled_node_ids_list)

# Find incidences where the node is in our sampled node set
# Example: edge_index[0] = [0, 0, 1, 2, 3, 4], sampled_node_ids = [0, 3]
# -> node_incidence_mask = [True, True, False, False, True, False]
node_incidence_mask = torch.isin(node_ids, sampled_node_ids)

# Get unique hyperedges that have at least one sampled node
# Example: edge_index[1] = [0, 0, 0, 1, 2, 2], node_incidence_mask = [True, True, False, False, True, False]
# -> sampled_edge_ids = [0, 2] as they connect to sampled nodes
sampled_edge_ids = edge_ids[node_incidence_mask].unique()

# Find all incidences for sampled nodes belonging to relevant hyperedges
# Example: edge_index[1] = [0, 0, 0, 1, 2, 2], sampled_edge_ids = [0, 2]
# -> edge_incidence_mask = [True, True, True, False, True, True]
edge_incidence_mask = torch.isin(edge_ids, sampled_edge_ids)

# Incidence is kept if node is sampled AND hyperedge is relevant
incidence_mask = node_incidence_mask & edge_incidence_mask

# Keep only the incidences that match our mask
# Example: edge_index = [[0, 0, 1, 2, 3, 4],
# [0, 0, 0, 1, 2, 2]],
# incidence_mask = [True, True, False, False, True, False]
# -> sampled_edge_index = [[0, 0, 3],
# [0, 0, 2]]
sampled_edge_index = edge_index[:, incidence_mask]

return sampled_edge_index, sampled_node_ids, sampled_edge_ids

def __new_edge_index(
self,
sampled_edge_index: Tensor,
sampled_node_ids: Tensor,
sampled_edge_ids: Tensor,
) -> Tensor:
"""
Create new edge_index with 0-based node and edge IDs.
Args:
sampled_edge_index: Original edge_index tensor with sampled incidences.
sampled_node_ids: List of sampled original node IDs.
sampled_edge_ids: List of sampled original edge IDs.
Returns:
New edge_index tensor with 0-based node and edge IDs.
"""
# Example: sampled_edge_index = [[1, 1, 3],
# [0, 2, 2]]
# sampled_node_ids = [1, 3],
# sampled_edge_ids = [0, 2]
# -> new_node_ids = [0, 0, 1], new_edge_ids = [0, 1, 1]
new_node_ids = self.__to_0based_ids(
sampled_edge_index[0], sampled_node_ids, self.hdata.num_nodes
)
new_edge_ids = self.__to_0based_ids(
sampled_edge_index[1], sampled_edge_ids, self.hdata.num_edges
)

# Example: new_node_ids = [0, 1], new_edge_ids = [0, 1]
# -> new_edge_index = [[0, 1],
# [0, 1]]
new_edge_index = torch.stack([new_node_ids, new_edge_ids], dim=0)
return new_edge_index

def __to_0based_ids(
self,
original_ids: Tensor,
ids_to_keep: Tensor,
n: int,
) -> Tensor:
"""
Map original IDs to 0-based ids.
Example:
original_ids: [1, 3, 3, 7]
ids_to_keep: [3, 7]
n = 8 # total number of elements (nodes or edges) in the original hypergraph
Returned 0-based IDs: [0, 0, 1] # the size is sum of occurrences of ids_to_keep in original_ids
Args:
original_ids: Tensor of original IDs.
n: Total number of original IDs.
ids_to_keep: List of selected original IDs to be mapped to 0-based.
Returns:
Tensor of 0-based ids.
"""
id_to_0based_id = torch.zeros(n)
n_ids_to_keep = len(ids_to_keep)
id_to_0based_id[ids_to_keep] = torch.arange(n_ids_to_keep)
return id_to_0based_id[original_ids]


class AlgebraDataset(Dataset):
DATASET_NAME = "ALGEBRA"
Expand Down
Loading