Skip to content

Commit c61dfe2

Browse files
committed
added code
1 parent 371a341 commit c61dfe2

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed
Binary file not shown.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import torch.nn.functional as F
5+
from torchvision import datasets, transforms
6+
from torch.utils.data import DataLoader
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
10+
# Define the VAE
11+
class VAE(nn.Module):
12+
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=2):
13+
super(VAE, self).__init__()
14+
self.fc1 = nn.Linear(input_dim, hidden_dim)
15+
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
16+
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
17+
self.fc3 = nn.Linear(latent_dim, hidden_dim)
18+
self.fc4 = nn.Linear(hidden_dim, input_dim)
19+
20+
def encode(self, x):
21+
h1 = F.relu(self.fc1(x))
22+
return self.fc_mu(h1), self.fc_logvar(h1)
23+
24+
def reparameterize(self, mu, logvar):
25+
std = torch.exp(0.5 * logvar)
26+
eps = torch.randn_like(std)
27+
return mu + eps * std
28+
29+
def decode(self, z):
30+
h3 = F.relu(self.fc3(z))
31+
return torch.sigmoid(self.fc4(h3))
32+
33+
def forward(self, x):
34+
mu, logvar = self.encode(x.view(-1, 784))
35+
z = self.reparameterize(mu, logvar)
36+
return self.decode(z), mu, logvar
37+
38+
# Loss function
39+
def loss_function(recon_x, x, mu, logvar):
40+
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
41+
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
42+
return BCE + KLD
43+
44+
# Prepare data
45+
transform = transforms.ToTensor()
46+
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
47+
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
48+
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
49+
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
50+
51+
# Set device
52+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53+
model = VAE().to(device)
54+
optimizer = optim.Adam(model.parameters(), lr=1e-3)
55+
56+
# Training function
57+
def train(epoch):
58+
model.train()
59+
train_loss = 0
60+
for batch_idx, (data, _) in enumerate(train_loader):
61+
data = data.to(device)
62+
optimizer.zero_grad()
63+
recon_batch, mu, logvar = model(data)
64+
loss = loss_function(recon_batch, data, mu, logvar)
65+
loss.backward()
66+
train_loss += loss.item()
67+
optimizer.step()
68+
print(f"Epoch {epoch}: Avg Loss: {train_loss / len(train_loader.dataset):.4f}")
69+
70+
# Run training
71+
for epoch in range(1, 11):
72+
train(epoch)
73+
74+
# Reconstruction Visualization
75+
model.eval()
76+
with torch.no_grad():
77+
data, _ = next(iter(test_loader))
78+
data = data.to(device)
79+
recon_batch, _, _ = model(data)
80+
81+
n = 8
82+
comparison = torch.cat([data[:n], recon_batch.view(-1, 1, 28, 28)[:n]])
83+
comparison = comparison.cpu()
84+
85+
plt.figure(figsize=(12, 3))
86+
for i in range(n):
87+
plt.subplot(2, n, i + 1)
88+
plt.imshow(comparison[i][0], cmap='gray')
89+
plt.axis('off')
90+
plt.subplot(2, n, i + 1 + n)
91+
plt.imshow(comparison[i + n][0], cmap='gray')
92+
plt.axis('off')
93+
plt.suptitle("Top: Original | Bottom: Reconstructed")
94+
plt.show()
95+
96+
# Latent Space Visualization
97+
model.eval()
98+
z_list = []
99+
label_list = []
100+
101+
with torch.no_grad():
102+
for data, labels in test_loader:
103+
data = data.to(device)
104+
mu, _ = model.encode(data.view(-1, 784))
105+
z_list.append(mu.cpu())
106+
label_list.append(labels)
107+
108+
z = torch.cat(z_list).numpy()
109+
labels = torch.cat(label_list).numpy()
110+
111+
plt.figure(figsize=(8, 6))
112+
scatter = plt.scatter(z[:, 0], z[:, 1], c=labels, cmap='tab10', alpha=0.7, s=10)
113+
plt.colorbar(scatter, ticks=range(10))
114+
plt.title("2D Latent Space of MNIST")
115+
plt.xlabel("z[0]")
116+
plt.ylabel("z[1]")
117+
plt.grid(True)
118+
plt.show()

0 commit comments

Comments
 (0)