Skip to content

Commit 5d9effe

Browse files
committed
Added article for neural network model pruning
1 parent 68e5d6d commit 5d9effe

2 files changed

Lines changed: 231 additions & 0 deletions

File tree

_data/navigation.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ wiki:
186186
url: /wiki/machine-learning/understanding-kalman-filters-and-visual-tracking.md
187187
- title: Knowledge Distillation practical implementation guide
188188
url: /wiki/machine-learning/knowledge-distillation-practical-implementation-guide.md
189+
- title: Neural Network optimization using model pruning
190+
url: /wiki/machine-learning/neural-network-optimization-using-model-pruning.md
189191
- title: State Estimation
190192
url: /wiki/state-estimation/
191193
children:
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Neural Network optimization using model pruning
2+
3+
## Understanding Model Pruning
4+
5+
Model pruning is a fundamental technique in deep learning model optimization where we systematically remove weights or neurons from a neural network while maintaining its performance. This process is analogous to biological neural pruning, where the brain eliminates less important neural connections to improve efficiency.
6+
7+
## Theoretical Foundation
8+
9+
Neural networks typically contain redundant parameters that contribute minimally to the model's outputs. Pruning identifies and removes these parameters by:
10+
11+
1. Evaluating parameter importance using specific criteria
12+
2. Removing parameters deemed less important
13+
3. Fine-tuning the remaining parameters to maintain performance
14+
15+
## Implementation with PyTorch
16+
17+
Let's explore different pruning techniques using PyTorch's pruning utilities:
18+
19+
```python
20+
import torch.nn.utils.prune as prune
21+
import torch.nn as nn
22+
from copy import deepcopy
23+
import numpy as np
24+
```
25+
26+
### Basic Pruning Setup
27+
28+
First, let's create a simple linear layer to demonstrate pruning concepts:
29+
30+
```python
31+
# Create a test module
32+
fc_test = nn.Linear(10, 10)
33+
module = deepcopy(fc_test)
34+
35+
# Examine initial parameters
36+
print('Before pruning:')
37+
print(list(module.named_parameters()))
38+
print(list(module.named_buffers()))
39+
```
40+
41+
## Unstructured Pruning
42+
43+
### L1 Unstructured Pruning
44+
45+
L1 unstructured pruning removes individual weights based on their absolute magnitude. This is the most flexible form of pruning but results in sparse matrices that may not provide practical speed benefits without specialized hardware.
46+
47+
```python
48+
def apply_l1_unstructured_pruning(module, amount=0.3):
49+
"""
50+
Apply L1 unstructured pruning to a module
51+
52+
Args:
53+
module: PyTorch module to prune
54+
amount: Fraction of weights to prune (0.3 = 30%)
55+
"""
56+
prune.l1_unstructured(module, name='weight', amount=amount)
57+
58+
# Examine the pruned weights
59+
weight = module.weight.cpu().detach().numpy()
60+
mask = module.get_buffer('weight_mask').cpu().numpy()
61+
62+
return weight, mask
63+
```
64+
65+
The process works by:
66+
1. Computing the L1 norm (absolute values) of all weights
67+
2. Sorting weights by magnitude
68+
3. Setting the smallest weights to zero based on the specified amount
69+
70+
### Visualizing Unstructured Pruning
71+
72+
```python
73+
def visualize_pruning_pattern(weight, mask, title):
74+
"""
75+
Visualize weight matrix before and after pruning
76+
"""
77+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
78+
79+
# Original weights
80+
im1 = ax1.imshow(weight, cmap='viridis')
81+
ax1.set_title('Original Weights')
82+
83+
# Pruned weights
84+
pruned_weight = weight * mask
85+
im2 = ax2.imshow(pruned_weight, cmap='viridis')
86+
ax2.set_title(f'After {title}')
87+
88+
plt.colorbar(im1, ax=ax1)
89+
plt.colorbar(im2, ax=ax2)
90+
plt.tight_layout()
91+
92+
return fig
93+
```
94+
95+
## Structured Pruning
96+
97+
### L1 Structured Pruning
98+
99+
Structured pruning removes entire groups of weights (e.g., neurons or channels) based on their collective importance. This approach results in dense but smaller matrices that can provide immediate speed benefits.
100+
101+
```python
102+
def apply_l1_structured_pruning(module, amount=0.3, dim=0):
103+
"""
104+
Apply L1 structured pruning to a module
105+
106+
Args:
107+
module: PyTorch module to prune
108+
amount: Fraction of structures to prune
109+
dim: Dimension along which to prune (0=rows, 1=columns)
110+
"""
111+
prune.ln_structured(
112+
module,
113+
name='weight',
114+
amount=amount,
115+
n=1, # L1 norm
116+
dim=dim
117+
)
118+
119+
return module.weight.cpu().detach().numpy()
120+
```
121+
122+
The process works by:
123+
1. Computing the L1 norm of each structure (row/column)
124+
2. Sorting structures by their total magnitude
125+
3. Removing entire structures with lowest magnitude
126+
127+
## Advanced Pruning Techniques
128+
129+
### Iterative Pruning
130+
131+
Iterative pruning gradually removes weights over multiple rounds, allowing the network to adapt:
132+
133+
```python
134+
def iterative_pruning(model, pruning_schedule, fine_tune_steps=1000):
135+
"""
136+
Iteratively prune a model according to a schedule
137+
138+
Args:
139+
model: PyTorch model to prune
140+
pruning_schedule: List of (epoch, amount) tuples
141+
fine_tune_steps: Number of steps to fine-tune after each pruning
142+
"""
143+
for epoch, amount in pruning_schedule:
144+
# Apply pruning
145+
for name, module in model.named_modules():
146+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
147+
prune.l1_unstructured(module, 'weight', amount=amount)
148+
149+
# Fine-tune
150+
fine_tune_model(model, steps=fine_tune_steps)
151+
```
152+
153+
### Global Pruning
154+
155+
Instead of pruning each layer independently, global pruning considers the importance of weights across the entire network:
156+
157+
```python
158+
def global_magnitude_pruning(model, amount):
159+
"""
160+
Prune weights globally across the model based on magnitude
161+
162+
Args:
163+
model: PyTorch model to prune
164+
amount: Fraction of weights to prune globally
165+
"""
166+
# Collect all weights
167+
all_weights = []
168+
for name, module in model.named_modules():
169+
if isinstance(module, (nn.Linear, nn.Conv2d)):
170+
all_weights.extend(module.weight.data.abs().cpu().numpy().flatten())
171+
172+
# Compute global threshold
173+
threshold = np.percentile(all_weights, amount * 100)
174+
175+
# Apply pruning
176+
for name, module in model.named_modules():
177+
if isinstance(module, (nn.Linear, nn.Conv2d)):
178+
mask = module.weight.data.abs() > threshold
179+
module.weight.data *= mask
180+
```
181+
182+
## Best Practices for Model Pruning
183+
184+
### 1. Pruning Strategy Selection
185+
186+
Choose your pruning strategy based on your requirements:
187+
188+
```python
189+
def select_pruning_strategy(model_type, hardware_target):
190+
"""
191+
Select appropriate pruning strategy based on model and hardware
192+
"""
193+
if hardware_target == 'gpu':
194+
return 'structured' # Better for parallel processing
195+
elif hardware_target == 'sparse_accelerator':
196+
return 'unstructured' # Better for specialized hardware
197+
else:
198+
return 'structured' # Default to structured for general purpose
199+
```
200+
201+
### 2. Performance Monitoring
202+
203+
Monitor key metrics during pruning:
204+
205+
```python
206+
def evaluate_pruning(model, test_loader, original_accuracy):
207+
"""
208+
Evaluate the impact of pruning
209+
"""
210+
metrics = {
211+
'accuracy': compute_accuracy(model, test_loader),
212+
'model_size': get_model_size(model),
213+
'inference_time': measure_inference_time(model),
214+
'compression_ratio': compute_compression_ratio(model)
215+
}
216+
217+
return metrics
218+
```
219+
220+
## Conclusion
221+
222+
Effective model pruning requires:
223+
224+
1. Understanding different pruning techniques and their trade-offs
225+
2. Careful selection of pruning parameters and schedules
226+
3. Proper monitoring of model performance during pruning
227+
4. Consideration of hardware constraints and deployment targets
228+
229+
When implemented correctly, pruning can significantly reduce model size and improve inference speed while maintaining most of the original model's accuracy.

0 commit comments

Comments
 (0)