-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathanalyze_subnets_complexity.py
More file actions
215 lines (180 loc) · 11.1 KB
/
analyze_subnets_complexity.py
File metadata and controls
215 lines (180 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import torch
import torch.nn as nn
import pandas as pd
import json
import os
import sys
from tqdm import tqdm
import re
# FVCore for complexity analysis
try:
from fvcore.nn import FlopCountAnalysis, parameter_count_table
FVCORE_AVAILABLE = True
except ImportError:
print("Warning: fvcore not found. FLOPs and detailed MParams calculation will be skipped.")
print("Please install fvcore: pip install fvcore")
FVCORE_AVAILABLE = False
# Your OFA building blocks and OFAMaxViT model
# Ensure OFAMaxViT and its dependencies are importable
# This might mean copying OFAMaxViT definition here or ensuring ofa_maxvit_building_blocks.py is in PYTHONPATH
try:
from ofa_maxvit_building_blocks import OFAMaxViT, make_divisible # Or your actual import path
except ImportError as e:
print(f"Error importing OFAMaxViT: {e}")
print("Please ensure ofa_maxvit_building_blocks.py is in your PYTHONPATH or current directory.")
sys.exit(1)
# --- CONFIGURATIONS ---
# These must exactly match how the OFAMaxViT supernet was initialized during training AND search
# Values from your search_ofa_subnets.py or train_ofa_maxvit.py
# Helper to get channel choices (ensure it's available or define it here)
def get_channel_choices(base_channel, common_divisor=8, multipliers=[0.5, 0.75, 1.0]):
choices = []
for m in multipliers:
raw_c = base_channel * m
if raw_c < common_divisor and raw_c > 0:
c_candidate = common_divisor
else:
c_candidate = make_divisible(raw_c, common_divisor)
if c_candidate > 0: choices.append(c_candidate)
if not choices and base_channel > 0: choices.append(make_divisible(base_channel, common_divisor))
elif not choices and base_channel == 0: return [0] # Allow zero channel choices if base is zero
# Ensure there's at least one choice if base_channel > 0 and multipliers resulted in empty
if not choices and base_channel > 0:
choices.append(make_divisible(base_channel, common_divisor))
return sorted(list(set(choices)))
STEM_OUT_CHANNELS = 64
GLOBAL_K_CHOICES = [3]
GLOBAL_E_CHOICES = [4, 6]
GLOBAL_MLP_RATIO_CHOICES = [2.0, 4.0]
# This STAGE_CONFIG_PARAMS_FINAL should be EXACTLY as used to initialize OFAMaxViT when it was trained/searched
STAGE_CONFIG_PARAMS_FINAL = [
{ 'C_out_stage_choices': get_channel_choices(96), 'depth_choices': [1, 2], 'stride_first_block': 2, 'num_heads_attn_max': 3, 'se_fixed_rd_channels_choices_mbconv': [16, 24, 32]},
{ 'C_out_stage_choices': get_channel_choices(192), 'depth_choices': [1, 2], 'stride_first_block': 2, 'num_heads_attn_max': 6, 'se_fixed_rd_channels_choices_mbconv': [24, 32, 48]},
{ 'C_out_stage_choices': get_channel_choices(384), 'depth_choices': [2, 3, 4, 5], 'stride_first_block': 2, 'num_heads_attn_max': 12, 'se_fixed_rd_channels_choices_mbconv': [32, 48, 64, 96]},
{ 'C_out_stage_choices': get_channel_choices(768), 'depth_choices': [1, 2], 'stride_first_block': 2, 'num_heads_attn_max': 24, 'se_fixed_rd_channels_choices_mbconv': [64, 96, 128]}
]
NUM_CLASSES_HAM10000 = 7
EVAL_IMAGE_SIZE = 224 # Input image size for FLOPs calculation
# Path to your trained supernet checkpoint
SUPERNET_CHECKPOINT_PATH = "/home/dgx-s-user2/controlleddiffusion/EIS/ofa_maxvit_supernet_training/ofa_maxvit_supernet_best_val_loss.pth" # <<< *** MODIFY THIS ***
# Path to the CSV file from your subnet search
SEARCH_RESULTS_CSV_PATH = "/home/dgx-s-user2/controlleddiffusion/EIS/ofa_maxvit_subnet_search_results/subnet_search_results_upto_1000.csv" # <<< *** MODIFY THIS ***
# Output path for the new CSV with complexity metrics
OUTPUT_CSV_WITH_COMPLEXITY = "/home/dgx-s-user2/controlleddiffusion/EIS/subnet_results_with_complexity_100_best_loss.csv" # <<< *** MODIFY THIS ***
def get_subnet_complexity(supernet_model_instance, active_config_dict,
input_res=(EVAL_IMAGE_SIZE, EVAL_IMAGE_SIZE), device='cpu'):
if not FVCORE_AVAILABLE:
return -1.0, -1.0 # Return error indicators if fvcore is not available
supernet_model_instance.eval().to(device)
supernet_model_instance.set_active_subnet(active_config_dict)
dummy_input = torch.randn(1, 3, input_res[0], input_res[1]).to(device)
gmacs, mparams = -1.0, -1.0
try:
flops_analyzer = FlopCountAnalysis(supernet_model_instance, dummy_input)
# fvcore FlopCountAnalysis.total() returns MACs. 1 MAC = 2 FLOPs (approx)
# Common practice is to report MACs as "FLOPs" in papers, or explicitly state GigaMACs.
# Let's get MACs and call them GMACs.
gmacs = flops_analyzer.total() / 1e9
except Exception as e:
print(f" fvcore FLOPs calculation error: {e}")
# import traceback; traceback.print_exc() # For detailed debug
# Parameter count: IMPORTANT CAVEAT (as discussed before)
# This counts parameters in the *current supernet object*.
# Dynamic layers still hold max weights. True active params require building a static subnet.
# This will be an OVERESTIMATE of active parameters for OFA.
# However, if `get_active_subnet` is not robustly implemented yet, this is a starting point.
# For OFA, the number of parameters of a sub-network *is not simply a sum from the super-network*.
# A better approach if get_active_subnet is not ready, is to calculate from config.
# For now, let's use a placeholder or a very rough calculation.
# A *true* parameter count requires a static model instance.
# Let's calculate parameters MANUALLY from the config for a more OFA-like estimate.
# This is still an approximation as it doesn't account for shared biases etc perfectly.
# --- Manual Parameter Calculation (Approximation for OFA Subnet) ---
# This requires knowledge of how many params each active component has.
# This is complex. For now, we will report fvcore's GMACs (good)
# and for MParams, we will either use fvcore's (overestimate for OFA) or try a config-based sum.
# A config-based sum is also hard without introspecting each dynamic layer's exact active weight shapes.
# For now, let's report the fvcore parameter count with a clear CAVEAT.
# Or, if your OFAMaxViT has a method like `count_active_parameters(config)`, use that.
try:
# This is supernet params, NOT active subnet params
# total_params = sum(p.numel() for p in supernet_model_instance.parameters())
# mparams = total_params / 1e6
# To get a more OFA-like parameter count (still an estimate without building static model):
# We need to iterate through the active modules and sum their *active* parameters.
# This would require each OFA block to have a .count_active_parameters() method.
# For simplicity in this script, we'll report the GMACs and placeholder for MParams
# or use the fvcore total (and note it's an overestimate).
# Let's try to print the fvcore parameter table result as is.
param_table_str = parameter_count_table(supernet_model_instance, max_depth=0)
match = re.search(r"Total parameters: ([\d,]+)", param_table_str)
if match:
total_params_supernet_state = int(match.group(1).replace(',', ''))
mparams = total_params_supernet_state / 1e6 # This is for the current state of supernet
else:
mparams = -2.0 # Different error code
except Exception as e:
print(f" fvcore Parameter count error: {e}")
mparams = -1.0
return gmacs, mparams
def main_analysis():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Analyzing subnets using device: {device}")
if not os.path.exists(SEARCH_RESULTS_CSV_PATH):
print(f"ERROR: Search results CSV not found: {SEARCH_RESULTS_CSV_PATH}")
return
df_search_results = pd.read_csv(SEARCH_RESULTS_CSV_PATH)
print(f"Loaded {len(df_search_results)} subnet results from {SEARCH_RESULTS_CSV_PATH}")
# Initialize Supernet once
if not os.path.exists(SUPERNET_CHECKPOINT_PATH):
print(f"ERROR: Supernet checkpoint not found: {SUPERNET_CHECKPOINT_PATH}")
return
supernet = OFAMaxViT(
stem_out_channels=STEM_OUT_CHANNELS,
stage_configs=STAGE_CONFIG_PARAMS_FINAL,
num_classes=NUM_CLASSES_HAM10000, # Ensure this matches the trained supernet's head
global_k_mbconv_choices=GLOBAL_K_CHOICES,
global_e_mbconv_choices=GLOBAL_E_CHOICES,
global_mlp_ratio_attn_choices=GLOBAL_MLP_RATIO_CHOICES,
).to(device) # Send to device before loading state_dict for potential DDP compatibility
checkpoint = torch.load(SUPERNET_CHECKPOINT_PATH, map_location=device)
# Handle potential DDP-saved models (module. prefix)
state_dict = checkpoint.get('supernet_state_dict', checkpoint.get('model_state_dict', checkpoint))
if any(key.startswith('module.') for key in state_dict.keys()):
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
supernet.load_state_dict(state_dict)
supernet.eval()
print("Supernet loaded and set to eval mode.")
gmacs_list = []
mparams_list = [] # This will be supernet params, not active subnet, due to fvcore behavior on dynamic nets
for idx, row in tqdm(df_search_results.iterrows(), total=len(df_search_results), desc="Calculating Complexity"):
config_details_str = row['config_details']
try:
active_config = json.loads(config_details_str)
gmacs, mparams = get_subnet_complexity(supernet, active_config, device=device)
gmacs_list.append(gmacs)
mparams_list.append(mparams) # Reminder: This mparams from fvcore is an overestimate for OFA subnets
except json.JSONDecodeError:
print(f" Error decoding JSON for config: {row['config_name']}")
gmacs_list.append(-1.0)
mparams_list.append(-1.0)
except Exception as e:
print(f" Error getting complexity for {row['config_name']}: {e}")
gmacs_list.append(-1.0)
mparams_list.append(-1.0)
df_search_results['gmacs'] = gmacs_list
df_search_results['mparams_supernet_state'] = mparams_list # Clearly label this column
# Sort by F1, then GMACs
df_sorted = df_search_results.sort_values(by=['macro_f1', 'gmacs'], ascending=[False, True])
df_sorted.to_csv(OUTPUT_CSV_WITH_COMPLEXITY, index=False)
print(f"\nAnalysis complete. Results with complexity saved to: {OUTPUT_CSV_WITH_COMPLEXITY}")
print("\nTop performing subnets (by Macro F1, then GMACs):")
print(df_sorted[['config_name', 'macro_f1', 'val_loss', 'gmacs', 'mparams_supernet_state']].head(20))
print("\nCAVEAT: 'mparams_supernet_state' is the parameter count of the supernet object in its current active state as measured by fvcore.")
print("For OFA, the true parameters of an *extracted static subnet* would be lower.")
print("GMACs should be a good indicator of computational cost for the active path.")
if __name__ == '__main__':
# --- IMPORTANT: Set paths at the top of the script ---
# Ensure SUPERNET_CHECKPOINT_PATH, SEARCH_RESULTS_CSV_PATH, and OUTPUT_CSV_WITH_COMPLEXITY are correct.
# Also ensure STAGE_CONFIG_PARAMS_FINAL and other OFAMaxViT init params match your trained supernet.
main_analysis()