Skip to content
31 changes: 17 additions & 14 deletions hyperbench/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
NodeSpaceFiller,
NodeSpaceSetting,
is_transductive_setting,
validate_node_space_setting,
validate_split_ratios,
)

from hyperbench.data.hif import HIFLoader, HIFProcessor
Expand Down Expand Up @@ -146,8 +148,9 @@ def enrich_node_features(
Args:
enricher: An instance of NodeEnricher to generate structural node features from hypergraph topology.
enrichment_mode: How to combine generated features with existing ``hdata.x``.
``concatenate`` appends new features as additional columns.
``concatenate`` appends new features to the existing ones as additional columns.
``replace`` substitutes ``hdata.x`` entirely.
Defaults to ``replace`` if not provided.
"""
self.hdata = self.hdata.enrich_node_features(enricher, enrichment_mode)

Expand Down Expand Up @@ -195,10 +198,11 @@ def enrich_hyperedge_attr(
"""Enrich hyperedge features using the provided hyperedge feature enricher.

Args:
enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology.
enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_attr``.
``concatenate`` appends new features as additional columns.
enricher: An instance of HyperedgeEnricher to generate structural hyperedge attributes from hypergraph topology.
enrichment_mode: How to combine generated attributes with existing ``hdata.hyperedge_attr``.
``concatenate`` appends new attributes to the existing ones as additional columns.
``replace`` substitutes ``hdata.hyperedge_attr`` entirely.
Defaults to ``replace`` if not provided.
"""
self.hdata = self.hdata.enrich_hyperedge_attr(enricher, enrichment_mode)

Expand All @@ -210,10 +214,11 @@ def enrich_hyperedge_weights(
"""Enrich hyperedge weights using the provided hyperedge weight enricher.

Args:
enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology.
enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_weights``.
``concatenate`` appends new features as additional columns.
enricher: An instance of HyperedgeEnricher to generate structural hyperedge weights from hypergraph topology.
enrichment_mode: How to combine generated weights with existing ``hdata.hyperedge_weights``.
``concatenate`` appends new weights to the existing ones as additional columns.
``replace`` substitutes ``hdata.hyperedge_weights`` entirely.
Defaults to ``replace`` if not provided.
"""
self.hdata = self.hdata.enrich_hyperedge_weights(enricher, enrichment_mode)

Expand Down Expand Up @@ -325,7 +330,8 @@ def split_with_ratios(
seed: int | None = None,
node_space_setting: NodeSpaceSetting = "transductive",
) -> tuple[list[Dataset], list[float]]:
"""Split the dataset and return the final hyperedge ratios.
"""
Split the dataset and return the final hyperedge ratios.

Final ratios are computed from split hyperedge counts after ratio
boundaries and any transductive rebalancing have been applied.
Expand All @@ -350,12 +356,9 @@ def split_with_ratios(
hyperedges, or a transductive first split cannot cover the full
node space.
"""
# Allow small imprecision in sum of ratios, but raise error if it's significant
# Example: ratios = [0.8, 0.1, 0.1] -> sum = 1.0 (valid)
# ratios = [0.8, 0.1, 0.05] -> sum = 0.95 (invalid, raises ValueError)
# ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision)
if abs(sum(ratios) - 1.0) > 1e-6:
raise ValueError(f"Split ratios must sum to 1.0, got {sum(ratios)}.")
validate_node_space_setting(node_space_setting)
validate_split_ratios(ratios)

device = self.hdata.device

hyperedge_splitter = HyperedgeIDSplitter(self.hdata)
Expand Down
72 changes: 63 additions & 9 deletions hyperbench/data/enricher.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import warnings
import random
import torch
import warnings

from abc import ABC, abstractmethod
from torch import Tensor, optim
from typing import Literal, TypeAlias
from torch_geometric.nn import Node2Vec as PyGNode2Vec
from hyperbench.types import EdgeIndex, HyperedgeIndex
from hyperbench.models import VilLain
from hyperbench.utils import (
validate_is_between,
validate_is_finite,
validate_is_finite_when_provided,
validate_is_non_negative,
validate_is_positive,
)


EnrichmentMode: TypeAlias = Literal["concatenate", "replace"]
Expand Down Expand Up @@ -63,6 +70,8 @@ def __init__(
self.weight_decay = weight_decay
self.verbose = verbose

self.__validate()

def _empty_features(self, hyperedge_index: Tensor) -> Tensor:
"""
Return an empty feature matrix on the same device as ``hyperedge_index``.
Expand Down Expand Up @@ -147,6 +156,28 @@ def _train(self, hyperedge_index: Tensor):

return model

def __validate(self) -> None:
validate_is_positive("num_features", self.embedding_dim)
validate_is_non_negative("num_nodes", self.num_nodes)
validate_is_non_negative("num_hyperedges", self.num_hyperedges)

if self.labels_per_subspace < 2:
raise ValueError(
f"'labels_per_subspace' must be at least 2, got {self.labels_per_subspace}."
)

validate_is_positive("training_steps", self.training_steps)
validate_is_positive("generation_steps", self.generation_steps)
validate_is_finite("tau", self.tau)
validate_is_positive("tau", self.tau)
validate_is_finite("eps", self.eps)
validate_is_positive("eps", self.eps)
validate_is_positive("num_epochs", self.num_epochs)
validate_is_positive("learning_rate", self.learning_rate)
validate_is_non_negative("weight_decay", self.weight_decay)
validate_is_finite("learning_rate", self.learning_rate)
validate_is_finite("weight_decay", self.weight_decay)


class Enricher(ABC):
"""
Expand Down Expand Up @@ -322,8 +353,9 @@ def __init__(
beta: float | None = None,
):
super().__init__(cache_dir=cache_dir)
if alpha < 0.0 or alpha > 1.0:
raise ValueError("Alpha must be between 0.0 and 1.0.")

validate_is_between("alpha", alpha, 0.0, 1.0)
validate_is_finite_when_provided("beta", beta)

self.alpha = alpha
self.beta = beta
Expand Down Expand Up @@ -407,12 +439,6 @@ def __init__(
verbose: bool = False,
):
super().__init__(cache_dir=cache_dir)
if walk_length < context_size:
raise ValueError(
f"Expected walk_length >= context_size, got "
f"walk_length={walk_length}, context_size={context_size}."
)

self.embedding_dim = num_features
self.walk_length = walk_length
self.context_size = context_size
Expand All @@ -428,6 +454,8 @@ def __init__(
self.sparse = sparse
self.verbose = verbose

self.__validate()

def enrich(self, hyperedge_index: Tensor) -> Tensor:
"""
Compute Node2Vec embeddings from the clique expansion of the hypergraph.
Expand Down Expand Up @@ -519,6 +547,28 @@ def enrich(self, hyperedge_index: Tensor) -> Tensor:
# Detach node embeddings from computation graph and return them
return x.detach().to(device)

def __validate(self) -> None:
validate_is_positive("num_features", self.embedding_dim)
validate_is_positive("walk_length", self.walk_length)
validate_is_positive("context_size", self.context_size)
if self.walk_length < self.context_size:
raise ValueError(
"Expected walk_length >= context_size, got "
f"walk_length={self.walk_length}, context_size={self.context_size}."
)

validate_is_positive("num_walks_per_node", self.num_walks_per_node)
validate_is_finite("p", self.p)
validate_is_positive("p", self.p)
validate_is_finite("q", self.q)
validate_is_positive("q", self.q)
validate_is_positive("num_negative_samples", self.num_negative_samples)
validate_is_non_negative("num_nodes", self.num_nodes)
validate_is_positive("num_epochs", self.num_epochs)
validate_is_finite("learning_rate", self.learning_rate)
validate_is_positive("learning_rate", self.learning_rate)
validate_is_positive("batch_size", self.batch_size)


class LaplacianPositionalEncodingEnricher(NodeEnricher):
"""
Expand All @@ -540,6 +590,10 @@ def __init__(
cache_dir: str | None = None,
):
super().__init__(cache_dir=cache_dir)

validate_is_positive("num_features", num_features)
validate_is_non_negative("num_nodes", num_nodes)

self.num_features = num_features
self.num_nodes = num_nodes

Expand Down
13 changes: 10 additions & 3 deletions hyperbench/data/hif.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData:

# Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order
node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)}
if len(node_id_to_idx) != num_nodes:
raise ValueError("HIF node IDs must be unique.")

# Initialize edge_set only with edges that have incidences, so that
# we avoid inflating edge count due to isolated nodes/missing incidences
hyperedge_id_to_idx: dict[Any, int] = {}
Expand All @@ -81,6 +84,11 @@ def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData:
for incidence in hypergraph.incidences:
node_id = incidence.get("node", 0)
hyperedge_id = incidence.get("edge", 0)
if node_id not in node_id_to_idx:
raise ValueError(
f"Incidence references unknown node id {node_id!r}; "
"all incidence nodes must be declared in the HIF nodes list."
)

if hyperedge_id not in hyperedge_id_to_idx:
# Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences
Expand Down Expand Up @@ -149,8 +157,7 @@ def __process_hyperedge_attr(
hyperedge_id_to_idx: dict[Any, int],
num_hyperedges: int,
) -> Tensor | None:
# hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes]
hyperedge_attr = None
hyperedge_attr = None # shape [num_hyperedges, num_hyperedge_attributes]
has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0
has_any_hyperedge_attrs = has_hyperedges and any(
"attrs" in edge for edge in hypergraph.hyperedges
Expand Down Expand Up @@ -231,7 +238,7 @@ def __process_hyperedge_weights(
edge_attrs = hyperedge_id_to_attrs.get(edge_id, {})
weights.append(float(edge_attrs.get("weight", 1.0)))

return torch.tensor(weights, dtype=torch.float)
return torch.tensor(weights, dtype=torch.float) # shape [num_hyperedges,]


class HIFLoader:
Expand Down
5 changes: 1 addition & 4 deletions hyperbench/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,7 @@ def collate(self, batch: list[HData]) -> HData:

collated_x = self.__cached_dataset_hdata.x[node_ids]
collated_y = self.__cached_dataset_hdata.y[hyperedge_ids]

collated_global_node_ids = None
if self.__cached_dataset_hdata.global_node_ids is not None:
collated_global_node_ids = self.__cached_dataset_hdata.global_node_ids[node_ids]
collated_global_node_ids = self.__cached_dataset_hdata.global_node_ids[node_ids]

collated_hyperedge_attr = None
if self.__cached_dataset_hdata.hyperedge_attr is not None:
Expand Down
38 changes: 24 additions & 14 deletions hyperbench/data/negative_sampling_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from enum import Enum
from typing import Any
from typing import Any, Literal, TypeAlias
from hyperbench.types import HData
from hyperbench.data import NegativeSampler


class NegativeSamplingSchedule(Enum):
"""When to run negative sampling during training."""

FIRST_EPOCH = "first_epoch" # Only at epoch 0, cached for all subsequent epochs
EVERY_N_EPOCHS = "every_n_epochs" # Every N epochs (N provided separately)
EVERY_EPOCH = "every_epoch" # Negatives generated every epoch
NegativeSamplingSchedule: TypeAlias = Literal[
"first_epoch", # Only at epoch 0, cached for all subsequent epochs
"every_n_epochs", # Every N epochs (N provided separately)
"every_epoch", # Negatives generated every epoch
]


class NegativeSamplingScheduler:
Expand All @@ -21,14 +19,15 @@ class NegativeSamplingScheduler:

Args:
negative_sampler: An instance of a ``NegativeSampler`` that defines how to sample negatives.
negative_sampling_schedule: An instance of ``NegativeSamplingSchedule`` that specifies the schedule for sampling negatives.
negative_sampling_every_n: An integer specifying the interval for sampling negatives when the schedule is set to ``EVERY_N_EPOCHS``. This parameter is ignored for other schedules.
negative_sampling_schedule: Literal string specifying the schedule for sampling negatives.
negative_sampling_every_n: An integer specifying the interval for sampling negatives
when the schedule is set to ``"every_n_epochs"``. This parameter is ignored for other schedules.
"""

def __init__(
self,
negative_sampler: NegativeSampler,
negative_sampling_schedule: NegativeSamplingSchedule = NegativeSamplingSchedule.EVERY_EPOCH,
negative_sampling_schedule: NegativeSamplingSchedule = "every_epoch",
negative_sampling_every_n: int = 1,
) -> None:
self.negative_sampler = negative_sampler
Expand Down Expand Up @@ -56,13 +55,24 @@ def should_sample(self, epoch: int) -> bool:
Returns:
should_sample: True if negatives should be resampled for the current epoch, False otherwise.
"""
if epoch < 0:
raise ValueError(f"Epoch must be non-negative, got {epoch}.")

match self.negative_sampling_schedule:
case NegativeSamplingSchedule.EVERY_N_EPOCHS:
case "every_n_epochs":
if self.negative_sampling_every_n <= 0:
raise ValueError(
f"negative_sampling_every_n must be positive, got {self.negative_sampling_every_n}."
)
return epoch % self.negative_sampling_every_n == 0
case NegativeSamplingSchedule.FIRST_EPOCH:
case "first_epoch":
return epoch == 0
case _: # Defaults to NegativeSamplingSchedule.EVERY_EPOCH
case "every_epoch":
return True
case _:
raise ValueError(
f"Unsupported negative sampling schedule: {self.negative_sampling_schedule!r}."
)

def sample(self, batch: HData, epoch: int) -> HData:
"""
Expand Down
18 changes: 16 additions & 2 deletions hyperbench/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _normalize_index(self, index: int | list[int], size: int) -> list[int]:

Raises:
ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of sampleable items).
TypeError: If the index is not an integer or a list of integers.
"""
if isinstance(index, list):
if len(index) < 1:
Expand All @@ -57,7 +58,15 @@ def _normalize_index(self, index: int | list[int], size: int) -> list[int]:
raise ValueError(
f"Index list length ({len(index)}) cannot exceed the number of sampleable items ({size})."
)
for id in index:
if not isinstance(id, int) or isinstance(id, bool):
raise TypeError("Index list must contain only integers.")

return list(set(index))

if not isinstance(index, int) or isinstance(index, bool):
raise TypeError("Index must be an integer or a list of integers.")

return [index]

def _sample_hyperedge_index(
Expand Down Expand Up @@ -244,10 +253,15 @@ def create_sampler_from_strategy(strategy: SamplingStrategy) -> BaseSampler:
strategy: An instance of SamplingStrategy enum indicating which sampling strategy to use.

Returns:
sampler: An instance of a subclass of BaseSampler corresponding to the specified strategy. If strategy is not recognized, defaults to ``HyperedgeSampler``.
sampler: An instance of a subclass of BaseSampler corresponding to the specified strategy.

Raises:
ValueError: If ``strategy`` is not a supported `SamplingStrategy`.
"""
match strategy:
case SamplingStrategy.NODE:
return NodeSampler()
case _:
case SamplingStrategy.HYPEREDGE:
return HyperedgeSampler()
case _:
raise ValueError(f"Unsupported sampling strategy: {strategy!r}.")
Loading