-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
109 lines (89 loc) · 4.36 KB
/
predict.py
File metadata and controls
109 lines (89 loc) · 4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from config import (DEVICE, SAVE_DIR, TEST_DIR, BATCH_SIZE,
K_FOLDS, NUM_WORKERS, NUM_GROUPS)
from dataset import WBCDataset, get_transforms
def predict_ensemble(test_df, le, model_class, tag="model",
use_clahe=True, out_csv=None):
"""
Moyenne les probabilités des K modèles sauvegardés et génère le CSV de
soumission.
"""
_, val_tf = get_transforms()
test_loader = DataLoader(
WBCDataset(test_df, TEST_DIR, val_tf, use_clahe=use_clahe, is_test=True),
batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
all_probs = np.zeros((len(test_df), len(le.classes_)))
for fold in range(K_FOLDS):
model_path = f"{SAVE_DIR}/best_{tag}_fold{fold+1}.pth"
model = model_class(NUM_GROUPS, len(le.classes_)).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
fold_probs = []
with torch.no_grad():
for imgs, _ in tqdm(test_loader, desc=f"Predict Fold {fold+1}"):
_, out_f = model(imgs.to(DEVICE))
fold_probs.append(torch.softmax(out_f, dim=1).cpu().numpy())
all_probs += np.concatenate(fold_probs, axis=0)
final_preds = np.argmax(all_probs / K_FOLDS, axis=1)
test_df = test_df.copy()
test_df['label'] = le.inverse_transform(final_preds)
if out_csv is None:
out_csv = f"{SAVE_DIR}/submission_{tag}.csv"
test_df[['ID', 'label']].to_csv(out_csv, index=False)
print(f"Submission sauvegardée : {out_csv}")
return test_df
def predict_ensemble_tta(test_df, le, model_class, tag="model",
use_clahe=True, n_tta=5, out_csv=None):
"""Ensemble K-Fold + Test-Time Augmentation."""
_, val_tf = get_transforms()
tta_tfs = [
val_tf,
transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(p=1.0),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomVerticalFlip(p=1.0),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
transforms.Compose([transforms.Resize((256, 256)),
transforms.CenterCrop(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomRotation(degrees=(90, 90)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
]
all_probs = np.zeros((len(test_df), len(le.classes_)))
for fold in range(K_FOLDS):
model_path = f"{SAVE_DIR}/best_{tag}_fold{fold+1}.pth"
if not os.path.exists(model_path):
print(f"Fold {fold+1} introuvable, ignoré.")
continue
model = model_class(NUM_GROUPS, len(le.classes_)).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
for tf in tta_tfs[:n_tta]:
loader = DataLoader(
WBCDataset(test_df, TEST_DIR, tf, use_clahe=use_clahe, is_test=True),
batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
fold_probs = []
with torch.no_grad():
for imgs, _ in tqdm(loader, desc=f"TTA Fold {fold+1}"):
_, out_f = model(imgs.to(DEVICE))
fold_probs.append(torch.softmax(out_f, dim=1).cpu().numpy())
all_probs += np.concatenate(fold_probs, axis=0)
total = K_FOLDS * n_tta
final_preds = np.argmax(all_probs / total, axis=1)
test_df = test_df.copy()
test_df['label'] = le.inverse_transform(final_preds)
if out_csv is None:
out_csv = f"{SAVE_DIR}/submission_{tag}_tta.csv"
test_df[['ID', 'label']].to_csv(out_csv, index=False)
print(f"Submission sauvegardée : {out_csv}")
return test_df