Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/boruvkas_algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Boruvka's algorithm for finding minimum spanning trees."""

from boruvkas_algorithm.boruvka import Graph, find_mst_with_boruvkas_algorithm
from boruvkas_algorithm.union_find import UnionFind

__all__: list[str] = ["Graph", "UnionFind", "find_mst_with_boruvkas_algorithm"]
64 changes: 10 additions & 54 deletions src/boruvkas_algorithm/boruvka.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import matplotlib.pyplot as plt
import networkx as nx

from boruvkas_algorithm.union_find import UnionFind


class Graph:
"""A graph that contains nodes and edges."""
Expand Down Expand Up @@ -78,69 +80,23 @@ def draw_mst(self, mst_edges: list[tuple[int, int, int]]) -> None:

def find_mst_with_boruvkas_algorithm(
graph: Graph,
union_find: UnionFind | None = None,
) -> tuple[int, list[tuple[int, int, int]]]:
"""
Finds the minimum spanning tree (MST) of a graph using Boruvka's algorithm.

Args:
graph: The graph to find the MST of.
union_find: Optional UnionFind instance for tracking components. If not
provided, a new one will be created.

Returns:
A tuple containing the total weight of the MST and a list of the
edges in the MST.
"""

def find(node: int) -> int:
"""
Finds the root parent of the node using path compression.

Args:
node: The node to find the root parent of.

Returns:
The root parent of the node.
"""
cur_parent = parent[node]
while cur_parent != parent[cur_parent]:
# Compress the links as we go up the chain of parents to make
# it faster to traverse in the future - amortised O(a(n)) time,
# where a(n) is the inverse Ackermann function.
parent[cur_parent] = parent[parent[cur_parent]]
cur_parent = parent[cur_parent]
return cur_parent

def union(node1: int, node2: int) -> bool:
"""
Combines the two nodes into the larger segment.

Args:
node1: The first node to combine.
node2: The second node to combine.

Returns:
True if the nodes were combined, False if they were already in the
same segment.
"""
root1 = find(node1)
root2 = find(node2)
# If they have the same root parent, they're already connected.
if root1 == root2:
return False

# Combine the two nodes into the larger segment based on the rank.
if rank[root1] > rank[root2]:
parent[root2] = root1
rank[root1] += rank[root2]
else:
parent[root1] = root2
rank[root2] += rank[root1]
return True

num_vertices = len(graph.vertices)
# Each node is its own parent initially.
parent: list[int] = list(range(num_vertices))
# Each tree has size 1 (itself) initially.
rank: list[int] = [1] * num_vertices
if union_find is None:
union_find = UnionFind(num_vertices)

print("\nFinding MST with Boruvka's algorithm:")
graph.print_graph_info()
Expand All @@ -164,7 +120,7 @@ def union(node1: int, node2: int) -> bool:
] * num_vertices
for edge in graph.edges:
node1, node2, weight = edge
comp1, comp2 = find(node1), find(node2)
comp1, comp2 = union_find.find(node1), union_find.find(node2)

if comp1 != comp2:
current_min1 = min_edge_per_component[comp1]
Expand All @@ -178,10 +134,10 @@ def union(node1: int, node2: int) -> bool:
for edge in min_edge_per_component:
if edge is not None:
node1, node2, weight = edge
if find(node1) != find(node2):
if union_find.find(node1) != union_find.find(node2):
mst_edges.append(edge)
mst_weight += weight
union(node1, node2)
union_find.union(node1, node2)
num_components -= 1
print(f"Added edge {node1} - {node2} with weight {weight} to MST.")

Expand Down
76 changes: 76 additions & 0 deletions src/boruvkas_algorithm/union_find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
class UnionFind:
"""
Union-find (disjoint set union) data structure for tracking connected
components with path compression and union by size.
"""

def __init__(self, size: int) -> None:
"""
Initialises the Union-Find structure.

Args:
size: The number of elements in the structure.
"""
# Each node is its own parent initially.
self.parent: list[int] = list(range(size))
# Each tree has size 1 (itself) initially.
self.rank: list[int] = [1] * size

