Skip to content

Commit 95a9ee6

Browse files
committed
rbm+vae codes
1 parent 2b3d91b commit 95a9ee6

File tree

11 files changed

+227
-0
lines changed

11 files changed

+227
-0
lines changed

doc/src/week11/programs/rbm.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
from torch.utils.data import DataLoader
5+
import torchvision.transforms as transforms
6+
import torchvision.datasets as datasets
7+
import matplotlib.pyplot as plt
8+
9+
# Hyperparameters
10+
input_size = 784 # 28x28 images
11+
hidden_size = 128
12+
batch_size = 64
13+
learning_rate = 0.01
14+
num_epochs = 10
15+
k = 1 # Number of Gibbs sampling steps
16+
17+
# Device configuration
18+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19+
20+
# MNIST dataset
21+
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
22+
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
23+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
24+
25+
# Restricted Boltzmann Machine
26+
class RBM(nn.Module):
27+
def __init__(self, visible_dim, hidden_dim):
28+
super(RBM, self).__init__()
29+
self.W = nn.Parameter(torch.randn(hidden_dim, visible_dim) * 0.01)
30+
self.h_bias = nn.Parameter(torch.zeros(hidden_dim))
31+
self.v_bias = nn.Parameter(torch.zeros(visible_dim))
32+
33+
def sample_from_p(self, p):
34+
return torch.bernoulli(p)
35+
36+
def v_to_h(self, v):
37+
p_h_given_v = torch.sigmoid(torch.matmul(v, self.W.t()) + self.h_bias)
38+
return p_h_given_v, self.sample_from_p(p_h_given_v)
39+
40+
def h_to_v(self, h):
41+
p_v_given_h = torch.sigmoid(torch.matmul(h, self.W) + self.v_bias)
42+
return p_v_given_h, self.sample_from_p(p_v_given_h)
43+
44+
def forward(self, v):
45+
# Gibbs sampling
46+
h_prob, h_sample = self.v_to_h(v)
47+
for _ in range(k):
48+
v_prob, v_sample = self.h_to_v(h_sample)
49+
h_prob, h_sample = self.v_to_h(v_sample)
50+
return v, v_prob
51+
52+
def free_energy(self, v):
53+
vbias_term = torch.matmul(v, self.v_bias.unsqueeze(1)).squeeze()
54+
wx_b = torch.matmul(v, self.W.t()) + self.h_bias
55+
hidden_term = torch.sum(torch.log(1 + torch.exp(wx_b)), dim=1)
56+
return -hidden_term - vbias_term
57+
58+
# Initialize RBM
59+
rbm = RBM(visible_dim=input_size, hidden_dim=hidden_size).to(device)
60+
61+
# Optimizer
62+
optimizer = optim.SGD(rbm.parameters(), lr=learning_rate)
63+
64+
# Training loop
65+
for epoch in range(num_epochs):
66+
epoch_loss = 0
67+
for batch_idx, (data, _) in enumerate(train_loader):
68+
data = data.to(device)
69+
70+
# Forward pass
71+
v, v_prob = rbm(data)
72+
73+
# Compute loss (contrastive divergence)
74+
loss = rbm.free_energy(data) - rbm.free_energy(v_prob)
75+
loss = loss.mean()
76+
77+
# Backward pass
78+
optimizer.zero_grad()
79+
loss.backward()
80+
optimizer.step()
81+
82+
epoch_loss += loss.item()
83+
84+
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}')
85+
86+
# Function to visualize reconstructed images
87+
def visualize_reconstructions(rbm, data_loader, num_images=5):
88+
rbm.eval()
89+
with torch.no_grad():
90+
for batch_idx, (data, _) in enumerate(data_loader):
91+
data = data.to(device)
92+
_, v_prob = rbm(data)
93+
v_prob = v_prob.view(-1, 28, 28).cpu()
94+
data = data.view(-1, 28, 28).cpu()
95+
96+
for i in range(num_images):
97+
plt.figure(figsize=(6, 3))
98+
plt.subplot(1, 2, 1)
99+
plt.imshow(data[i], cmap='gray')
100+
plt.title('Original')
101+
plt.axis('off')
102+
103+
plt.subplot(1, 2, 2)
104+
plt.imshow(v_prob[i], cmap='gray')
105+
plt.title('Reconstructed')
106+
plt.axis('off')
107+
108+
plt.show()
109+
break
110+
111+
# Visualize some reconstructed images
112+
visualize_reconstructions(rbm, train_loader)
113+
114+
115+
"""
116+
### Explanation:
117+
1. **RBM Class**:
118+
- The `RBM` class defines the weights (`W`), hidden biases (`h_bias`), and visible biases (`v_bias`).
119+
- It includes methods for sampling from probabilities (`sample_from_p`), converting visible to hidden units (`v_to_h`), and converting hidden to visible units (`h_to_v`).
120+
- The `forward` method performs Gibbs sampling to reconstruct the input.
121+
- The `free_energy` method computes the free energy of the RBM, which is used in the loss function.
122+
123+
2. **Training**:
124+
- The training loop uses Contrastive Divergence (CD-k) to update the weights and biases.
125+
- The loss is computed as the difference in free energy between the original data and the reconstructed data.
126+
127+
3. **Visualization**:
128+
- After training, the `visualize_reconstructions` function displays some original and reconstructed images to evaluate the RBM's performance.
129+
130+
### Notes:
131+
- RBMs are unsupervised models, so we don't use labels during training.
132+
- The number of Gibbs sampling steps (`k`) is typically small (e.g., 1 or 2) for efficiency.
133+
- You can experiment with different hyperparameters like `hidden_size`, `learning_rate`, and `num_epochs` to improve performance.
134+
"""

doc/src/week13/programs/.DS_Store

6 KB
Binary file not shown.
-7.48 MB
Binary file not shown.
-1.57 MB
Binary file not shown.
-9.77 KB
Binary file not shown.
-4.44 KB
Binary file not shown.
-44.9 MB
Binary file not shown.
-9.45 MB
Binary file not shown.
-58.6 KB
Binary file not shown.
-28.2 KB
Binary file not shown.

0 commit comments

Comments
 (0)