|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +class BinaryBinaryRBM: |
| 4 | + def __init__(self, n_visible, n_hidden, learning_rate=0.1): |
| 5 | + """ |
| 6 | + Initialize the RBM with given parameters. |
| 7 | + """ |
| 8 | + self.n_visible = n_visible |
| 9 | + self.n_hidden = n_hidden |
| 10 | + self.learning_rate = learning_rate |
| 11 | + |
| 12 | + # Initialize weights and biases |
| 13 | + self.weights = np.random.normal(0, 0.01, (n_visible, n_hidden)) # Weights |
| 14 | + self.visible_bias = np.zeros(n_visible) # Bias for visible units |
| 15 | + self.hidden_bias = np.zeros(n_hidden) # Bias for hidden units |
| 16 | + |
| 17 | + def sigmoid(self, x): |
| 18 | + """ |
| 19 | + Sigmoid activation function. |
| 20 | + """ |
| 21 | + return 1 / (1 + np.exp(-x)) |
| 22 | + |
| 23 | + def sample_hidden(self, v): |
| 24 | + """ |
| 25 | + Sample hidden units given visible units. |
| 26 | + """ |
| 27 | + activation = np.dot(v, self.weights) + self.hidden_bias |
| 28 | + prob_h = self.sigmoid(activation) |
| 29 | + return np.random.binomial(n=1, p=prob_h), prob_h |
| 30 | + |
| 31 | + def sample_visible(self, h): |
| 32 | + """ |
| 33 | + Sample visible units given hidden units. |
| 34 | + """ |
| 35 | + activation = np.dot(h, self.weights.T) + self.visible_bias |
| 36 | + prob_v = self.sigmoid(activation) |
| 37 | + return np.random.binomial(n=1, p=prob_v), prob_v |
| 38 | + |
| 39 | + def contrastive_divergence(self, v0, k=1): |
| 40 | + """ |
| 41 | + Contrastive Divergence (CD-k) algorithm for training the RBM. |
| 42 | + """ |
| 43 | + # Positive phase |
| 44 | + h0, prob_h0 = self.sample_hidden(v0) |
| 45 | + pos_associations = np.outer(v0, prob_h0) |
| 46 | + |
| 47 | + # Gibbs Sampling (k steps) |
| 48 | + v_k = v0 |
| 49 | + for _ in range(k): |
| 50 | + h_k, _ = self.sample_hidden(v_k) |
| 51 | + v_k, _ = self.sample_visible(h_k) |
| 52 | + |
| 53 | + # Negative phase |
| 54 | + h_k, prob_h_k = self.sample_hidden(v_k) |
| 55 | + neg_associations = np.outer(v_k, prob_h_k) |
| 56 | + |
| 57 | + # Update weights and biases |
| 58 | + self.weights += self.learning_rate * (pos_associations - neg_associations) |
| 59 | + self.visible_bias += self.learning_rate * (v0 - v_k) |
| 60 | + self.hidden_bias += self.learning_rate * (prob_h0 - prob_h_k) |
| 61 | + |
| 62 | + def train(self, data, epochs=1000, batch_size=10): |
| 63 | + """ |
| 64 | + Train the RBM using mini-batch gradient descent. |
| 65 | + """ |
| 66 | + n_samples = data.shape[0] |
| 67 | + for epoch in range(epochs): |
| 68 | + np.random.shuffle(data) |
| 69 | + for i in range(0, n_samples, batch_size): |
| 70 | + batch = data[i:i+batch_size] |
| 71 | + for v in batch: |
| 72 | + self.contrastive_divergence(v) |
| 73 | + if (epoch + 1) % 100 == 0: |
| 74 | + error = self.reconstruction_error(data) |
| 75 | + print(f"Epoch {epoch + 1}/{epochs} - Reconstruction Error: {error:.4f}") |
| 76 | + |
| 77 | + def reconstruction_error(self, data): |
| 78 | + """ |
| 79 | + Compute reconstruction error for the dataset. |
| 80 | + """ |
| 81 | + error = 0 |
| 82 | + for v in data: |
| 83 | + _, prob_h = self.sample_hidden(v) |
| 84 | + _, prob_v = self.sample_visible(prob_h) |
| 85 | + error += np.linalg.norm(v - prob_v) |
| 86 | + return error / len(data) |
| 87 | + |
| 88 | + def reconstruct(self, v): |
| 89 | + """ |
| 90 | + Reconstruct a visible vector after one pass through hidden units. |
| 91 | + """ |
| 92 | + _, prob_h = self.sample_hidden(v) |
| 93 | + _, prob_v = self.sample_visible(prob_h) |
| 94 | + return prob_v |
| 95 | + |
| 96 | +# Generate synthetic binary data |
| 97 | +np.random.seed(42) |
| 98 | +data = np.random.binomial(n=1, p=0.5, size=(100, 6)) # 100 samples, 6 visible units |
| 99 | + |
| 100 | +# Initialize and train the RBM |
| 101 | +rbm = BinaryBinaryRBM(n_visible=6, n_hidden=3, learning_rate=0.1) |
| 102 | +rbm.train(data, epochs=1000, batch_size=10) |
| 103 | + |
| 104 | +# Test the reconstruction |
| 105 | +sample = np.array([1, 0, 1, 0, 1, 0]) |
| 106 | +reconstructed = rbm.reconstruct(sample) |
| 107 | +print(f"\nOriginal: {sample}") |
| 108 | +print(f"Reconstructed: {np.round(reconstructed, 2)}") |
0 commit comments