def find(self, node: int) -> int:
"""
Finds the root parent of the node using path compression.

Args:
node: The node to find the root parent of.

Returns:
The root parent of the node.
"""
cur_parent = self.parent[node]
while cur_parent != self.parent[cur_parent]:
# Compress the links as we go up the chain of parents to make
# it faster to traverse in the future - amortised O(a(n)) time,
# where a(n) is the inverse Ackermann function.
self.parent[cur_parent] = self.parent[self.parent[cur_parent]]
cur_parent = self.parent[cur_parent]
return cur_parent

def union(self, node1: int, node2: int) -> bool:
"""
Combines the two nodes into the larger segment.

Args:
node1: The first node to combine.
node2: The second node to combine.

Returns:
True if the nodes were combined, False if they were already in the
same segment.
"""
root1 = self.find(node1)
root2 = self.find(node2)
# If they have the same root parent, they're already connected.
if root1 == root2:
return False

# Combine the two nodes into the larger segment based on the rank.
if self.rank[root1] > self.rank[root2]:
self.parent[root2] = root1
self.rank[root1] += self.rank[root2]
else:
self.parent[root1] = root2
self.rank[root2] += self.rank[root1]
return True

def is_connected(self, node1: int, node2: int) -> bool:
"""
Checks if two nodes are in the same component.

Args:
node1: The first node.
node2: The second node.

Returns:
True if the nodes are connected, False otherwise.
"""
return self.find(node1) == self.find(node2)
168 changes: 160 additions & 8 deletions tests/test_boruvka.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from boruvkas_algorithm.boruvka import Graph, find_mst_with_boruvkas_algorithm
from boruvkas_algorithm.union_find import UnionFind


@pytest.fixture
Expand All @@ -15,6 +16,16 @@ def setup_graph():
return Graph(9) # Example graph with 9 vertices.


def test_graph_initialization():
"""
Tests that a graph is initialised with the correct number of vertices and
no edges.
"""
graph = Graph(5)
assert len(graph.vertices) == 5, "Graph should have 5 vertices"
assert len(graph.edges) == 0, "Graph should be initialised with no edges"


def test_add_edge(setup_graph: Graph):
"""
Tests that edges are correctly added by checking the length of the edge
Expand All @@ -38,6 +49,131 @@ def test_add_edge_invalid_vertices(setup_graph: Graph):
graph.add_edge(10, 11, 5)


# =============================================================================
# UnionFind Tests
# =============================================================================


def test_union_find_initialization():
"""Tests that UnionFind initialises with correct parent and rank arrays."""
uf = UnionFind(5)
assert uf.parent == [0, 1, 2, 3, 4], "Each node should be its own parent"
assert uf.rank == [1, 1, 1, 1, 1], "Each node should have rank 1"


def test_union_find_find_single_node():
"""Tests that find returns the node itself when it's its own parent."""
uf = UnionFind(5)
assert uf.find(0) == 0
assert uf.find(4) == 4


def test_union_find_union_two_nodes():
"""Tests that union correctly combines two nodes."""
uf = UnionFind(5)
result = uf.union(0, 1)
assert result is True, "Union should return True when nodes are combined"
assert uf.find(0) == uf.find(1), "Nodes should have the same root after union"


def test_union_find_union_already_connected():
"""Tests that union returns False when nodes are already connected."""
uf = UnionFind(5)
uf.union(0, 1)
result = uf.union(0, 1)
assert result is False, "Union should return False when already connected"


def test_union_find_union_by_size():
"""Tests that smaller trees are merged into larger trees."""
uf = UnionFind(5)
# Create a larger tree: 0 <- 1, 0 <- 2
uf.union(0, 1)
uf.union(0, 2)
# Now union with node 3 - node 3 should be merged into the larger tree.
uf.union(3, 0)
# The root of the larger tree should remain the root.
root = uf.find(0)
assert uf.find(3) == root, "Smaller tree should be merged into larger tree"


