-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsecondrunOFA.py
More file actions
604 lines (510 loc) · 28.7 KB
/
secondrunOFA.py
File metadata and controls
604 lines (510 loc) · 28.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader # Your HAM10000 DataLoader
import random
import os
import numpy as np
from tqdm import tqdm
from ofa.utils import make_divisible # Use OFA's version
# Your OFA building blocks and OFAMaxViT model
from ofa_maxvit_building_blocks import OFAMaxViT # Assuming this file has all classes
from sklearn.metrics import f1_score as calculate_macro_f1
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms # torchvision.datasets.ImageFolder
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import cv2 # For Albumentations interpolation
# --- Global Constants for Data ---
MEAN_IMG = [0.485, 0.456, 0.406] # Standard ImageNet mean/std
STD_IMG = [0.229, 0.224, 0.225]
# These will be used to instantiate OFAMaxViT
STEM_OUT_CHANNELS = 64 # Fixed stem output
# Global choices (can be overridden per stage if needed)
GLOBAL_K_CHOICES = [3] # MaxViT-Small typically uses 3x3 in MBConv
GLOBAL_E_CHOICES = [4, 6] # MBConv expansion ratios (original is often 6 or more)
GLOBAL_MLP_RATIO_CHOICES = [2.0, 4.0] # Attention MLP expansion (original is 4)
# SE fixed reduced channels choices for MBConv. This needs to be a list of options per stage,
# as the expanded channels in MBConv change significantly.
# Or, use se_rd_ratio. For MaxViT, fixed values are more common.
# Let's define SE choices per stage based on *their* max expanded channels.
# Stage-specific configurations for OFAMaxViT
# We need to ensure C_out_choices are divisible by their respective num_heads_attn_max
# MaxViT-Small head_dim is often 32.
# MaxViT-Timm `maxvit_small_tf_224.in1k` has head_dim = 32 for all stages.
# Stage 0: 96 channels -> 96/32 = 3 heads
# Stage 1: 192 channels -> 192/32 = 6 heads
# Stage 2: 384 channels -> 384/32 = 12 heads
# Stage 3: 768 channels -> 768/32 = 24 heads
# # Helper to generate channel choices
# def get_channel_choices(base_channel, head_dim=32, multipliers=[0.5, 0.75, 1.0]):
# choices = []
# for m in multipliers:
# c = make_divisible(base_channel * m, head_dim) # Ensure divisible by head_dim equivalent
# # Or make_divisible by 8 (common for GPUs)
# # and also by num_heads derived from it.
# # Simpler: make_divisible by (head_dim or common_divisor like 8)
# # For MaxViT head_dim=32 is a strong constraint.
# # The num_heads for a stage is C_out_stage / head_dim.
# # So C_out_stage must be div by head_dim.
# # Here, num_heads_attn_max for the stage IS C_out_max_stage / head_dim.
# # So any chosen C_out_stage must be div by (C_out_max_stage / num_heads_attn_max) if head_dim is to be const.
# # OR: num_heads is fixed per stage, and C_out_stage must be div by num_heads.
# # Let's use the num_heads from the max config of the stage.
# # e.g. Stage 0: max_C_out=96, num_heads_max=3. All choices must be div by 3.
# # This is implicitly handled by OFAMaxxVitBlock/Stage init checks if num_heads_max is passed.
# # For choice generation, ensure they are sensible.
# # Let's use num_heads of the stage and ensure C_out is div by it
# # No, C_out determines num_heads if head_dim is fixed.
# # Let's just use make_divisible by 8 for general good practice.
# # The check `c % num_heads_attn_max == 0` in OFAMaxxVitBlock/Stage will fail if
# # num_heads_attn_max (passed to stage) is not a divisor of a C_out_choice.
# # This means `num_heads_attn_max` argument to OFAMaxxVitStage should be the fixed head count FOR THAT STAGE'S MAX CONFIG.
# # And all C_out_choices for that stage must be divisible by that fixed num_heads_attn_max.
# # Let's make choices divisible by a common factor (e.g., 16 or 32 for head_dim)
# # This implies num_heads for a choice `c` would be `c / head_dim_fixed`.
# # The `num_heads_attn_max` passed to stage init refers to the head count at max channels.
# # `OFAAttentionCl` then needs to handle `active_dim` being div by `num_heads_max`.
# c_candidate = make_divisible(base_channel * m, 8)
# if c_candidate > 0:
# choices.append(c_candidate)
# return sorted(list(set(choices)))
# # MaxViT Small Original Params:
# # Stage 0: C=96, Depth=2, Heads=3 (assuming head_dim=32)
# # Stage 1: C=192, Depth=2, Heads=6
# # Stage 2: C=384, Depth=5, Heads=12
# # Stage 3: C=768, Depth=2, Heads=24
# # SE choices: based on max expanded channels in each stage's MBConv.
# # Max E for MBConv is max(GLOBAL_E_CHOICES) = 6.
# # Stage 0: in_C_max=64, max_E=6 -> mid_max=384. Sensible SE_rd_choices: [16, 24, 32]
# # Stage 1: in_C_max=96, max_E=6 -> mid_max=576. Sensible SE_rd_choices: [24, 32, 48]
# # Stage 2: in_C_max=192, max_E=6 -> mid_max=1152. Sensible SE_rd_choices: [32, 48, 64, 96]
# # Stage 3: in_C_max=384, max_E=6 -> mid_max=2304. Sensible SE_rd_choices: [64, 96, 128]
# STAGE_CONFIG_PARAMS = [
# { # Stage 0
# 'C_out_stage_choices': get_channel_choices(96), # e.g. [48, 72, 96]
# 'depth_choices': [1, 2], # Original 2
# 'stride_first_block': 2, # First block of stage 0 is strided (total stride 4 after stem)
# 'num_heads_attn_max': 3, # For C_max=96, head_dim=32
# 'se_fixed_rd_channels_choices_mbconv': [16, 24, 32], # For MBConv in this stage
# },
# { # Stage 1
# 'C_out_stage_choices': get_channel_choices(192), # e.g. [96, 144, 192]
# 'depth_choices': [1, 2], # Original 2
# 'stride_first_block': 2,
# 'num_heads_attn_max': 6, # For C_max=192, head_dim=32
# 'se_fixed_rd_channels_choices_mbconv': [24, 32, 48],
# },
# { # Stage 2
# 'C_out_stage_choices': get_channel_choices(384), # e.g. [192, 288, 384]
# 'depth_choices': [2, 3, 4, 5], # Original 5
# 'stride_first_block': 2,
# 'num_heads_attn_max': 12, # For C_max=384, head_dim=32
# 'se_fixed_rd_channels_choices_mbconv': [32, 48, 64, 96],
# },
# { # Stage 3
# 'C_out_stage_choices': get_channel_choices(768), # e.g. [384, 576, 768]
# 'depth_choices': [1, 2], # Original 2
# 'stride_first_block': 2,
# 'num_heads_attn_max': 24, # For C_max=768, head_dim=32
# 'se_fixed_rd_channels_choices_mbconv': [64, 96, 128],
# }
# ]
# # Need to ensure all C_out_stage_choices are divisible by their respective num_heads_attn_max
# # Our OFAMaxxVitBlock/Stage init already filters C_out_choices based on this.
# # Let's re-check get_channel_choices with this constraint.
# # A simpler way: head_dim is fixed (e.g. 32), num_heads = active_C_out / head_dim.
# # This means active_C_out must be divisible by head_dim.
# FIXED_HEAD_DIM = 32
# def get_channel_choices_v2(base_channel, head_dim=FIXED_HEAD_DIM, multipliers=[0.5, 0.75, 1.0]):
# choices = []
# for m in multipliers:
# c_candidate = make_divisible(base_channel * m, head_dim) # Ensure divisible by head_dim
# if c_candidate > 0:
# choices.append(c_candidate)
# return sorted(list(set(choices)))
# STAGE_CONFIG_PARAMS_V2 = [
# { # Stage 0
# 'C_out_stage_choices': get_channel_choices_v2(96), # e.g., [32, 64, 96] if head_dim=32
# 'depth_choices': [1, 2],
# 'stride_first_block': 2,
# 'num_heads_attn_max': 96 // FIXED_HEAD_DIM, # 3
# 'se_fixed_rd_channels_choices_mbconv': [16, 24, 32],
# },
# { # Stage 1
# 'C_out_stage_choices': get_channel_choices_v2(192), # e.g., [96, 160, 192] if make_divisible(192*0.75, 32) = 160
# 'depth_choices': [1, 2],
# 'stride_first_block': 2,
# 'num_heads_attn_max': 192 // FIXED_HEAD_DIM, # 6
# 'se_fixed_rd_channels_choices_mbconv': [24, 32, 48],
# },
# { # Stage 2
# 'C_out_stage_choices': get_channel_choices_v2(384),
# 'depth_choices': [2, 3, 4, 5],
# 'stride_first_block': 2,
# 'num_heads_attn_max': 384 // FIXED_HEAD_DIM, # 12
# 'se_fixed_rd_channels_choices_mbconv': [32, 48, 64, 96],
# },
# { # Stage 3
# 'C_out_stage_choices': get_channel_choices_v2(768),
# 'depth_choices': [1, 2],
# 'stride_first_block': 2,
# 'num_heads_attn_max': 768 // FIXED_HEAD_DIM, # 24
# 'se_fixed_rd_channels_choices_mbconv': [64, 96, 128],
# }
# ]
def get_ofa_train_transforms(image_size=224):
return A.Compose([
A.Resize(height=image_size, width=image_size, interpolation=cv2.INTER_LINEAR),
A.RandomRotate90(p=0.5),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Affine(
scale=(0.9, 1.1),
translate_percent=(-0.07, 0.07),
rotate=(-30, 30),
p=0.7
),
A.ColorJitter(
brightness=0.20,
contrast=0.20,
saturation=0.20,
hue=0.06,
p=0.6
),
A.GaussianBlur(blur_limit=(3, 5), p=0.3),
# OFA paper sometimes uses more aggressive augmentations like AutoAugment or RandAugment
# For now, let's stick to what worked for you. Can add RandAugment later if needed.
# A.pytorch.transforms.RandAugment(num_ops=2, magnitude=9), # Example if adding later
A.Normalize(mean=MEAN_IMG, std=STD_IMG),
ToTensorV2(),
])
def get_ofa_val_transforms(image_size=224):
return A.Compose([
A.Resize(height=image_size, width=image_size, interpolation=cv2.INTER_LINEAR),
A.Normalize(mean=MEAN_IMG, std=STD_IMG),
ToTensorV2(),
])
class AlbumentationsImageFolder(datasets.ImageFolder):
def __init__(self, root, transform=None, target_transform=None, loader=datasets.folder.default_loader):
super().__init__(root, transform=None, target_transform=target_transform, loader=loader)
# The 'transform' passed here is the Albumentations transform pipeline
self.alb_transform = transform
def __getitem__(self, index):
path, target = self.samples[index]
image = self.loader(path) # Loads as PIL Image
if self.alb_transform:
image_np = np.array(image) # Convert PIL to NumPy array
augmented = self.alb_transform(image=image_np)
image = augmented['image'] # This is now a PyTorch tensor from ToTensorV2()
# If no Albumentations transform, ImageFolder's default transform (if any) would apply to PIL
# But we want Albumentations, so self.alb_transform should always be provided.
return image, target
def get_ham10000_dataloaders(base_dataset_path, image_size, batch_size, num_workers=4):
"""
Creates train and validation DataLoaders for HAM10000.
Assumes 'train' and 'validation' (or 'val') subfolders in base_dataset_path.
"""
train_dir = os.path.join(base_dataset_path, "train")
# Try 'validation' first, then 'val' for validation directory name
val_dir_try1 = os.path.join(base_dataset_path, "validation")
if os.path.isdir(val_dir_try1):
val_dir = val_dir_try1
else:
raise FileNotFoundError(f"Validation directory not found at {val_dir_try1}")
if not os.path.isdir(train_dir):
raise FileNotFoundError(f"Train directory not found at {train_dir}")
train_transforms = get_ofa_train_transforms(image_size)
val_transforms = get_ofa_val_transforms(image_size)
train_dataset = AlbumentationsImageFolder(root=train_dir, transform=train_transforms)
val_dataset = AlbumentationsImageFolder(root=val_dir, transform=val_transforms)
# OFA often uses weighted sampling for the supernet training if dataset is imbalanced
# Calculate weights for sampler if needed (from your previous successful script)
# For HAM10000, it's imbalanced.
print(f"Train dataset classes: {train_dataset.classes}")
print(f"Train dataset class_to_idx: {train_dataset.class_to_idx}")
# Get targets for weighted sampling
train_targets = [s[1] for s in train_dataset.samples] # Or train_dataset.targets
class_counts = np.bincount(train_targets)
num_samples = len(train_targets)
class_weights = [num_samples / (len(class_counts) * count) if count > 0 else 0 for count in class_counts]
sample_weights = [class_weights[target] for target in train_targets]
sampler = torch.utils.data.WeightedRandomSampler(
weights=torch.DoubleTensor(sample_weights),
num_samples=len(sample_weights), # Draw N samples in total per epoch
replacement=True
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
# shuffle=True, # Sampler handles shuffling
sampler=sampler, # Use weighted sampler
num_workers=num_workers,
pin_memory=True,
persistent_workers=True if num_workers > 0 else False
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True if num_workers > 0 else False
)
print(f"Number of training samples: {len(train_dataset)}, Batches: {len(train_loader)}")
print(f"Number of validation samples: {len(val_dataset)}, Batches: {len(val_loader)}")
return train_loader, val_loader, train_dataset.classes # Return class names for reference
# --- CONFIGURATIONS ---
# Defined earlier: STEM_OUT_CHANNELS, STAGE_CONFIG_PARAMS, GLOBAL_K_CHOICES, etc.
# Use STAGE_CONFIG_PARAMS (the first version, which relies on OFAMaxxVitStage init to filter C_out_choices)
# Helper from previous response
def get_channel_choices(base_channel, common_divisor=8, multipliers=[0.5, 0.75, 1.0]):
choices = []
for m in multipliers:
c_candidate = make_divisible(base_channel * m, common_divisor)
if c_candidate > 0: choices.append(c_candidate)
return sorted(list(set(choices))) if choices else [make_divisible(base_channel, common_divisor)]
STEM_OUT_CHANNELS = 64
GLOBAL_K_CHOICES = [3]
GLOBAL_E_CHOICES = [4, 6]
GLOBAL_MLP_RATIO_CHOICES = [2.0, 4.0]
STAGE_CONFIG_PARAMS_FINAL = [
{ 'C_out_stage_choices': get_channel_choices(96), 'depth_choices': [1, 2], 'stride_first_block': 2, 'num_heads_attn_max': 3, 'se_fixed_rd_channels_choices_mbconv': [16, 24, 32]},
{ 'C_out_stage_choices': get_channel_choices(192), 'depth_choices': [1, 2], 'stride_first_block': 2, 'num_heads_attn_max': 6, 'se_fixed_rd_channels_choices_mbconv': [24, 32, 48]},
{ 'C_out_stage_choices': get_channel_choices(384), 'depth_choices': [2, 3, 4, 5], 'stride_first_block': 2, 'num_heads_attn_max': 12, 'se_fixed_rd_channels_choices_mbconv': [32, 48, 64, 96]},
{ 'C_out_stage_choices': get_channel_choices(768), 'depth_choices': [1, 2], 'stride_first_block': 2, 'num_heads_attn_max': 24, 'se_fixed_rd_channels_choices_mbconv': [64, 96, 128]}
]
# Training Hyperparameters
NUM_CLASSES_HAM10000 = 7 # Your dataset
TRAIN_IMAGE_SIZE = 224 # Initial image size for OFA training
BATCH_SIZE = 32 # Adjust for H100
LEARNING_RATE = 0.01 # OFA often starts higher with SGD
WEIGHT_DECAY = 1e-4 # Common for OFA
MOMENTUM_SGD = 0.9
NUM_EPOCHS_SUPERNET = 200 # Total epochs for supernet training (adjust)
WARMUP_EPOCHS = 10
KD_ALPHA = 0.4 # Weight for CE loss with true labels
KD_TEMPERATURE = 4.0 # For softening teacher/student logits
BASE_HAM10000_PATH = "/home/dgx-s-user2/controlleddiffusion/EIS/HAM10000_extracted/HAM10000_local" # <<< *** MODIFY THIS ***
# Path to your best Fold 4 MaxViT-Small model (the teacher)
TEACHER_MODEL_PATH = "/home/dgx-s-user2/controlleddiffusion/EIS/Server_Runs_KFold/maxvit_small_tf_224.in1k_k5_e20_20250523-212155/fold_4/maxvit_small_tf_224_in1k_fold4_img224_bs32_lr3e-05_e20_best_loss.pth" # <<< *** MODIFY THIS ***
# Output directory for supernet checkpoints
OUTPUT_DIR = "./ofa_maxvit_supernet_training_200"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# For Sandwich Rule
NUM_RANDOM_SAMPLES_PER_STEP = 2
def set_seed(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True # Can slow down training
# torch.backends.cudnn.benchmark = False
def get_teacher_model(path, num_classes, device):
import timm # Ensure timm is available
print(f"Loading teacher model from: {path}")
teacher = timm.create_model('maxvit_small_tf_224.in1k', pretrained=False, num_classes=num_classes)
teacher.load_state_dict(torch.load(path, map_location='cpu'))
teacher.eval()
return teacher.to(device)
def main():
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
best_val_loss_overall = float('inf')
# 1. DataLoaders
print("Setting up DataLoaders...")
if BASE_HAM10000_PATH == "/path/to/your/HAM10000_dataset_root": # Placeholder check
print("ERROR: Please update BASE_HAM10000_PATH with the actual path to your dataset.")
return
train_loader, val_loader, class_names = get_ham10000_dataloaders(
base_dataset_path=BASE_HAM10000_PATH,
image_size=TRAIN_IMAGE_SIZE,
batch_size=BATCH_SIZE,
num_workers=4 # Adjust as needed
)
if NUM_CLASSES_HAM10000 != len(class_names):
print(f"Warning: NUM_CLASSES_HAM10000 ({NUM_CLASSES_HAM10000}) != detected classes ({len(class_names)}: {class_names})")
# Potentially update NUM_CLASSES_HAM10000 = len(class_names) if dynamic
print(f"Dataloaders ready. Num classes: {len(class_names)}")
# 2. Initialize OFA-MaxViT Supernet
print("Initializing OFA-MaxViT Supernet...")
supernet = OFAMaxViT(
stem_out_channels=STEM_OUT_CHANNELS,
stage_configs=STAGE_CONFIG_PARAMS_FINAL, # Use the refined stage configs
num_classes=NUM_CLASSES_HAM10000,
global_k_mbconv_choices=GLOBAL_K_CHOICES,
global_e_mbconv_choices=GLOBAL_E_CHOICES,
global_mlp_ratio_attn_choices=GLOBAL_MLP_RATIO_CHOICES,
# Add other global choices if used by OFAMaxViT init
).to(device)
print("Supernet initialized.")
# 3. Load Teacher Model
if TEACHER_MODEL_PATH == "/path/to/your/fold_4_best_model.pth":
print("ERROR: Update TEACHER_MODEL_PATH with your actual model file!")
return
teacher_model = get_teacher_model(TEACHER_MODEL_PATH, NUM_CLASSES_HAM10000, device)
print("Teacher model loaded.")
# 4. Optimizer and Scheduler
optimizer = optim.SGD(supernet.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM_SGD, weight_decay=WEIGHT_DECAY)
# Cosine annealing scheduler (OFA often uses this)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS_SUPERNET - WARMUP_EPOCHS)
# Criterion for CE and KD
criterion_ce = nn.CrossEntropyLoss(label_smoothing=0.1) # Optional label smoothing
criterion_kd = nn.KLDivLoss(reduction='batchmean')
# --- Training Loop ---
print("Starting Supernet Training with Sandwich Rule...")
for epoch in range(NUM_EPOCHS_SUPERNET):
supernet.train()
teacher_model.eval()
# Warmup LR for first few epochs
if epoch < WARMUP_EPOCHS:
current_lr = LEARNING_RATE * (epoch + 1) / WARMUP_EPOCHS
for param_group in optimizer.param_groups:
param_group['lr'] = current_lr
elif epoch == WARMUP_EPOCHS: # After warmup, set to base LR for CosineAnnealing
for param_group in optimizer.param_groups:
param_group['lr'] = LEARNING_RATE
epoch_loss_total_avg = 0.0 # To average losses from different paths if needed
num_optimizer_steps = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS_SUPERNET} Training")
for images, true_labels in progress_bar:
images, true_labels = images.to(device), true_labels.to(device)
# Zero gradients for the accumulation of this "meta-batch"
optimizer.zero_grad()
# Calculate teacher logits once per batch of images
with torch.no_grad():
teacher_logits = teacher_model(images)
# --- Sandwich Rule ---
subnet_losses = []
# 1. Largest Subnet (trained with KD + CE)
config_largest = supernet.get_max_subnet_config()
supernet.set_active_subnet(config_largest)
student_logits_largest = supernet(images)
loss_ce_largest = criterion_ce(student_logits_largest, true_labels)
loss_kd_largest = criterion_kd(
F.log_softmax(student_logits_largest / KD_TEMPERATURE, dim=1),
F.softmax(teacher_logits / KD_TEMPERATURE, dim=1)
) * (KD_TEMPERATURE * KD_TEMPERATURE)
loss_largest = KD_ALPHA * loss_ce_largest + (1 - KD_ALPHA) * loss_kd_largest
subnet_losses.append(loss_largest)
# 2. Smallest Subnet (CE only, for simplicity to start)
config_smallest = supernet.get_min_subnet_config()
supernet.set_active_subnet(config_smallest)
student_logits_smallest = supernet(images)
loss_smallest = criterion_ce(student_logits_smallest, true_labels)
subnet_losses.append(loss_smallest)
# 3. Random Subnets (CE only, for simplicity to start)
for _ in range(NUM_RANDOM_SAMPLES_PER_STEP): # e.g., 2 random samples
config_random = supernet.sample_active_subnet_config()
supernet.set_active_subnet(config_random)
student_logits_random = supernet(images)
loss_random = criterion_ce(student_logits_random, true_labels)
subnet_losses.append(loss_random)
# --- Accumulate Gradients & Optimizer Step ---
# Average the losses from all subnets trained in this step and backpropagate once.
# This is one common way to implement the sandwich rule.
if subnet_losses:
# Each loss in subnet_losses is a scalar tensor.
# We can sum them and then divide, or sum and backward each.
# For simplicity and to mimic effect of larger batch size:
# Sum the losses, then backward. This is like one large batch.
# If memory is an issue, do .backward() on each and average grads
# by dividing loss by num_subnets_in_batch before .backward().
# Option A: Sum losses, then one backward pass.
# total_loss_for_step = sum(subnet_losses)
# total_loss_for_step.backward()
# Option B: Average losses, then one backward pass (helps normalize magnitudes).
# This is often preferred.
avg_loss_for_step = sum(subnet_losses) / len(subnet_losses)
avg_loss_for_step.backward()
optimizer.step() # Single optimizer step after processing all subnets
epoch_loss_total_avg += avg_loss_for_step.item()
num_optimizer_steps +=1
progress_bar.set_postfix(loss=avg_loss_for_step.item(), lr=optimizer.param_groups[0]['lr'])
else:
# Should not happen if largest/smallest/random are always attempted
progress_bar.set_postfix(loss="N/A (no subnets?)", lr=optimizer.param_groups[0]['lr'])
avg_epoch_loss = epoch_loss_total_avg / num_optimizer_steps if num_optimizer_steps > 0 else 0.0
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch+1}/{NUM_EPOCHS_SUPERNET} - Avg Step Loss: {avg_epoch_loss:.4f}, LR: {current_lr:.6f}")
if epoch >= WARMUP_EPOCHS:
scheduler.step()
# --- Validation (Periodically) ---
if (epoch + 1) % 5 == 0 or epoch == NUM_EPOCHS_SUPERNET - 1:
supernet.eval()
print(f"--- Validating Epoch {epoch+1} ---")
# Track the minimum validation loss found in this epoch's validation phase
current_epoch_min_val_loss = float('inf')
configs_to_validate = {
"largest": supernet.get_max_subnet_config(),
"smallest": supernet.get_min_subnet_config(),
"random_val_1": supernet.sample_active_subnet_config(),
"random_val_2": supernet.sample_active_subnet_config(),
}
for config_name, active_config_val in configs_to_validate.items():
print(f" Validating subnet: {config_name}")
supernet.set_active_subnet(active_config_val)
val_preds, val_labels = [], []
val_loss_accum_for_subnet = 0.0 # Accumulate loss for this specific subnet
num_samples_for_subnet = 0
with torch.no_grad():
for val_images, val_true_labels in tqdm(val_loader, desc=f"Validating {config_name}", leave=False):
val_images, val_true_labels = val_images.to(device), val_true_labels.to(device)
val_logits = supernet(val_images)
# Calculate CE loss for this batch
loss_batch = criterion_ce(val_logits, val_true_labels)
val_loss_accum_for_subnet += loss_batch.item() * val_images.size(0) # Sum of losses
num_samples_for_subnet += val_images.size(0)
val_preds.extend(torch.argmax(val_logits, dim=1).cpu().numpy())
val_labels.extend(val_true_labels.cpu().numpy())
avg_val_loss_subnet = val_loss_accum_for_subnet / num_samples_for_subnet if num_samples_for_subnet > 0 else float('inf')
# Update current_epoch_min_val_loss
if avg_val_loss_subnet < current_epoch_min_val_loss:
current_epoch_min_val_loss = avg_val_loss_subnet
if val_labels and len(np.unique(val_labels)) > 1 :
from sklearn.metrics import f1_score as sk_f1_score
macro_f1_val = sk_f1_score(val_labels, val_preds, average='macro', zero_division=0)
print(f" Epoch {epoch+1} - Val ({config_name} subnet): AvgLoss={avg_val_loss_subnet:.4f}, MacroF1={macro_f1_val:.4f} (on {num_samples_for_subnet} samples)")
else:
print(f" Epoch {epoch+1} - Val ({config_name} subnet): AvgLoss={avg_val_loss_subnet:.4f} (F1 not computed, {num_samples_for_subnet} samples)")
# --- Checkpoint Saving Logic ---
# 1. Save latest supernet model every N epochs
if (epoch + 1) % 10 == 0 or epoch == NUM_EPOCHS_SUPERNET - 1:
latest_checkpoint_path = os.path.join(OUTPUT_DIR, f"ofa_maxvit_supernet_latest_epoch_{epoch+1}.pth")
torch.save({
'epoch': epoch + 1,
'supernet_state_dict': supernet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'avg_train_loss': avg_epoch_loss,
'current_min_val_loss_this_epoch': current_epoch_min_val_loss, # Log it
}, latest_checkpoint_path)
print(f"Latest supernet checkpoint saved to {latest_checkpoint_path}")
# 2. Save "best" supernet model based on lowest validation loss
if current_epoch_min_val_loss < best_val_loss_overall:
best_val_loss_overall = current_epoch_min_val_loss
best_checkpoint_path = os.path.join(OUTPUT_DIR, f"ofa_maxvit_supernet_best_val_loss.pth")
torch.save({
'epoch': epoch + 1,
'supernet_state_dict': supernet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_val_loss_achieved': best_val_loss_overall,
'avg_train_loss': avg_epoch_loss,
}, best_checkpoint_path)
print(f"*** New best validation loss ({best_val_loss_overall:.4f})! Supernet checkpoint saved to {best_checkpoint_path} ***")
print("Supernet training finished.")
final_checkpoint_path = os.path.join(OUTPUT_DIR, "ofa_maxvit_supernet_final.pth")
torch.save(supernet.state_dict(), final_checkpoint_path)
print(f"Final supernet model state_dict saved to {final_checkpoint_path}")
if __name__ == '__main__':
# Path to your HAM10000 dataset (replace with actual path for your get_train_val_loaders)
# DATASET_PATH = "/path/to/your/HAM10000_extracted_server/HAM10000_local"
# --- IMPORTANT: Set TEACHER_MODEL_PATH above before running ---
main()