Skip to content

Commit 68e5d6d

Browse files
committed
Added article for knowledge distillation
1 parent 2aafc62 commit 68e5d6d

2 files changed

Lines changed: 231 additions & 0 deletions

File tree

_data/navigation.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ wiki:
184184
url: /wiki/machine-learning/multitask-learning-starter.md
185185
- title: Understanding Kalman Filters and Visual Tracking
186186
url: /wiki/machine-learning/understanding-kalman-filters-and-visual-tracking.md
187+
- title: Knowledge Distillation practical implementation guide
188+
url: /wiki/machine-learning/knowledge-distillation-practical-implementation-guide.md
187189
- title: State Estimation
188190
url: /wiki/state-estimation/
189191
children:
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Knowledge Distillation practical implementation guide
2+
3+
## Introduction to Model Compression
4+
5+
Deep neural networks have achieved remarkable performance across various computer vision tasks, but this often comes at the cost of computational complexity and large model sizes. Knowledge Distillation (KD) offers an elegant solution to this challenge by transferring knowledge from a large, complex model (the teacher) to a smaller, more efficient model (the student).
6+
7+
## Understanding Knowledge Distillation
8+
9+
Knowledge Distillation, introduced by Hinton et al., works on a fundamental principle: a smaller model can achieve better performance by learning not just from ground truth labels, but also from the "soft targets" produced by a larger model. These soft targets capture rich information about similarities between classes that aren't present in one-hot encoded ground truth labels.
10+
11+
### The Mathematics Behind Soft Targets
12+
13+
When a neural network produces outputs through its softmax layer, it generates a probability distribution across all classes. At a temperature T=1, this distribution is typically very peaked, with most of the probability mass concentrated on one class. By introducing a temperature parameter T in the softmax function, we can "soften" these probabilities:
14+
15+
```python
16+
def softmax_with_temperature(logits, temperature=1.0):
17+
"""Apply temperature scaling to logits and return softmax probabilities"""
18+
scaled_logits = logits / temperature
19+
return torch.nn.functional.softmax(scaled_logits, dim=1)
20+
```
21+
22+
Higher temperatures produce softer probability distributions, revealing more about the model's uncertainties and relative similarities between classes.
23+
24+
## Implementing Knowledge Distillation
25+
26+
### 1. Setting Up the Data Pipeline
27+
28+
First, we need to create a data pipeline that provides three components: input images, ground truth labels, and teacher predictions:
29+
30+
```python
31+
class DistillationDataset(Dataset):
32+
def __init__(self, transform=None):
33+
self.transform = transform
34+
35+
# Load image paths and teacher predictions
36+
self.images = sorted(glob.glob('path/to/images/*.jpg'))
37+
self.teacher_preds = sorted(glob.glob('path/to/teacher_preds/*.pt'))
38+
39+
def __getitem__(self, idx):
40+
# Load and transform image
41+
image = cv2.imread(self.images[idx])
42+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
43+
if self.transform:
44+
image = self.transform(image)
45+
46+
# Load teacher predictions and ground truth
47+
teacher_pred = torch.load(self.teacher_preds[idx])
48+
ground_truth = self.load_ground_truth(idx)
49+
50+
return image, ground_truth, teacher_pred
51+
```
52+
53+
### 2. Defining the Loss Function
54+
55+
The distillation loss typically combines two components: standard cross-entropy loss with ground truth labels and Kullback-Leibler divergence with teacher predictions:
56+
57+
```python
58+
def distillation_loss(student_logits, teacher_logits, labels, temperature=1.0, alpha=0.5):
59+
"""
60+
Compute the knowledge distillation loss
61+
62+
Args:
63+
student_logits: Raw outputs of the student model
64+
teacher_logits: Raw outputs of the teacher model
65+
labels: Ground truth labels
66+
temperature: Softmax temperature
67+
alpha: Weight for balancing the two losses
68+
69+
Returns:
70+
Total loss combining distillation and standard cross-entropy
71+
"""
72+
# Standard cross-entropy loss
73+
hard_loss = F.cross_entropy(student_logits, labels)
74+
75+
# Soft targets with temperature
76+
soft_student = F.log_softmax(student_logits / temperature, dim=1)
77+
soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
78+
79+
# KL divergence loss
80+
distillation_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
81+
82+
# Combine losses
83+
total_loss = (1 - alpha) * hard_loss + alpha * (temperature ** 2) * distillation_loss
84+
85+
return total_loss
86+
```
87+
88+
### 3. Training Loop Implementation
89+
90+
Here's a comprehensive training loop that implements knowledge distillation:
91+
92+
```python
93+
def train_with_distillation(student_model, teacher_model, train_loader, optimizer,
94+
temperature=1.0, alpha=0.5, device='cuda'):
95+
"""
96+
Train student model using knowledge distillation
97+
"""
98+
student_model.train()
99+
teacher_model.eval()
100+
101+
for epoch in range(num_epochs):
102+
for batch_idx, (data, targets, teacher_preds) in enumerate(train_loader):
103+
data, targets = data.to(device), targets.to(device)
104+
teacher_preds = teacher_preds.to(device)
105+
106+
# Forward pass for student
107+
student_outputs = student_model(data)
108+
109+
# Compute distillation loss
110+
loss = distillation_loss(
111+
student_outputs,
112+
teacher_preds,
113+
targets,
114+
temperature=temperature,
115+
alpha=alpha
116+
)
117+
118+
# Backpropagation
119+
optimizer.zero_grad()
120+
loss.backward()
121+
optimizer.step()
122+
```
123+
124+
### 4. Advanced Techniques and Optimizations
125+
126+
#### Temperature Scheduling
127+
128+
Instead of using a fixed temperature, we can implement temperature scheduling:
129+
130+
```python
131+
def get_temperature(epoch, max_epochs):
132+
"""Implement temperature annealing"""
133+
return 1.0 + (4.0 * (1.0 - epoch / max_epochs))
134+
```
135+
136+
#### Online Distillation
137+
138+
We can also perform online distillation where the teacher's predictions are generated during training:
139+
140+
```python
141+
def online_distillation(student_model, teacher_model, data, temperature):
142+
"""Perform online knowledge distillation"""
143+
with torch.no_grad():
144+
teacher_logits = teacher_model(data)
145+
146+
student_logits = student_model(data)
147+
return student_logits, teacher_logits
148+
```
149+
150+
## Best Practices and Optimization Tips
151+
152+
### 1. Model Architecture Considerations
153+
154+
The student model should maintain a similar architectural pattern to the teacher, but with reduced capacity. For example:
155+
156+
```python
157+
class StudentModel(nn.Module):
158+
def __init__(self, num_classes):
159+
super().__init__()
160+
# Use depth-wise separable convolutions for efficiency
161+
self.features = nn.Sequential(
162+
DepthwiseSeparableConv(3, 64, stride=2),
163+
DepthwiseSeparableConv(64, 128),
164+
DepthwiseSeparableConv(128, 256)
165+
)
166+
self.classifier = nn.Linear(256, num_classes)
167+
```
168+
169+
### 2. Hyperparameter Selection
170+
171+
Key hyperparameters that significantly impact distillation performance:
172+
173+
```python
174+
distillation_params = {
175+
'temperature': 2.0, # Controls softness of probability distribution
176+
'alpha': 0.5, # Balance between hard and soft losses
177+
'learning_rate': 1e-4, # Usually lower than standard training
178+
'batch_size': 64 # Can be larger due to simpler model
179+
}
180+
```
181+
182+
### 3. Training Optimizations
183+
184+
Implement gradient clipping and learning rate scheduling for stable training:
185+
186+
```python
187+
def configure_training(student_model, learning_rate):
188+
"""Configure training optimizations"""
189+
optimizer = torch.optim.Adam(student_model.parameters(), lr=learning_rate)
190+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
191+
optimizer, T_max=num_epochs
192+
)
193+
194+
return optimizer, scheduler
195+
```
196+
197+
## Performance Evaluation and Metrics
198+
199+
To evaluate the effectiveness of knowledge distillation, we should measure:
200+
201+
```python
202+
def evaluate_distillation(student_model, teacher_model, test_loader, device):
203+
"""Evaluate distillation performance"""
204+
student_model.eval()
205+
teacher_model.eval()
206+
207+
metrics = {
208+
'accuracy': 0.0,
209+
'model_size_reduction': 0.0,
210+
'inference_speedup': 0.0
211+
}
212+
213+
with torch.no_grad():
214+
# Implement evaluation logic
215+
pass
216+
217+
return metrics
218+
```
219+
220+
## Conclusion
221+
222+
Knowledge Distillation offers a powerful approach to model compression while maintaining performance. Success depends on:
223+
224+
1. Careful selection of teacher and student architectures
225+
2. Proper tuning of temperature and loss balancing
226+
3. Implementation of training optimizations
227+
4. Comprehensive evaluation metrics
228+
229+
By following these guidelines and implementing the provided code patterns, you can effectively compress deep learning models while preserving their performance characteristics.

0 commit comments

Comments
 (0)