-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_ofa.py
More file actions
174 lines (145 loc) · 7.85 KB
/
train_ofa.py
File metadata and controls
174 lines (145 loc) · 7.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader # Your ImageFolderDataset
from tqdm import tqdm
import random
import os
import time
import numpy as np
# --- Your Project Specific Imports (Dataset, Transforms, etc.) ---
# from HAM10000 import ImageFolderDataset, get_train_transforms, CLASS_NAMES_DRIVE_DATA, NUM_CLASSES_DRIVE_DATA
# (Assuming these are in a HAM10000.py or similar, adjust path)
# For this script, let's redefine them briefly or assume they are available.
# --- OFA Dynamic Components (from previous response, you need to implement these fully) ---
# class DynamicLinear(nn.Linear): ...
# class DynamicConv2d(nn.Conv2d): ...
# class DynamicBatchNorm2d(nn.BatchNorm2d): ... # Or SwitchableBatchNorm
# class DynamicMaxViTBlock(nn.Module): ...
# class DynamicMaxViTStage(nn.Module): ...
# class OFAMaxViT(nn.Module): ... # The Supernet definition
# --- Configuration ---
### PLACEHOLDER: Define your final SEARCH_SPACE_CONFIG here ###
SEARCH_SPACE_CONFIG = {
'max_depths_per_stage': [2, 2, 4, 1],
'max_channels_per_stage': [64, 128, 256, 512], # Example, VERIFY!
'max_mlp_ratios_glob': 4.0,
'stage_0_depth_choices': [1, 2],
'stage_1_depth_choices': [1, 2],
'stage_2_depth_choices': [2, 3, 4],
'stage_3_depth_choices': [1],
'channel_scale_choices': [0.5, 0.75, 1.0],
'mlp_ratio_choices': [2.5, 3.0, 3.5, 4.0],
}
# Training Hyperparameters
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS_SUPERNET = 150 # Example: OFA supernets often train for many epochs
BATCH_SIZE = 32 # Adjust based on H100 VRAM for the MAX supernet size
LEARNING_RATE = 1e-3 # OFA often uses a specific LR schedule
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 4
IMG_SIZE = 224
# Paths
### PLACEHOLDER: Define your dataset paths ###
# TRAIN_DATA_PATH = "/path/to/your/ham10000_train_data_for_supernet"
# VAL_DATA_PATH = "/path/to/your/ham10000_val_data_for_supernet_eval" # For periodic eval
OUTPUT_DIR = "./ofa_supernet_training_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# --- Data Loading ---
# train_transforms = get_train_transforms(IMG_SIZE) # Your existing train transforms
# val_transforms = get_val_transforms(IMG_SIZE) # Your existing val transforms
# train_dataset = ImageFolderDataset(root_dir=TRAIN_DATA_PATH, transform=train_transforms, ...)
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
# val_dataset = ImageFolderDataset(root_dir=VAL_DATA_PATH, transform=val_transforms, ...)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
### PLACEHOLDER: Make sure your Dataset and DataLoader setup is correct ###
# --- Model Initialization ---
# You need to have your OFAMaxViT class fully implemented based on SEARCH_SPACE_CONFIG
# model_supernet = OFAMaxViT(num_classes=NUM_CLASSES_DRIVE_DATA, search_space_config=SEARCH_SPACE_CONFIG)
# model_supernet.to(DEVICE)
### PLACEHOLDER: Instantiate your OFAMaxViT supernet ###
model_supernet = None # Replace with actual model
if model_supernet is None:
raise NotImplementedError("OFAMaxViT supernet model is not implemented/instantiated.")
# --- Optimizer and Scheduler ---
optimizer = optim.AdamW(model_supernet.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
# OFA often uses a specific schedule, e.g., cosine decay.
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS_SUPERNET)
criterion = nn.CrossEntropyLoss() # Add label smoothing if desired: label_smoothing=0.1
# --- OFA Supernet Training Loop ---
print(f"Starting OFA Supernet training on {DEVICE} for {NUM_EPOCHS_SUPERNET} epochs.")
for epoch in range(NUM_EPOCHS_SUPERNET):
model_supernet.train()
epoch_loss = 0
# In OFA, sometimes the largest and smallest networks are trained more frequently
# or with specific attention to their Batch Norm layers.
# The "Sandwich Rule" is a common strategy.
for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS_SUPERNET} Training"):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
# Sandwich Rule: Train largest, smallest, and two random subnets per batch
# (Or adapt based on OFA paper / Progressive Shrinking)
# 1. Largest Subnetwork
config_max = model_supernet.get_max_config()
model_supernet.set_active_subnet(config_max)
### PLACEHOLDER: Handle Batch Norm re-calibration if using advanced BN ###
# model_supernet.recal_bn_stats(train_loader_for_bn_recal, config_max)
outputs_max = model_supernet(inputs)
loss_max = criterion(outputs_max, labels)
# 2. Smallest Subnetwork
config_min = model_supernet.get_min_config()
model_supernet.set_active_subnet(config_min)
### PLACEHOLDER: Handle Batch Norm re-calibration ###
outputs_min = model_supernet(inputs)
loss_min = criterion(outputs_min, labels)
# 3. Two Random Subnetworks
config_rand1 = model_supernet.get_random_config()
model_supernet.set_active_subnet(config_rand1)
### PLACEHOLDER: Handle Batch Norm re-calibration ###
outputs_rand1 = model_supernet(inputs)
loss_rand1 = criterion(outputs_rand1, labels)
config_rand2 = model_supernet.get_random_config()
model_supernet.set_active_subnet(config_rand2)
### PLACEHOLDER: Handle Batch Norm re-calibration ###
outputs_rand2 = model_supernet(inputs)
loss_rand2 = criterion(outputs_rand2, labels)
# Sum losses and backpropagate once
total_loss = loss_max + loss_min + loss_rand1 + loss_rand2
total_loss.backward()
optimizer.step()
epoch_loss += total_loss.item()
scheduler.step()
avg_epoch_loss = epoch_loss / (len(train_loader) * 4) # Multiplied by 4 due to 4 forward passes
print(f"Epoch {epoch+1} Supernet Training Loss: {avg_epoch_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
# --- Optional: Periodic Evaluation of Specific Subnets on Validation Set ---
if (epoch + 1) % 10 == 0: # Evaluate every 10 epochs, for example
model_supernet.eval()
with torch.no_grad():
# Evaluate Max Subnet
model_supernet.set_active_subnet(model_supernet.get_max_config())
# val_accuracy_max = evaluate_subnet_on_val(model_supernet, val_loader, DEVICE, criterion) # Implement this function
# print(f"Epoch {epoch+1} - Max Subnet Val Accuracy: {val_accuracy_max:.4f}")
# Evaluate Min Subnet
model_supernet.set_active_subnet(model_supernet.get_min_config())
# val_accuracy_min = evaluate_subnet_on_val(model_supernet, val_loader, DEVICE, criterion)
# print(f"Epoch {epoch+1} - Min Subnet Val Accuracy: {val_accuracy_min:.4f}")
model_supernet.train() # Set back to train mode
# --- Save Supernet Checkpoint ---
if (epoch + 1) % 25 == 0 or epoch == NUM_EPOCHS_SUPERNET - 1: # Save every 25 epochs and at the end
checkpoint_path = os.path.join(OUTPUT_DIR, f"ofa_maxvit_supernet_epoch_{epoch+1}.pth")
torch.save({
'epoch': epoch + 1,
'model_state_dict': model_supernet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'search_space_config': SEARCH_SPACE_CONFIG, # Save the config with the model
}, checkpoint_path)
print(f"Supernet checkpoint saved to {checkpoint_path}")
print("OFA Supernet training finished.")
# --- Helper function for validation (you'd need to implement this) ---
# def evaluate_subnet_on_val(model, val_loader, device, criterion):
# # Standard validation loop
# # model is already configured with a specific subnet and in eval mode
# # ...
# # return accuracy or macro_f1
# pass