|
| 1 | +### PCA Implementation in PyTorch |
| 2 | +import torch |
| 3 | +import torchvision.transforms as transforms |
| 4 | +from torchvision import datasets |
| 5 | +import matplotlib.pyplot as plt |
| 6 | + |
| 7 | +# Load MNIST dataset |
| 8 | +transform = transforms.Compose([ |
| 9 | + transforms.ToTensor(), |
| 10 | + transforms.Lambda(lambda x: x.view(-1)) # Flatten images to vectors |
| 11 | +]) |
| 12 | + |
| 13 | +train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) |
| 14 | +train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=len(train_dataset), shuffle=False) |
| 15 | + |
| 16 | +# Get all training data |
| 17 | +data_iter = iter(train_loader) |
| 18 | +images, _ = next(data_iter) |
| 19 | + |
| 20 | +# Convert images to float32 and normalize them |
| 21 | +X = images.float() |
| 22 | +X_mean = X.mean(dim=0) |
| 23 | +X_centered = X - X_mean # Centering the data |
| 24 | + |
| 25 | +# Compute covariance matrix |
| 26 | +cov_matrix = torch.mm(X_centered.T, X_centered) / (X.shape[0] - 1) |
| 27 | + |
| 28 | +# Eigen decomposition of the covariance matrix |
| 29 | +eigenvalues, eigenvectors = torch.eig(cov_matrix, eigenvectors=True) |
| 30 | + |
| 31 | +# Sort eigenvalues and corresponding eigenvectors in descending order |
| 32 | +sorted_indices = torch.argsort(eigenvalues[:, 0], descending=True) |
| 33 | +eigenvalues_sorted = eigenvalues[sorted_indices] |
| 34 | +eigenvectors_sorted = eigenvectors[:, sorted_indices] |
| 35 | + |
| 36 | +# Select top k components (for example k=2 for 2D projection) |
| 37 | +k = 2 |
| 38 | +W_k = eigenvectors_sorted[:, :k] |
| 39 | + |
| 40 | +# Project the centered data onto the new subspace |
| 41 | +Z_k = torch.mm(X_centered, W_k) |
| 42 | + |
| 43 | +# Visualize the first two principal components |
| 44 | +plt.figure(figsize=(8, 6)) |
| 45 | +plt.scatter(Z_k[:, 0].numpy(), Z_k[:, 1].numpy(), alpha=0.5) |
| 46 | +plt.title('PCA Projection of MNIST Dataset') |
| 47 | +plt.xlabel('Principal Component 1') |
| 48 | +plt.ylabel('Principal Component 2') |
| 49 | +plt.grid(True) |
| 50 | +plt.show() |
| 51 | + |
| 52 | +''' |
| 53 | +### Explanation: |
| 54 | + |
| 55 | +- **Data Loading**: The MNIST dataset is loaded and each image is flattened into a vector. |
| 56 | +- **Centering Data**: The mean of the dataset is computed and subtracted from each sample to center it around zero. |
| 57 | +- **Covariance Matrix**: The covariance matrix is calculated based on the centered data. |
| 58 | +- **Eigen Decomposition**: Eigenvalues and eigenvectors are computed from the covariance matrix. |
| 59 | +- **Sorting Components**: Eigenvalues and their corresponding eigenvectors are sorted in descending order. |
| 60 | +- **Projection**: The original data is projected onto a lower-dimensional space defined by the top `k` principal components. |
| 61 | +- **Visualization**: A scatter plot visualizes how samples project onto these two principal components. |
| 62 | +''' |
0 commit comments