Skip to content

Commit 04c7665

Browse files
feat: add OptimalBinarySearchTree algorithm
1 parent 5e06b15 commit 04c7665

File tree

2 files changed

+203
-0
lines changed

2 files changed

+203
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package com.thealgorithms.dynamicprogramming;
2+
3+
import java.util.Arrays;
4+
import java.util.Comparator;
5+
6+
/**
7+
* Computes the minimum search cost of an optimal binary search tree.
8+
*
9+
* <p>The algorithm sorts the keys, preserves the corresponding search frequencies, and uses
10+
* dynamic programming with Knuth's optimization to compute the minimum weighted search cost.
11+
*
12+
* <p>Example: if keys = [10, 12] and frequencies = [34, 50], the best tree puts 12 at the root
13+
* and 10 as its left child. The total cost is 50 * 1 + 34 * 2 = 118.
14+
*
15+
* <p>Reference:
16+
* https://en.wikipedia.org/wiki/Optimal_binary_search_tree
17+
*/
18+
public final class OptimalBinarySearchTree {
19+
private OptimalBinarySearchTree() {
20+
}
21+
22+
/**
23+
* Computes the minimum weighted search cost for the given keys and search frequencies.
24+
*
25+
* @param keys the BST keys
26+
* @param frequencies the search frequencies associated with the keys
27+
* @return the minimum search cost
28+
* @throws IllegalArgumentException if the input is invalid
29+
*/
30+
public static long findOptimalCost(int[] keys, int[] frequencies) {
31+
validateInput(keys, frequencies);
32+
if (keys.length == 0) {
33+
return 0L;
34+
}
35+
36+
int[][] sortedNodes = sortNodes(keys, frequencies);
37+
int nodeCount = sortedNodes.length;
38+
long[] prefixSums = buildPrefixSums(sortedNodes);
39+
long[][] optimalCost = new long[nodeCount][nodeCount];
40+
int[][] root = new int[nodeCount][nodeCount];
41+
42+
// Small example:
43+
// keys = [10, 12]
44+
// frequencies = [34, 50]
45+
// Choosing 12 as the root gives cost 50 * 1 + 34 * 2 = 118,
46+
// which is better than choosing 10 as the root.
47+
48+
// Base case: a subtree containing one key has cost equal to its frequency,
49+
// because that key becomes the root of the subtree and is searched at depth 1.
50+
for (int index = 0; index < nodeCount; index++) {
51+
optimalCost[index][index] = sortedNodes[index][1];
52+
root[index][index] = index;
53+
}
54+
55+
// Build solutions for longer and longer key ranges.
56+
// optimalCost[start][end] stores the minimum search cost for keys in that range.
57+
for (int length = 2; length <= nodeCount; length++) {
58+
for (int start = 0; start <= nodeCount - length; start++) {
59+
int end = start + length - 1;
60+
61+
// Every key in this range moves one level deeper when we choose a root,
62+
// so the sum of frequencies is added once to the subtree cost.
63+
long frequencySum = prefixSums[end + 1] - prefixSums[start];
64+
optimalCost[start][end] = Long.MAX_VALUE;
65+
66+
// Knuth's optimization:
67+
// the best root for [start, end] lies between the best roots of
68+
// [start, end - 1] and [start + 1, end], so we search only this interval.
69+
int leftBoundary = root[start][end - 1];
70+
int rightBoundary = root[start + 1][end];
71+
for (int currentRoot = leftBoundary; currentRoot <= rightBoundary; currentRoot++) {
72+
long leftCost = currentRoot > start ? optimalCost[start][currentRoot - 1] : 0L;
73+
long rightCost = currentRoot < end ? optimalCost[currentRoot + 1][end] : 0L;
74+
long currentCost = frequencySum + leftCost + rightCost;
75+
76+
if (currentCost < optimalCost[start][end]) {
77+
optimalCost[start][end] = currentCost;
78+
root[start][end] = currentRoot;
79+
}
80+
}
81+
}
82+
}
83+
84+
return optimalCost[0][nodeCount - 1];
85+
}
86+
87+
private static void validateInput(int[] keys, int[] frequencies) {
88+
if (keys == null || frequencies == null) {
89+
throw new IllegalArgumentException("Keys and frequencies cannot be null");
90+
}
91+
if (keys.length != frequencies.length) {
92+
throw new IllegalArgumentException("Keys and frequencies must have the same length");
93+
}
94+
95+
for (int frequency : frequencies) {
96+
if (frequency < 0) {
97+
throw new IllegalArgumentException("Frequencies cannot be negative");
98+
}
99+
}
100+
}
101+
102+
private static int[][] sortNodes(int[] keys, int[] frequencies) {
103+
int[][] sortedNodes = new int[keys.length][2];
104+
for (int index = 0; index < keys.length; index++) {
105+
sortedNodes[index][0] = keys[index];
106+
sortedNodes[index][1] = frequencies[index];
107+
}
108+
109+
// Sort by key so the nodes can be treated as an in-order BST sequence.
110+
Arrays.sort(sortedNodes, Comparator.comparingInt(node -> node[0]));
111+
112+
for (int index = 1; index < sortedNodes.length; index++) {
113+
if (sortedNodes[index - 1][0] == sortedNodes[index][0]) {
114+
throw new IllegalArgumentException("Keys must be distinct");
115+
}
116+
}
117+
118+
return sortedNodes;
119+
}
120+
121+
private static long[] buildPrefixSums(int[][] sortedNodes) {
122+
long[] prefixSums = new long[sortedNodes.length + 1];
123+
for (int index = 0; index < sortedNodes.length; index++) {
124+
// prefixSums[i] holds the total frequency of the first i sorted keys.
125+
// This lets us get the frequency sum of any range in O(1) time.
126+
prefixSums[index + 1] = prefixSums[index] + sortedNodes[index][1];
127+
}
128+
return prefixSums;
129+
}
130+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package com.thealgorithms.dynamicprogramming;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
6+
import java.util.Arrays;
7+
import java.util.stream.Stream;
8+
import org.junit.jupiter.params.ParameterizedTest;
9+
import org.junit.jupiter.params.provider.Arguments;
10+
import org.junit.jupiter.params.provider.MethodSource;
11+
12+
class OptimalBinarySearchTreeTest {
13+
14+
@ParameterizedTest
15+
@MethodSource("validTestCases")
16+
void testFindOptimalCost(int[] keys, int[] frequencies, long expectedCost) {
17+
assertEquals(expectedCost, OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
18+
}
19+
20+
private static Stream<Arguments> validTestCases() {
21+
return Stream.of(Arguments.of(new int[] {}, new int[] {}, 0L), Arguments.of(new int[] {15}, new int[] {9}, 9L), Arguments.of(new int[] {10, 12}, new int[] {34, 50}, 118L), Arguments.of(new int[] {20, 10, 30}, new int[] {50, 34, 8}, 134L),
22+
Arguments.of(new int[] {12, 10, 20, 42, 25, 37}, new int[] {8, 34, 50, 3, 40, 30}, 324L), Arguments.of(new int[] {1, 2, 3}, new int[] {0, 0, 0}, 0L));
23+
}
24+
25+
@ParameterizedTest
26+
@MethodSource("crossCheckTestCases")
27+
void testFindOptimalCostAgainstBruteForce(int[] keys, int[] frequencies) {
28+
assertEquals(bruteForceOptimalCost(keys, frequencies), OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
29+
}
30+
31+
private static Stream<Arguments> crossCheckTestCases() {
32+
return Stream.of(Arguments.of(new int[] {3, 1, 2}, new int[] {4, 2, 6}), Arguments.of(new int[] {5, 2, 8, 6}, new int[] {3, 7, 1, 4}), Arguments.of(new int[] {9, 4, 11, 2}, new int[] {1, 8, 2, 5}));
33+
}
34+
35+
@ParameterizedTest
36+
@MethodSource("invalidTestCases")
37+
void testFindOptimalCostInvalidInput(int[] keys, int[] frequencies) {
38+
assertThrows(IllegalArgumentException.class, () -> OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
39+
}
40+
41+
private static Stream<Arguments> invalidTestCases() {
42+
return Stream.of(Arguments.of(null, new int[] {}), Arguments.of(new int[] {}, null), Arguments.of(new int[] {1, 2}, new int[] {3}), Arguments.of(new int[] {1, 1}, new int[] {2, 3}), Arguments.of(new int[] {1, 2}, new int[] {3, -1}));
43+
}
44+
45+
private static long bruteForceOptimalCost(int[] keys, int[] frequencies) {
46+
int[][] sortedNodes = new int[keys.length][2];
47+
for (int index = 0; index < keys.length; index++) {
48+
sortedNodes[index][0] = keys[index];
49+
sortedNodes[index][1] = frequencies[index];
50+
}
51+
Arrays.sort(sortedNodes, java.util.Comparator.comparingInt(node -> node[0]));
52+
53+
int[] sortedFrequencies = new int[sortedNodes.length];
54+
for (int index = 0; index < sortedNodes.length; index++) {
55+
sortedFrequencies[index] = sortedNodes[index][1];
56+
}
57+
58+
return bruteForceOptimalCost(sortedFrequencies, 0, sortedFrequencies.length - 1, 1);
59+
}
60+
61+
private static long bruteForceOptimalCost(int[] frequencies, int start, int end, int depth) {
62+
if (start > end) {
63+
return 0L;
64+
}
65+
66+
long minimumCost = Long.MAX_VALUE;
67+
for (int root = start; root <= end; root++) {
68+
long currentCost = (long) depth * frequencies[root] + bruteForceOptimalCost(frequencies, start, root - 1, depth + 1) + bruteForceOptimalCost(frequencies, root + 1, end, depth + 1);
69+
minimumCost = Math.min(minimumCost, currentCost);
70+
}
71+
return minimumCost;
72+
}
73+
}

0 commit comments

Comments
 (0)