def test_union_find_path_compression():
"""Tests that path compression flattens the tree structure."""
uf = UnionFind(5)
# Create a chain: 0 <- 1 <- 2 <- 3
uf.parent = [0, 0, 1, 2, 4]
uf.rank = [4, 1, 1, 1, 1]
# Find on node 3 should compress the path.
root = uf.find(3)
assert root == 0, "Root should be 0"
# After path compression, intermediate nodes should point closer to root.
assert uf.parent[2] in (0, 1), "Path compression should shorten the path"


def test_union_find_multiple_components():
"""Tests UnionFind with multiple separate components."""
uf = UnionFind(6)
# Create two components: {0, 1, 2} and {3, 4, 5}
uf.union(0, 1)
uf.union(1, 2)
uf.union(3, 4)
uf.union(4, 5)

# Check components are separate.
assert uf.find(0) == uf.find(1) == uf.find(2)
assert uf.find(3) == uf.find(4) == uf.find(5)
assert uf.find(0) != uf.find(3), "Components should be separate"

# Merge the two components.
uf.union(2, 3)
assert uf.find(0) == uf.find(5), "Components should be merged"


def test_union_find_is_connected():
"""Tests the is_connected convenience method."""
uf = UnionFind(5)
assert not uf.is_connected(0, 1), "Nodes should not be connected initially"

uf.union(0, 1)
assert uf.is_connected(0, 1), "Nodes should be connected after union"
assert not uf.is_connected(0, 2), "Unconnected nodes should return False"

uf.union(1, 2)
assert uf.is_connected(0, 2), "Transitively connected nodes should return True"


# =============================================================================
# MST Algorithm Tests
# =============================================================================


def test_mst_with_injected_union_find(setup_graph: Graph):
"""Tests that the algorithm works with an injected UnionFind instance."""
graph = setup_graph
graph.add_edge(0, 1, 4)
graph.add_edge(0, 6, 7)
graph.add_edge(1, 6, 11)
graph.add_edge(1, 7, 20)
graph.add_edge(1, 2, 9)
graph.add_edge(2, 3, 6)
graph.add_edge(2, 4, 2)
graph.add_edge(3, 4, 10)
graph.add_edge(3, 5, 5)
graph.add_edge(4, 5, 15)
graph.add_edge(4, 7, 1)
graph.add_edge(4, 8, 5)
graph.add_edge(5, 8, 12)
graph.add_edge(6, 7, 1)
graph.add_edge(7, 8, 3)

# Inject a custom UnionFind instance.
union_find = UnionFind(len(graph.vertices))
mst_weight, mst_edges = find_mst_with_boruvkas_algorithm(graph, union_find)

assert mst_weight == 29, "MST weight should be 29"
assert len(mst_edges) == 8, "MST should have 8 edges for 9 vertices"


def test_mst(setup_graph: Graph):
"""
Tests that the MST has the correct total weight and structure by comparing
Expand Down Expand Up @@ -80,11 +216,27 @@ def test_mst(setup_graph: Graph):
)


def test_graph_initialization():
"""
Test that a graph is initialized with the correct number of vertices and
no edges.
"""
graph = Graph(5) # Initialize a graph with 5 vertices.
assert len(graph.vertices) == 5, "Graph should have 5 vertices"
assert len(graph.edges) == 0, "Graph should be initialized with no edges"
def test_mst_simple_triangle():
"""Tests MST on a simple triangle graph."""
graph = Graph(3)
graph.add_edge(0, 1, 1)
graph.add_edge(1, 2, 2)
graph.add_edge(0, 2, 3)

mst_weight, mst_edges = find_mst_with_boruvkas_algorithm(graph)

assert mst_weight == 3, "MST weight should be 3 (edges 1 + 2)"
assert len(mst_edges) == 2, "MST should have 2 edges for 3 vertices"


def test_mst_linear_graph():
"""Tests MST on a linear graph (already a tree)."""
graph = Graph(4)
graph.add_edge(0, 1, 1)
graph.add_edge(1, 2, 2)
graph.add_edge(2, 3, 3)

mst_weight, mst_edges = find_mst_with_boruvkas_algorithm(graph)

assert mst_weight == 6, "MST weight should be 6 (1 + 2 + 3)"
assert len(mst_edges) == 3, "MST should have 3 edges for 4 vertices"
Loading