Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ typecheck:

test:
@echo '=== Tests ==='
$(UV) run $(PYTEST)
$(UV) run $(PYTEST) --cov=hyperbench --cov-report=term-missing

clean:
@echo '=== Cleaning up ==='
Expand Down
10 changes: 7 additions & 3 deletions hyperbench/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from typing import List, Tuple
from typing import List, Optional, Tuple
from torch import Tensor
from torch.utils.data import DataLoader as TorchDataLoader
from hyperbench.data import Dataset
Expand All @@ -9,7 +9,11 @@

class DataLoader(TorchDataLoader):
def __init__(
self, dataset: Dataset, batch_size: int = 1, shuffle: bool = False, **kwargs
self,
dataset: Dataset,
batch_size: int = 1,
shuffle: Optional[bool] = False,
**kwargs,
) -> None:
super().__init__(
dataset=dataset,
Expand Down Expand Up @@ -101,7 +105,7 @@ def __batch_node_features(self, batch: List[HData]) -> Tuple[Tensor, int]:

return batched_node_features, total_nodes

def __batch_edges(self, batch: List[HData]) -> Tuple[Tensor, Tensor | None, int]:
def __batch_edges(self, batch: List[HData]) -> Tuple[Tensor, Optional[Tensor], int]:
"""Batches hyperedge indices and attributes, adjusting indices for concatenated nodes.
Hyperedge indices must be offset so they point to the correct nodes in the batched node tensor.

Expand Down
31 changes: 6 additions & 25 deletions hyperbench/tests/mock/mock.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,11 @@
from typing import Any, List
from hyperbench.data import Dataset
from hyperbench import utils
from unittest.mock import MagicMock


MOCK_BASE_PATH = "hyperbench/tests/mock"


class MockDataset(Dataset):
def __init__(self, data_list: list[Any]):
super().__init__()
self.data_list = data_list
self.hypergraph = utils.empty_hifhypergraph() # Not used in this mock
self.hdata = utils.empty_hdata() # Not used in this mock

def __len__(self):
return len(self.data_list)

def __getitem__(self, index: int | List[int]) -> Any:
if isinstance(index, list):
return [self.data_list[i] for i in index]
return self.data_list[index]

def download(self):
# Not implemented for mock as we don't need it
pass

def process(self):
# Not implemented for mock as we don't need it
pass
def new_mock_trainer() -> MagicMock:
trainer = MagicMock()
trainer.fit = MagicMock()
trainer.test = MagicMock(return_value=[{"acc": 0.9}])
return trainer
223 changes: 223 additions & 0 deletions hyperbench/tests/train/trainer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import pytest

from unittest.mock import MagicMock, patch
from hyperbench.train import MultiModelTrainer
from hyperbench.types import ModelConfig
from hyperbench.tests import new_mock_trainer


@pytest.fixture
def mock_model_configs():
model_configs = []

for i in range(2):
model = MagicMock()
model.name = f"model{i}"
model.version = f"{i}"

model_config = MagicMock(spec=ModelConfig)
model_config.name = f"model{i}"
model_config.version = f"{i}"
model_config.model = model
model_config.trainer = None
model_config.full_model_name = (
lambda self=model_config: f"{self.name}:{self.version}"
)

model_configs.append(model_config)

return model_configs


@patch("hyperbench.train.trainer.L.Trainer")
def test_trainer_initialization(_, mock_model_configs):
multi_model_trainer = MultiModelTrainer(mock_model_configs)

assert len(multi_model_trainer.model_configs) == len(mock_model_configs)
for config in multi_model_trainer.model_configs:
assert config.trainer is not None


@patch("hyperbench.train.trainer.L.Trainer")
def test_trainer_initialization_with_initialized_trainer(
mock_trainer, mock_model_configs
):
mock_model_configs[0].trainer = mock_trainer

multi_model_trainer = MultiModelTrainer(mock_model_configs)

assert len(multi_model_trainer.model_configs) == len(mock_model_configs)
for config in multi_model_trainer.model_configs:
assert config.trainer is not None


@patch("hyperbench.train.trainer.L.Trainer")
def test_models_property_returns_models(_, mock_model_configs):
multi_model_trainer = MultiModelTrainer(mock_model_configs)
models = multi_model_trainer.models

assert len(models) == len(mock_model_configs)


@patch("hyperbench.train.trainer.L.Trainer")
def test_models_property_returns_empty_when_no_models(_):
multi_model_trainer = MultiModelTrainer([])
models = multi_model_trainer.models

assert len(models) == 0


@patch("hyperbench.train.trainer.L.Trainer")
def test_model_returns_model_when_correct_name_and_no_version(_, mock_model_configs):
mock_model_configs[0].version = "default"
mock_model_configs[0].model.version = "default"

multi_model_trainer = MultiModelTrainer(mock_model_configs)
found = multi_model_trainer.model(name="model0")

assert found is not None
assert found.name == "model0"
assert found.version == "default"


@patch("hyperbench.train.trainer.L.Trainer")
def test_model_returns_None_when_incorrect_name_and_no_version(_, mock_model_configs):
multi_model_trainer = MultiModelTrainer(mock_model_configs)
found = multi_model_trainer.model(name="nonexistent")

assert found is None


@patch("hyperbench.train.trainer.L.Trainer")
def test_model_returns_model_when_correct_name_and_version(_, mock_model_configs):
multi_model_trainer = MultiModelTrainer(mock_model_configs)
found = multi_model_trainer.model(name="model0", version="0")

assert found is not None
assert found.name == "model0"
assert found.version == "0"


@patch("hyperbench.train.trainer.L.Trainer")
def test_model_returns_None_when_incorrect_name_and_version(_, mock_model_configs):
multi_model_trainer = MultiModelTrainer(mock_model_configs)
not_found = multi_model_trainer.model(name="nonexistent", version="100")

assert not_found is None


@patch("hyperbench.train.trainer.L.Trainer")
def test_model_returns_None_when_incorrect_name_and_correct_version(
_, mock_model_configs
):
multi_model_trainer = MultiModelTrainer(mock_model_configs)
not_found = multi_model_trainer.model(name="nonexistent", version="0")

assert not_found is None


@patch(
"hyperbench.train.trainer.L.Trainer",
side_effect=lambda *args, **kwargs: new_mock_trainer(),
)
def test_fit_all_calls_fit(_, mock_model_configs):
multi_model_trainer = MultiModelTrainer(mock_model_configs)

multi_model_trainer.fit_all(verbose=False)
for config in mock_model_configs:
config.trainer.fit.assert_called_once()


@patch("hyperbench.train.trainer.L.Trainer")
def test_fit_all_with_no_models(_):
multi_model_trainer = MultiModelTrainer([])

with pytest.raises(ValueError, match="No models to fit."):
multi_model_trainer.fit_all(verbose=False)


@patch("hyperbench.train.trainer.L.Trainer", return_value=None)
def test_fit_all_raises_when_None_trainer(_, mock_model_configs):
multi_model_trainer = MultiModelTrainer(mock_model_configs)

with pytest.raises(
ValueError,
match=f"Trainer not defined for model {mock_model_configs[0].full_model_name()}.",
):
multi_model_trainer.fit_all(verbose=False)


@patch(
"hyperbench.train.trainer.L.Trainer",
side_effect=lambda *args, **kwargs: new_mock_trainer(),
)
def test_fit_all_with_verbose_true_prints(_, mock_model_configs, caplog):
multi_model_trainer = MultiModelTrainer(mock_model_configs)

with caplog.at_level("INFO"):
multi_model_trainer.fit_all(verbose=True)

for config in mock_model_configs:
config.trainer.fit.assert_called_once()

logs = [
record.message for record in caplog.records if "Fit model" in record.message
]
assert len(logs) == len(mock_model_configs)


@patch(
"hyperbench.train.trainer.L.Trainer",
side_effect=lambda *args, **kwargs: new_mock_trainer(),
)
def test_test_all_calls_test_and_returns_results(_, mock_model_configs):
multi_model_trainer = MultiModelTrainer(mock_model_configs)

results = multi_model_trainer.test_all(verbose=False)

assert all("acc" in v for v in results.values())

for config in mock_model_configs:
config.trainer.test.assert_called_once()


@patch("hyperbench.train.trainer.L.Trainer")
def test_test_all_with_no_models(_):
multi_model_trainer = MultiModelTrainer([])

with pytest.raises(ValueError, match="No models to test."):
multi_model_trainer.test_all(verbose=False)


@patch("hyperbench.train.trainer.L.Trainer", return_value=None)
def test_test_all_raises_when_None_trainer(_, mock_model_configs):
multi_model_trainer = MultiModelTrainer(mock_model_configs)

with pytest.raises(
ValueError,
match=f"Trainer not defined for model {mock_model_configs[0].full_model_name()}.",
):
multi_model_trainer.test_all(verbose=False)


@patch(
"hyperbench.train.trainer.L.Trainer",
side_effect=lambda *args, **kwargs: new_mock_trainer(),
)
def test_test_all_with_verbose_true_prints(_, mock_model_configs, caplog):
multi_model_trainer = MultiModelTrainer(mock_model_configs)

with caplog.at_level("INFO"):
multi_model_trainer.test_all(verbose=True)

for config in mock_model_configs:
config.trainer.test.assert_called_once()

logs = [
record.message for record in caplog.records if "Test model" in record.message
]
assert len(logs) == len(mock_model_configs)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
47 changes: 47 additions & 0 deletions hyperbench/tests/types/model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from hyperbench.types import ModelConfig
from unittest.mock import MagicMock
from hyperbench.types.model import ModelConfig


@pytest.fixture
def mock_model():
return MagicMock()


@pytest.fixture
def mock_trainer():
return MagicMock()


def test_model_config_initialization_with_trainer(mock_model, mock_trainer):
model_config = ModelConfig(
name="model", model=mock_model, version="0", trainer=mock_trainer
)

assert model_config.name == "model"
assert model_config.version == "0"
assert model_config.model is mock_model
assert model_config.trainer is mock_trainer


def test_model_config_initialization_without_trainer(mock_model):
mock_config = ModelConfig(name="test_model", model=mock_model)

assert mock_config.name == "test_model"
assert mock_config.version == "default"
assert mock_config.model is mock_model
assert mock_config.trainer is None


def test_full_model_name(mock_model):
mock_config = ModelConfig(name="foo", model=mock_model, version="bar")

assert mock_config.full_model_name() == "foo:bar"


def test_full_model_name_default_version(mock_model):
mock_config = ModelConfig(name="foo", model=mock_model)

assert mock_config.full_model_name() == "foo:default"
5 changes: 5 additions & 0 deletions hyperbench/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .trainer import MultiModelTrainer

__all__ = [
"MultiModelTrainer",
]
Loading