Skip to content

Commit 619bcf1

Browse files
committed
feat(detect): 批量检测并保存检测结果
1 parent 3a38a54 commit 619bcf1

File tree

1 file changed

+220
-0
lines changed

1 file changed

+220
-0
lines changed

py/batch_detect.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
@date: 2020/4/19 下午3:07
5+
@file: batch_detect.py.py
6+
@author: zj
7+
@description: 批量检测数据(for mAP)
8+
"""
9+
10+
import os
11+
import glob
12+
import time
13+
import shutil
14+
import cv2
15+
import numpy as np
16+
import torch
17+
import torchvision.transforms as transforms
18+
19+
from utils import file
20+
from utils import util
21+
from utils import draw
22+
from models.location_dataset import LocationDataset
23+
from models.yolo_v1 import YOLO_v1
24+
25+
S = 7
26+
B = 2
27+
C = 3
28+
29+
cate_list = ['cucumber', 'eggplant', 'mushroom']
30+
31+
32+
def get_transform():
33+
transform = transforms.Compose([
34+
transforms.ToPILImage(),
35+
transforms.Resize((448, 448)),
36+
transforms.ToTensor(),
37+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
38+
])
39+
40+
return transform
41+
42+
43+
def load_data(root_dir):
44+
img_path_list = glob.glob(os.path.join(root_dir, '*.jpg'))
45+
annotation_path_list = [os.path.join(root_dir, os.path.splitext(os.path.basename(img_path))[0] + ".xml")
46+
for img_path in img_path_list]
47+
48+
return img_path_list, annotation_path_list
49+
50+
51+
def parse_data(img_path, xml_path, transform):
52+
src = cv2.imread(img_path)
53+
bndboxs, name_list = file.parse_location_xml(xml_path)
54+
# dst = draw.plot_box(src, bndboxs, name_list)
55+
# draw.show(dst)
56+
57+
h, w = src.shape[:2]
58+
img = transform(src)
59+
scale_h, scale_w = img.shape[1:]
60+
ratio_h = scale_h / h
61+
ratio_w = scale_w / w
62+
63+
# [C, H, W] -> [N, C, H, W]
64+
img = img.unsqueeze(0)
65+
66+
data_dict = {}
67+
data_dict['src'] = src
68+
data_dict['src_size'] = (h, w)
69+
data_dict['bndboxs'] = bndboxs
70+
data_dict['name_list'] = name_list
71+
72+
data_dict['img'] = img
73+
data_dict['scale_size'] = (scale_h, scale_w)
74+
data_dict['ratio'] = (ratio_h, ratio_w)
75+
76+
return img, data_dict
77+
78+
79+
def load_model(device):
80+
model_path = './models/checkpoint_yolo_v1_49.pth'
81+
model = YOLO_v1(S=7, B=2, C=3)
82+
model.load_state_dict(torch.load(model_path))
83+
model.eval()
84+
for param in model.parameters():
85+
param.requires_grad = False
86+
model = model.to(device)
87+
88+
return model
89+
90+
91+
def deform_bboxs(pred_bboxs, data_dict):
92+
"""
93+
:param pred_bboxs: [S*S, 4]
94+
:return:
95+
"""
96+
scale_h, scale_w = data_dict['scale_size']
97+
grid_w = scale_w / S
98+
grid_h = scale_h / S
99+
100+
bboxs = np.zeros(pred_bboxs.shape)
101+
for i in range(S * S):
102+
row = int(i / S)
103+
col = int(i % S)
104+
105+
x_center, y_center, box_w, box_h = pred_bboxs[i]
106+
bboxs[i, 0] = (col + x_center) * grid_w
107+
bboxs[i, 1] = (row + y_center) * grid_h
108+
bboxs[i, 2] = box_w * scale_w
109+
bboxs[i, 3] = box_h * scale_h
110+
# (x_center, y_center, w, h) -> (xmin, ymin, xmax, ymax)
111+
bboxs = util.bbox_center_to_corner(bboxs)
112+
113+
ratio_h, ratio_w = data_dict['ratio']
114+
bboxs[:, 0] /= ratio_w
115+
bboxs[:, 1] /= ratio_h
116+
bboxs[:, 2] /= ratio_w
117+
bboxs[:, 3] /= ratio_h
118+
119+
# 最大最小值
120+
h, w = data_dict['src_size']
121+
bboxs[:, 0] = np.maximum(bboxs[:, 0], 0)
122+
bboxs[:, 1] = np.maximum(bboxs[:, 1], 0)
123+
bboxs[:, 2] = np.minimum(bboxs[:, 2], w)
124+
bboxs[:, 3] = np.minimum(bboxs[:, 3], h)
125+
126+
return bboxs.astype(int)
127+
128+
129+
def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs, pred_bboxs):
130+
"""
131+
保存检测结果
132+
:param img_name: 图像名
133+
:param img: 原始图像
134+
:param target_cates: 标注边界框所属类别
135+
:param target_bboxs: 标注边界框坐标
136+
:param pred_cates: 预测边界框类别
137+
:param pred_probs: 预测边界框置信度
138+
:param pred_bboxs: 预测边界框坐标
139+
"""
140+
dst_root_dir = './data/outputs'
141+
dst_target_dir = os.path.join(dst_root_dir, 'targets')
142+
dst_pred_dir = os.path.join(dst_root_dir, 'preds')
143+
dst_img_dir = os.path.join(dst_root_dir, 'imgs')
144+
145+
file.check_dir(dst_root_dir)
146+
file.check_dir(dst_target_dir)
147+
file.check_dir(dst_pred_dir)
148+
file.check_dir(dst_img_dir)
149+
150+
img_path = os.path.join(dst_img_dir, img_name + ".png")
151+
cv2.imwrite(img_path, img)
152+
annotation_path = os.path.join(dst_target_dir, img_name + ".txt")
153+
with open(annotation_path, 'w') as f:
154+
for i in range(len(target_cates)):
155+
target_cate_name = target_cates[i]
156+
xmin, ymin, xmax, ymax = target_bboxs[i]
157+
158+
f.write('%s %d %d %d %d' % (target_cate_name, xmin, ymin, xmax, ymax))
159+
if i != (len(target_cates) - 1):
160+
f.write('\n')
161+
pred_path = os.path.join(dst_pred_dir, img_name + ".txt")
162+
with open(pred_path, 'w') as f:
163+
for i in range(len(pred_cates)):
164+
pred_cate_idx = pred_cates[i]
165+
pred_prob = pred_probs[i]
166+
xmin, ymin, xmax, ymax = pred_bboxs[i]
167+
168+
f.write('%s %.3f %d %d %d %d' % (cate_list[pred_cate_idx], pred_prob, xmin, ymin, xmax, ymax))
169+
if i != (len(pred_cates) - 1):
170+
f.write('\n')
171+
172+
173+
if __name__ == '__main__':
174+
# device = util.get_device()
175+
device = "cpu"
176+
model = load_model(device)
177+
178+
transform = get_transform()
179+
img_path_list, annotation_path_list = load_data('./data/training_images')
180+
# print(img_path_list)
181+
182+
N = len(img_path_list)
183+
for i in range(N):
184+
img_path = img_path_list[i]
185+
print(i, img_path)
186+
annotation_path = annotation_path_list[i]
187+
188+
img, data_dict = parse_data(img_path, annotation_path, transform)
189+
190+
# 计算
191+
outputs = model.forward(img.to(device)).cpu().squeeze(0).numpy()
192+
193+
# (S*S, C)
194+
pred_probs = outputs[:, :C]
195+
# (S*S, C:(C+B))
196+
pred_confidences = outputs[:, C:(C + B)]
197+
# (S*S, (C+B):(C+5B))
198+
pred_bboxs = outputs[:, (C + B):]
199+
200+
# 计算类别
201+
pred_cates = np.argmax(pred_probs, axis=1).astype(int)
202+
# 计算分类概率
203+
pred_confidences_idxs = np.argmax(pred_confidences, axis=1)
204+
pred_cate_probs = pred_probs[range(S * S), pred_cates] \
205+
* pred_confidences[range(S * S), pred_confidences_idxs]
206+
# 计算预测边界框
207+
pred_cate_bboxs = np.zeros((S * S, 4))
208+
pred_cate_bboxs[:, 0] = pred_bboxs[range(S * S), pred_confidences_idxs * 4]
209+
pred_cate_bboxs[:, 1] = pred_bboxs[range(S * S), pred_confidences_idxs * 4 + 1]
210+
pred_cate_bboxs[:, 2] = pred_bboxs[range(S * S), pred_confidences_idxs * 4 + 2]
211+
pred_cate_bboxs[:, 3] = pred_bboxs[range(S * S), pred_confidences_idxs * 4 + 3]
212+
213+
# 预测边界框的缩放,回到原始图像
214+
pred_bboxs = deform_bboxs(pred_cate_bboxs, data_dict)
215+
216+
# 保存图像/标注边界框/预测边界框
217+
img_name = os.path.splitext(os.path.basename(img_path))[0]
218+
save_data(img_name, data_dict['src'], data_dict['name_list'], data_dict['bndboxs'],
219+
pred_cates, pred_cate_probs, pred_bboxs)
220+
print('done')

0 commit comments

Comments
 (0)