Skip to content

Commit bb3973b

Browse files
committed
Added article on multitask learning starter code
1 parent 8cb06bd commit bb3973b

2 files changed

Lines changed: 238 additions & 1 deletion

File tree

_data/navigation.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,10 @@ wiki:
178178
url: /wiki/machine-learning/comprehensive-guide-to-albumentations.md
179179
- title: Kornia technical guide
180180
url: /wiki/machine-learning/kornia-technical-guide.md
181-
- title: Integrating OLLAMA LLMs with Franka Arm
181+
- title: Integrating OLLAMA LLMs with Franka Arm
182182
url: /wiki/machine-learning/integrating-ollama-llms-with-franka-arm.md
183+
- title: Multi-task learning A starter guide
184+
url: /wiki/machine-learning/multitask-learning-starter.md
183185
- title: State Estimation
184186
url: /wiki/state-estimation/
185187
children:
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Multi-task learning: A starter guide
2+
3+
## Introduction to Multi-Task Learning in Computer Vision
4+
5+
Multi-task learning represents a powerful paradigm in deep learning where a single neural network learns to perform multiple related tasks simultaneously. In computer vision, this approach is particularly valuable because many vision tasks share common low-level features. For instance, both depth estimation and semantic segmentation benefit from understanding edges, textures, and object boundaries in an image.
6+
7+
In this guide, we'll explore how to build a HydraNet architecture that performs two complementary tasks:
8+
9+
1. Monocular depth estimation: Predicting the depth of each pixel from a single RGB image
10+
2. Semantic segmentation: Classifying each pixel into predefined semantic categories
11+
12+
The power of this approach lies in the shared learning of features that are useful for both tasks, leading to more efficient and often more accurate predictions than training separate models for each task.
13+
14+
## Understanding the System Architecture
15+
16+
The HydraNet architecture consists of three main components working in harmony:
17+
18+
### 1. MobileNetV2 Encoder
19+
20+
The encoder serves as the backbone of our network, converting RGB images into rich feature representations. We choose MobileNetV2 for several reasons:
21+
22+
- Efficient design with inverted residual blocks
23+
- Strong feature extraction capabilities
24+
- Lower computational requirements compared to heavier architectures
25+
- Good balance of speed and accuracy
26+
27+
### 2. Lightweight RefineNet Decoder
28+
29+
The decoder takes the encoded features and processes them through refinement stages. Its key characteristics include:
30+
31+
- Chained Residual Pooling (CRP) blocks for effective feature refinement
32+
- Skip connections to preserve spatial information
33+
- Gradual upsampling to restore resolution
34+
35+
### 3. Task-Specific Heads
36+
37+
Two separate heads branch out from the decoder:
38+
39+
- Depth head: Outputs continuous depth values
40+
- Segmentation head: Outputs class probabilities for each pixel
41+
42+
## Detailed Implementation Guide
43+
44+
### 1. Environment Setup and Prerequisites
45+
46+
First, let's understand the constants we'll be using for image processing:
47+
48+
```python
49+
import torch
50+
import torch.nn as nn
51+
import torch.nn.functional as F
52+
import numpy as np
53+
from torch.autograd import Variable
54+
55+
# Image normalization constants
56+
IMG_SCALE = 1./255 # Scale pixel values to [0,1]
57+
IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) # ImageNet means
58+
IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) # ImageNet stds
59+
```
60+
61+
These constants are crucial for ensuring our input images match the distribution that our pre-trained MobileNetV2 encoder expects. The normalization process transforms our images to have similar statistical properties to the ImageNet dataset, which helps with transfer learning.
62+
63+
### 2. HydraNet Core Architecture
64+
65+
The HydraNet class serves as our model's foundation. Let's examine its structure in detail:
66+
67+
```python
68+
class HydraNet(nn.Module):
69+
def __init__(self):
70+
super().__init__()
71+
self.num_tasks = 2 # Depth estimation and segmentation
72+
self.num_classes = 6 # Number of segmentation classes
73+
74+
# Initialize network components
75+
self.define_mobilenet() # Encoder
76+
self.define_lightweight_refinenet() # Decoder
77+
```
78+
79+
This initialization sets up our multi-task framework. The `num_tasks` parameter defines how many outputs our network will produce, while `num_classes` specifies the number of semantic categories for segmentation.
80+
81+
### 3. Understanding the MobileNetV2 Encoder
82+
83+
The encoder uses inverted residual blocks, a key innovation of MobileNetV2. Here's how they work:
84+
85+
```python
86+
class InvertedResidualBlock(nn.Module):
87+
def __init__(self, in_channels, out_channels, stride, expansion_factor):
88+
super().__init__()
89+
90+
hidden_dim = in_channels * expansion_factor
91+
92+
self.output = nn.Sequential(
93+
# Step 1: Channel Expansion - Increases the number of channels
94+
nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
95+
nn.BatchNorm2d(hidden_dim),
96+
nn.ReLU6(inplace=True),
97+
98+
# Step 2: Depthwise Convolution - Spatial filtering
99+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1,
100+
groups=hidden_dim, bias=False),
101+
nn.BatchNorm2d(hidden_dim),
102+
nn.ReLU6(inplace=True),
103+
104+
# Step 3: Channel Reduction - Projects back to a smaller dimension
105+
nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
106+
nn.BatchNorm2d(out_channels)
107+
)
108+
```
109+
110+
Each inverted residual block performs three key operations:
111+
112+
1. Channel expansion: Increases the feature dimensions to allow for more expressive transformations
113+
2. Depthwise convolution: Applies spatial filtering efficiently by processing each channel separately
114+
3. Channel reduction: Compresses the features back to a manageable size
115+
116+
The name "inverted residual" comes from the fact that the block expands channels before the depthwise convolution, unlike traditional residual blocks that reduce dimensions first.
117+
118+
### 4. Lightweight RefineNet Decoder Deep Dive
119+
120+
The decoder's CRP blocks are crucial for effective feature refinement:
121+
122+
```python
123+
def _make_crp(self, in_planes, out_planes, stages):
124+
layers = [
125+
# Initial projection to desired number of channels
126+
nn.Conv2d(in_planes, out_planes, 1, 1, bias=False),
127+
nn.BatchNorm2d(out_planes),
128+
nn.ReLU(inplace=True)
129+
]
130+
131+
# Create chain of pooling operations
132+
for i in range(stages):
133+
layers.extend([
134+
nn.MaxPool2d(5, stride=1, padding=2), # Maintains spatial size
135+
nn.Conv2d(out_planes, out_planes, 1, 1, bias=False),
136+
nn.BatchNorm2d(out_planes),
137+
nn.ReLU(inplace=True)
138+
])
139+
140+
return nn.Sequential(*layers)
141+
```
142+
143+
The CRP blocks serve several important purposes:
144+
145+
- They capture multi-scale context through repeated pooling operations
146+
- The chain structure allows for refinement of features at different receptive fields
147+
- The 1x1 convolutions after each pooling operation help in feature adaptation
148+
- The residual connections help maintain gradient flow
149+
150+
### 5. Task-Specific Heads in Detail
151+
152+
The heads are designed to transform shared features into task-specific predictions:
153+
154+
```python
155+
def define_heads(self):
156+
# Segmentation head: Transforms features into class probabilities
157+
self.segm_head = nn.Sequential(
158+
nn.Conv2d(self.feature_dim, self.feature_dim, 3, padding=1),
159+
nn.BatchNorm2d(self.feature_dim),
160+
nn.ReLU(inplace=True),
161+
nn.Conv2d(self.feature_dim, self.num_classes, 1)
162+
)
163+
164+
# Depth head: Transforms features into depth values
165+
self.depth_head = nn.Sequential(
166+
nn.Conv2d(self.feature_dim, self.feature_dim, 3, padding=1),
167+
nn.BatchNorm2d(self.feature_dim),
168+
nn.ReLU(inplace=True),
169+
nn.Conv2d(self.feature_dim, 1, 1)
170+
)
171+
```
172+
173+
Each head follows a similar structure but serves different purposes:
174+
175+
- The segmentation head outputs logits for each class at each pixel
176+
- The depth head outputs a single continuous value per pixel
177+
- The 3x3 convolution captures local spatial context
178+
- The final 1x1 convolution projects to the required output dimensions
179+
180+
### 6. Forward Pass and Loss Functions
181+
182+
The forward pass coordinates the flow of information through the network:
183+
184+
```python
185+
def compute_loss(depth_pred, depth_gt, segm_pred, segm_gt, weights):
186+
# Depth loss: L1 loss for continuous values
187+
depth_loss = F.l1_loss(depth_pred, depth_gt)
188+
189+
# Segmentation loss: Cross-entropy for classification
190+
segm_loss = F.cross_entropy(segm_pred, segm_gt)
191+
192+
# Weighted combination of losses
193+
total_loss = weights['depth'] * depth_loss + weights['segm'] * segm_loss
194+
return total_loss
195+
```
196+
197+
The loss function balancing is crucial for successful multi-task learning:
198+
199+
- The depth loss measures absolute differences in depth predictions
200+
- The segmentation loss measures classification accuracy
201+
- The weights help balance the contribution of each task
202+
- These weights can be fixed or learned during training
203+
204+
## Training Considerations and Best Practices
205+
206+
When training a multi-task model like HydraNet, several factors require careful attention:
207+
208+
### 1. Data Balancing
209+
210+
- Ensure both tasks have sufficient and balanced training data
211+
- Consider the relative difficulty of each task
212+
- Use appropriate data augmentation for each task
213+
214+
### 2. Loss Balancing
215+
216+
- Monitor individual task losses during training
217+
- Adjust task weights if one task dominates
218+
- Consider uncertainty-based loss weighting
219+
220+
### 3. Optimization Strategy
221+
222+
- Start with lower learning rates
223+
- Use appropriate learning rate scheduling
224+
- Monitor task-specific metrics separately
225+
- Implement early stopping based on validation performance
226+
227+
## Conclusion
228+
229+
The HydraNet architecture demonstrates the power of multi-task learning in computer vision. By sharing features between depth estimation and segmentation tasks, we achieve:
230+
231+
- More efficient use of model parameters
232+
- Better generalization through shared representations
233+
- Fast inference times suitable for real-world applications
234+
235+
Success with this architecture requires careful attention to implementation details, particularly in the areas of loss balancing, training dynamics, and architecture design. The code provided here serves as a foundation that can be adapted and extended based on specific requirements and constraints.

0 commit comments

Comments
 (0)