Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 152 additions & 2 deletions hyperbench/tests/types/hdata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@pytest.fixture
def mock_hdata():
def mock_hdata() -> HData:
x = torch.randn(5, 4) # 5 nodes with 4 features each
hyperedge_index = torch.tensor(
[
Expand All @@ -26,7 +26,7 @@ def mock_hdata():


@pytest.fixture
def mock_hdata_stats():
def mock_hdata_stats() -> HData:
x = torch.tensor(
[
[0.0, 1.0, 2.0, 3.0],
Expand All @@ -45,6 +45,24 @@ def mock_hdata_stats():
return HData(x=x, hyperedge_index=hyperedge_index)


@pytest.fixture
def hdata_with_all_mutable_tensors() -> HData:
x = torch.arange(10, dtype=torch.float).reshape(5, 2)
hyperedge_index = torch.tensor([[0, 1, 2, 2, 3, 4], [0, 0, 1, 1, 2, 2]])
hyperedge_weights = torch.tensor([0.1, 0.2, 0.3])
hyperedge_attr = torch.tensor([[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]])
global_node_ids = torch.tensor([10, 20, 30, 40, 50])
y = torch.tensor([1.0, 0.0, 1.0])
return HData(
x=x,
hyperedge_index=hyperedge_index,
hyperedge_weights=hyperedge_weights,
hyperedge_attr=hyperedge_attr,
global_node_ids=global_node_ids,
y=y,
)


@pytest.fixture
def mock_negative_sampler() -> tuple[NegativeSampler, MagicMock]:
def sample(data: HData, seed: int | None = None) -> HData:
Expand Down Expand Up @@ -494,6 +512,36 @@ def test_cat_same_node_space_drops_hyperedge_attr_when_partially_missing():
assert result.hyperedge_attr is None


def test_cat_same_node_space_does_not_share_mutable_storage_with_inputs(
hdata_with_all_mutable_tensors,
):
hdata = hdata_with_all_mutable_tensors
other_hdata = HData(
x=hdata.x,
hyperedge_index=torch.tensor([[0, 4], [3, 3]]),
hyperedge_weights=torch.tensor([0.4]),
hyperedge_attr=torch.tensor([[4.0, 40.0]]),
global_node_ids=hdata.global_node_ids,
y=torch.tensor([0.0]),
)

result = HData.cat_same_node_space([hdata, other_hdata])

__assert_mutating_result_keeps_source_tensors_unchanged(hdata, result)
__assert_mutating_result_keeps_source_tensors_unchanged(other_hdata, result)


def test_cat_same_node_space_clones_provided_x(hdata_with_all_mutable_tensors):
hdata = hdata_with_all_mutable_tensors
custom_x = torch.full_like(hdata.x, 9.0)
original_custom_x = custom_x.clone()

result = HData.cat_same_node_space([hdata], x=custom_x)
result.x.flatten()[0].add_(1)

assert torch.equal(custom_x, original_custom_x)


def test_add_negative_samples_combines_positive_and_negative_hyperedges(mock_negative_sampler):
hdata = HData(
x=torch.arange(4, dtype=torch.float).unsqueeze(1),
Expand Down Expand Up @@ -739,6 +787,20 @@ def test_split_transductive_handles_none_global_node_ids():
assert torch.equal(result.global_node_ids, torch.arange(hdata.num_nodes))


def test_split_transductive_does_not_share_mutable_storage_with_source(
hdata_with_all_mutable_tensors,
):
hdata = hdata_with_all_mutable_tensors

result = HData.split(
hdata,
split_hyperedge_ids=torch.tensor([1]),
node_space_setting="transductive",
)

__assert_mutating_result_keeps_source_tensors_unchanged(hdata, result)


def test_split_subsets_edge_attr():
x = torch.randn(4, 2)
hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
Expand Down Expand Up @@ -815,6 +877,14 @@ def test_with_y_zeros_returns_all_zeros(mock_hdata):
assert torch.equal(hdata.y, torch.zeros(mock_hdata.num_hyperedges, dtype=torch.float))


def test_with_y_to_does_not_share_mutable_storage_with_source(hdata_with_all_mutable_tensors):
hdata = hdata_with_all_mutable_tensors

result = hdata.with_y_to(0.5)

__assert_mutating_result_keeps_source_tensors_unchanged(hdata, result)


def test_enrich_node_features_replace(mock_hdata):
enricher = MagicMock(spec=NodeEnricher)
enriched_x = torch.randn(5, 3)
Expand Down Expand Up @@ -1073,6 +1143,40 @@ def test_enrich_node_features_from_non_transductive_raises_on_fill_value_shape_m
)


def test_enrich_methods_do_not_share_mutable_storage_with_source(hdata_with_all_mutable_tensors):
node_enricher = MagicMock(spec=NodeEnricher)
node_enricher.enrich.return_value = torch.ones(
(hdata_with_all_mutable_tensors.num_nodes, 1), dtype=hdata_with_all_mutable_tensors.x.dtype
)

feature_source_hdata = hdata_with_all_mutable_tensors.clone()
feature_source_hdata.x = torch.full((hdata_with_all_mutable_tensors.num_nodes, 2), 7.0)

weight_enricher = MagicMock(spec=HyperedgeEnricher)
weight_enricher.enrich.return_value = torch.full(
(hdata_with_all_mutable_tensors.num_hyperedges,), 0.8
)

attr_enricher = MagicMock(spec=HyperedgeEnricher)
attr_enricher.enrich.return_value = torch.full(
(hdata_with_all_mutable_tensors.num_hyperedges, 2), 8.0
)

results = [
hdata_with_all_mutable_tensors.enrich_node_features(
node_enricher, enrichment_mode="concatenate"
),
hdata_with_all_mutable_tensors.enrich_node_features_from(feature_source_hdata),
hdata_with_all_mutable_tensors.enrich_hyperedge_weights(weight_enricher),
hdata_with_all_mutable_tensors.enrich_hyperedge_attr(attr_enricher),
]

for result in results:
__assert_mutating_result_keeps_source_tensors_unchanged(
hdata_with_all_mutable_tensors, result
)


def test_enrich_hyperedge_weights_replace():
x = torch.tensor([[1.0], [2.0], [3.0]])
hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]])
Expand Down Expand Up @@ -1333,6 +1437,24 @@ def test_shuffle_with_no_seed_set(mock_hdata):
assert shuffled_hdata1.hyperedge_index.shape == mock_hdata.hyperedge_index.shape


def test_shuffle_does_not_share_mutable_storage_with_source(hdata_with_all_mutable_tensors):
hdata = hdata_with_all_mutable_tensors

result = hdata.shuffle(seed=42)

__assert_mutating_result_keeps_source_tensors_unchanged(hdata, result)


def test_from_hyperedge_index_clones_caller_owned_tensor():
hyperedge_index = torch.tensor([[0, 1], [0, 0]])
original_hyperedge_index = hyperedge_index.clone()

result = HData.from_hyperedge_index(hyperedge_index)
result.hyperedge_index.flatten()[0].add_(1)

assert torch.equal(hyperedge_index, original_hyperedge_index)


def test_stats_returns_correct_statistics(mock_hdata_stats):
expected_stats = {
"shape_x": torch.Size([4, 4]),
Expand Down Expand Up @@ -1583,3 +1705,31 @@ def test_stats_with_empty_hdata():
stats = empty_hdata.stats()

assert stats == expected_stats


def __assert_mutating_result_keeps_source_tensors_unchanged(
Comment thread
ddevin96 marked this conversation as resolved.
source: HData,
result: HData,
field_names: tuple[str, ...] = (
"x",
"hyperedge_index",
"hyperedge_weights",
"hyperedge_attr",
"global_node_ids",
"y",
),
) -> None:
original_tensors = {
field_name: getattr(source, field_name).clone()
for field_name in field_names
if getattr(source, field_name) is not None
}

for field_name in field_names:
source_tensor = getattr(source, field_name)
result_tensor = getattr(result, field_name)
if source_tensor is None or result_tensor is None or result_tensor.numel() == 0:
continue

result_tensor.flatten()[0].add_(1)
assert torch.equal(source_tensor, original_tensors[field_name])
27 changes: 27 additions & 0 deletions hyperbench/tests/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from hyperbench.utils import (
clone_optional_tensor,
empty_nodefeatures,
empty_hyperedgeindex,
empty_edgeattr,
Expand All @@ -10,6 +11,32 @@
)


def test_clone_optional_tensor_with_none():
result = clone_optional_tensor(None)

assert result is None


def test_clone_optional_tensor_with_tensor_preserves_values():
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])

result = clone_optional_tensor(tensor)

assert result is not None
assert torch.equal(result, tensor)


def test_clone_optional_tensor_with_tensor_does_not_share_storage():
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])

result = clone_optional_tensor(tensor)
assert result is not None

result[0, 0] = 99.0
print(tensor, result)
assert tensor[0, 0] == 1.0


def test_empty_edgeindex():
result = empty_hyperedgeindex()

Expand Down
Loading
Loading