-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_specific_fold.py
More file actions
262 lines (223 loc) · 13.7 KB
/
evaluate_specific_fold.py
File metadata and controls
262 lines (223 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from PIL import Image
import timm
from sklearn.metrics import classification_report, confusion_matrix, balanced_accuracy_score, f1_score
import os
from tqdm import tqdm
from torch.cuda.amp import autocast # For modern PyTorch, directly from torch.amp
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import cv2
import time # For unique output directory
# --- Global Constants (Keep consistent with training) ---
MEAN_VIZ_GLOBAL = [0.485, 0.456, 0.406]
STD_VIZ_GLOBAL = [0.229, 0.224, 0.225]
CLASS_NAMES_DRIVE_DATA = sorted([
'actinic_keratoses', 'basal_cell_carcinoma', 'benign_keratosis-like_lesions',
'dermatofibroma', 'melanocytic_Nevi', 'melanoma', 'vascular_lesions'
])
NUM_CLASSES_DRIVE_DATA = len(CLASS_NAMES_DRIVE_DATA)
# --- Utility Functions (Copied from your main script) ---
def get_val_transforms(img_size_param=224):
h = int(img_size_param)
w = int(img_size_param)
return A.Compose([
A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR),
A.Normalize(mean=MEAN_VIZ_GLOBAL, std=STD_VIZ_GLOBAL),
ToTensorV2(),
])
class ImageFolderDataset(Dataset):
def __init__(self, root_dir, transform=None, class_names_list=None, img_size=224):
self.root_dir = root_dir
self.transform = transform
self.img_size = img_size
self.image_paths = []
self.image_labels_int = []
if root_dir is not None:
if not os.path.exists(root_dir):
raise FileNotFoundError(f"Root directory not found: {root_dir}")
if class_names_list:
self.class_names = sorted(class_names_list)
else:
self.class_names = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
self.class_to_int = {cls_name: i for i, cls_name in enumerate(self.class_names)}
self.int_to_class = {i: cls_name for cls_name, i in self.class_to_int.items()}
for class_idx, class_name in enumerate(self.class_names):
class_dir_path = os.path.join(root_dir, class_name)
if not os.path.isdir(class_dir_path):
print(f"Warning: Class directory not found: {class_dir_path}")
continue
for ext in ('*.jpg', '*.jpeg', '*.png', '*.bmp'):
for img_path in glob.glob(os.path.join(class_dir_path, ext)):
self.image_paths.append(img_path)
self.image_labels_int.append(class_idx)
if not self.image_paths:
print(f"Warning: No images found in {root_dir}")
elif class_names_list:
self.class_names = sorted(class_names_list) # Should be NUM_CLASSES_DRIVE_DATA
self.class_to_int = {cls_name: i for i, cls_name in enumerate(self.class_names)}
self.int_to_class = {i: cls_name for cls_name, i in self.class_to_int.items()}
else:
raise ValueError("Must provide root_dir or class_names_list for ImageFolderDataset")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]; label = self.image_labels_int[idx]
try:
image_pil = Image.open(img_path).convert('RGB'); img_np = np.array(image_pil)
except Exception as e:
print(f"Error loading image {img_path}: {e}. Returning black placeholder.")
img_np = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
if self.transform: img_tensor = self.transform(image=img_np)['image']
else:
basic_transform = A.Compose([A.Resize(self.img_size, self.img_size, interpolation=cv2.INTER_LINEAR), A.Normalize(mean=MEAN_VIZ_GLOBAL, std=STD_VIZ_GLOBAL), ToTensorV2()])
img_tensor = basic_transform(image=img_np)['image']
return img_tensor, torch.tensor(label, dtype=torch.long)
# --- Evaluation Function (Adapted from your main script's evaluate_model_from_drive) ---
def evaluate_specific_model(
model_path_to_evaluate, # Path to the .pth file of the model to test
model_architecture, # e.g., 'timm/maxvit_small_tf_224.in1k'
global_test_dataset_path, # Path to the root of the global test set (e.g., ".../HAM10000_extracted_server/HAM10000_local/test")
img_size_eval,
batch_size_eval,
num_workers_eval,
output_dir_for_this_eval, # Where to save results for this specific evaluation
device_str=None):
if device_str: DEVICE = torch.device(device_str)
else: DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(output_dir_for_this_eval, exist_ok=True)
# Load test dataset
test_transform = get_val_transforms(img_size_eval)
test_dataset = ImageFolderDataset(root_dir=global_test_dataset_path, transform=test_transform, class_names_list=CLASS_NAMES_DRIVE_DATA, img_size=img_size_eval)
if len(test_dataset) == 0:
print(f"Error: No images found in the global test dataset path: {global_test_dataset_path}")
return
print(f"\n--- Evaluating Specific Model ---")
print(f"Model Path: {model_path_to_evaluate}")
print(f"Model Architecture: {model_architecture}")
print(f"Test Data Path: {global_test_dataset_path}")
print(f"Device: {DEVICE}")
print(f"Test samples: {len(test_dataset)}, Classes: {NUM_CLASSES_DRIVE_DATA} -> {CLASS_NAMES_DRIVE_DATA}")
test_loader = DataLoader(test_dataset, batch_size=batch_size_eval, shuffle=False, num_workers=num_workers_eval, pin_memory=True)
model = timm.create_model(model_architecture, pretrained=False, num_classes=NUM_CLASSES_DRIVE_DATA) # Ensure num_classes is correct
print(f"Loading trained weights from: {model_path_to_evaluate}")
try:
model.load_state_dict(torch.load(model_path_to_evaluate, map_location=DEVICE))
except Exception as e:
print(f"Error loading model state_dict: {e}")
# Try loading the whole model if state_dict fails (less common for .pth from training)
try:
print("Attempting to load entire model object...")
model = torch.load(model_path_to_evaluate, map_location=DEVICE)
# If it was a full model save, it might already have num_classes set, but good to be sure
# For timm, usually state_dict is saved.
except Exception as e2:
print(f"Error loading entire model object: {e2}")
print("Please ensure the model_path_to_evaluate points to a valid PyTorch state_dict or model file compatible with the architecture.")
return
model = model.to(DEVICE); model.eval()
all_preds_list, all_labels_list = [], []
print("Running inference on the test set...")
with torch.no_grad():
eval_loop = tqdm(test_loader, desc="Evaluating Test Set", leave=False)
for inputs, labels in eval_loop:
inputs = inputs.to(DEVICE, non_blocking=True)
# Use the corrected autocast call
# Alternative corrected line (more explicit):
with torch.amp.autocast(device_type=DEVICE.type, enabled=(DEVICE.type == 'cuda'), dtype=torch.float16):
outputs = model(inputs)
all_preds_list.extend(torch.argmax(outputs, dim=1).cpu().numpy())
all_labels_list.extend(labels.cpu().numpy())
print("\n--- Test Set Evaluation Results ---")
if not all_labels_list: print("No data processed."); return
metric_labels_eval = list(range(NUM_CLASSES_DRIVE_DATA))
try:
report_str = classification_report(all_labels_list, all_preds_list, target_names=CLASS_NAMES_DRIVE_DATA, digits=4, zero_division=0, labels=metric_labels_eval)
cm_arr = confusion_matrix(all_labels_list, all_preds_list, labels=metric_labels_eval)
bal_acc = balanced_accuracy_score(all_labels_list, all_preds_list)
macro_f1 = f1_score(all_labels_list, all_preds_list, average='macro', zero_division=0, labels=metric_labels_eval)
except ValueError as e:
print(f"Error during metrics calculation: {e}. Check if all classes are present in predictions/labels.")
report_str, cm_arr, bal_acc, macro_f1 = "N/A due to error", np.array([]), -1.0, -1.0
results_summary = (f"\nOverall Metrics ({len(all_labels_list)} samples):\n"
f" Balanced Accuracy: {bal_acc:.4f}\n Macro F1: {macro_f1:.4f}\n\n"
f"Classification Report:\n{report_str}\n\nConfusion Matrix:\n{cm_arr if cm_arr.size > 0 else 'N/A'}\n")
print(results_summary)
model_name_for_file = os.path.basename(model_path_to_evaluate).replace('.pth', '')
results_file_path = os.path.join(output_dir_for_this_eval, f"evaluation_report_test_set_{model_name_for_file}.txt")
with open(results_file_path, "w") as f:
f.write(f"Evaluation of Model: {model_path_to_evaluate}\n")
f.write(f"Architecture: {model_architecture}\n")
f.write(f"Data: Global Test Set ({global_test_dataset_path})\n")
f.write(results_summary)
print(f"Test report saved to {results_file_path}")
if cm_arr.size > 0:
plt.figure(figsize=(max(8, NUM_CLASSES_DRIVE_DATA*1.1), max(6, NUM_CLASSES_DRIVE_DATA*0.8)))
sns.heatmap(cm_arr, annot=True, fmt='d', cmap='Blues', xticklabels=CLASS_NAMES_DRIVE_DATA, yticklabels=CLASS_NAMES_DRIVE_DATA)
plt.xlabel('Predicted Label'); plt.ylabel('True Label'); plt.title(f'CM - Test Set - {model_name_for_file}')
cm_plot_path = os.path.join(output_dir_for_this_eval, f"confusion_matrix_test_set_{model_name_for_file}.png")
plt.savefig(cm_plot_path); print(f"Confusion matrix plot saved to {cm_plot_path}"); plt.close()
# --- Configuration for Evaluation ---
if __name__ == "__main__":
# !!! --- YOU NEED TO SET THESE PATHS --- !!!
# Path to the specific .pth model file you want to evaluate (e.g., Fold 4's best model)
# Example: "/home/dgx-s-user2/controlleddiffusion/EIS/Colab_Runs_Server_KFold/maxvit_small_tf_224.in1k_k5_e20_20231105-103000/fold_4/maxvit_small_tf_224_in1k_fold4_img224_bs32_lr3e-05_e20_best_loss.pth"
MODEL_PATH_TO_TEST = "/home/dgx-s-user2/controlleddiffusion/EIS/Server_Runs_KFold/maxvit_small_tf_224.in1k_k5_e20_20250523-212155/fold_4/maxvit_small_tf_224_in1k_fold4_img224_bs32_lr3e-05_e20_best_loss.pth" # <--- *** MODIFY THIS ***
# Base path of your main run's output (where fold_X directories are)
# Used to create a sensible output directory for this specific evaluation.
# Example: "/home/dgx-s-user2/controlleddiffusion/EIS/Colab_Runs_Server_KFold/maxvit_small_tf_224.in1k_k5_e20_20231105-103000"
BASE_OUTPUT_DIR_OF_RUN = "/home/dgx-s-user2/controlleddiffusion/EIS/Specific_fold" # <--- *** MODIFY THIS ***
# Path to the root of your extracted dataset on the server
# Example: "/home/dgx-s-user2/controlleddiffusion/EIS/HAM10000_extracted_server/HAM10000_local"
BASE_EXTRACTED_DATA_PATH_SERVER = "/home/dgx-s-user2/controlleddiffusion/EIS/HAM10000_extracted/HAM10000_local" # <--- *** MODIFY THIS ***
# -------------------------------------------------
# --- These should match your training configuration ---
MODEL_ARCHITECTURE = 'timm/maxvit_small_tf_224.in1k' # Or whatever you used
IMG_SIZE = 224
BATCH_SIZE = 32 # Can be larger for H100 during inference if memory allows
NUM_WORKERS = 4 # Or your preferred number for data loading
DEVICE = "cuda" # Or "cpu" or None for auto
# -------------------------------------------------
# --- Check if MODEL_PATH_TO_TEST is a placeholder ---
if MODEL_PATH_TO_TEST == "/path/to/your/fold_4_best_model.pth":
print("ERROR: Please modify 'MODEL_PATH_TO_TEST' in the script to point to your actual model file.")
exit()
if not os.path.exists(MODEL_PATH_TO_TEST):
print(f"ERROR: Model path not found: {MODEL_PATH_TO_TEST}")
exit()
if BASE_OUTPUT_DIR_OF_RUN == "/path/to/your/main_run_output_directory":
print("WARNING: 'BASE_OUTPUT_DIR_OF_RUN' is using placeholder. Output will be in a generic directory.")
eval_output_parent_dir = os.path.join(os.path.dirname(MODEL_PATH_TO_TEST), "global_test_evaluation")
else:
eval_output_parent_dir = os.path.join(BASE_OUTPUT_DIR_OF_RUN, "specific_fold_evaluations")
# Create a unique subdirectory for this specific evaluation's results
model_filename_stem = os.path.splitext(os.path.basename(MODEL_PATH_TO_TEST))[0]
timestamp = time.strftime('%Y%m%d-%H%M%S')
CURRENT_EVAL_OUTPUT_DIR = os.path.join(eval_output_parent_dir, f"eval_{model_filename_stem}_{timestamp}")
os.makedirs(CURRENT_EVAL_OUTPUT_DIR, exist_ok=True)
print(f"Results for this evaluation will be saved in: {CURRENT_EVAL_OUTPUT_DIR}")
GLOBAL_TEST_SET_FULL_PATH = os.path.join(BASE_EXTRACTED_DATA_PATH_SERVER, "test") # Assuming 'test' is the subdir name
if BASE_EXTRACTED_DATA_PATH_SERVER == "/path/to/your/extracted_dataset_root":
print("ERROR: Please modify 'BASE_EXTRACTED_DATA_PATH_SERVER' in the script.")
exit()
if not os.path.isdir(GLOBAL_TEST_SET_FULL_PATH):
print(f"ERROR: Global test set directory not found: {GLOBAL_TEST_SET_FULL_PATH}")
print("Please ensure 'BASE_EXTRACTED_DATA_PATH_SERVER' is correct and contains a 'test' subdirectory.")
exit()
evaluate_specific_model(
model_path_to_evaluate=MODEL_PATH_TO_TEST,
model_architecture=MODEL_ARCHITECTURE,
global_test_dataset_path=GLOBAL_TEST_SET_FULL_PATH,
img_size_eval=IMG_SIZE,
batch_size_eval=BATCH_SIZE,
num_workers_eval=NUM_WORKERS,
output_dir_for_this_eval=CURRENT_EVAL_OUTPUT_DIR,
device_str=DEVICE
)
print(f"\nEvaluation script finished. Check results in {CURRENT_EVAL_OUTPUT_DIR}")