|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +@date: 2020/4/19 下午3:07 |
| 5 | +@file: batch_detect.py.py |
| 6 | +@author: zj |
| 7 | +@description: 批量检测数据(for mAP) |
| 8 | +""" |
| 9 | + |
| 10 | +import os |
| 11 | +import glob |
| 12 | +import time |
| 13 | +import shutil |
| 14 | +import cv2 |
| 15 | +import numpy as np |
| 16 | +import torch |
| 17 | +import torchvision.transforms as transforms |
| 18 | + |
| 19 | +from utils import file |
| 20 | +from utils import util |
| 21 | +from utils import draw |
| 22 | +from models.location_dataset import LocationDataset |
| 23 | +from models.yolo_v1 import YOLO_v1 |
| 24 | + |
| 25 | +S = 7 |
| 26 | +B = 2 |
| 27 | +C = 3 |
| 28 | + |
| 29 | +cate_list = ['cucumber', 'eggplant', 'mushroom'] |
| 30 | + |
| 31 | + |
| 32 | +def get_transform(): |
| 33 | + transform = transforms.Compose([ |
| 34 | + transforms.ToPILImage(), |
| 35 | + transforms.Resize((448, 448)), |
| 36 | + transforms.ToTensor(), |
| 37 | + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| 38 | + ]) |
| 39 | + |
| 40 | + return transform |
| 41 | + |
| 42 | + |
| 43 | +def load_data(root_dir): |
| 44 | + img_path_list = glob.glob(os.path.join(root_dir, '*.jpg')) |
| 45 | + annotation_path_list = [os.path.join(root_dir, os.path.splitext(os.path.basename(img_path))[0] + ".xml") |
| 46 | + for img_path in img_path_list] |
| 47 | + |
| 48 | + return img_path_list, annotation_path_list |
| 49 | + |
| 50 | + |
| 51 | +def parse_data(img_path, xml_path, transform): |
| 52 | + src = cv2.imread(img_path) |
| 53 | + bndboxs, name_list = file.parse_location_xml(xml_path) |
| 54 | + # dst = draw.plot_box(src, bndboxs, name_list) |
| 55 | + # draw.show(dst) |
| 56 | + |
| 57 | + h, w = src.shape[:2] |
| 58 | + img = transform(src) |
| 59 | + scale_h, scale_w = img.shape[1:] |
| 60 | + ratio_h = scale_h / h |
| 61 | + ratio_w = scale_w / w |
| 62 | + |
| 63 | + # [C, H, W] -> [N, C, H, W] |
| 64 | + img = img.unsqueeze(0) |
| 65 | + |
| 66 | + data_dict = {} |
| 67 | + data_dict['src'] = src |
| 68 | + data_dict['src_size'] = (h, w) |
| 69 | + data_dict['bndboxs'] = bndboxs |
| 70 | + data_dict['name_list'] = name_list |
| 71 | + |
| 72 | + data_dict['img'] = img |
| 73 | + data_dict['scale_size'] = (scale_h, scale_w) |
| 74 | + data_dict['ratio'] = (ratio_h, ratio_w) |
| 75 | + |
| 76 | + return img, data_dict |
| 77 | + |
| 78 | + |
| 79 | +def load_model(device): |
| 80 | + model_path = './models/checkpoint_yolo_v1_49.pth' |
| 81 | + model = YOLO_v1(S=7, B=2, C=3) |
| 82 | + model.load_state_dict(torch.load(model_path)) |
| 83 | + model.eval() |
| 84 | + for param in model.parameters(): |
| 85 | + param.requires_grad = False |
| 86 | + model = model.to(device) |
| 87 | + |
| 88 | + return model |
| 89 | + |
| 90 | + |
| 91 | +def deform_bboxs(pred_bboxs, data_dict): |
| 92 | + """ |
| 93 | + :param pred_bboxs: [S*S, 4] |
| 94 | + :return: |
| 95 | + """ |
| 96 | + scale_h, scale_w = data_dict['scale_size'] |
| 97 | + grid_w = scale_w / S |
| 98 | + grid_h = scale_h / S |
| 99 | + |
| 100 | + bboxs = np.zeros(pred_bboxs.shape) |
| 101 | + for i in range(S * S): |
| 102 | + row = int(i / S) |
| 103 | + col = int(i % S) |
| 104 | + |
| 105 | + x_center, y_center, box_w, box_h = pred_bboxs[i] |
| 106 | + bboxs[i, 0] = (col + x_center) * grid_w |
| 107 | + bboxs[i, 1] = (row + y_center) * grid_h |
| 108 | + bboxs[i, 2] = box_w * scale_w |
| 109 | + bboxs[i, 3] = box_h * scale_h |
| 110 | + # (x_center, y_center, w, h) -> (xmin, ymin, xmax, ymax) |
| 111 | + bboxs = util.bbox_center_to_corner(bboxs) |
| 112 | + |
| 113 | + ratio_h, ratio_w = data_dict['ratio'] |
| 114 | + bboxs[:, 0] /= ratio_w |
| 115 | + bboxs[:, 1] /= ratio_h |
| 116 | + bboxs[:, 2] /= ratio_w |
| 117 | + bboxs[:, 3] /= ratio_h |
| 118 | + |
| 119 | + # 最大最小值 |
| 120 | + h, w = data_dict['src_size'] |
| 121 | + bboxs[:, 0] = np.maximum(bboxs[:, 0], 0) |
| 122 | + bboxs[:, 1] = np.maximum(bboxs[:, 1], 0) |
| 123 | + bboxs[:, 2] = np.minimum(bboxs[:, 2], w) |
| 124 | + bboxs[:, 3] = np.minimum(bboxs[:, 3], h) |
| 125 | + |
| 126 | + return bboxs.astype(int) |
| 127 | + |
| 128 | + |
| 129 | +def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs, pred_bboxs): |
| 130 | + """ |
| 131 | + 保存检测结果 |
| 132 | + :param img_name: 图像名 |
| 133 | + :param img: 原始图像 |
| 134 | + :param target_cates: 标注边界框所属类别 |
| 135 | + :param target_bboxs: 标注边界框坐标 |
| 136 | + :param pred_cates: 预测边界框类别 |
| 137 | + :param pred_probs: 预测边界框置信度 |
| 138 | + :param pred_bboxs: 预测边界框坐标 |
| 139 | + """ |
| 140 | + dst_root_dir = './data/outputs' |
| 141 | + dst_target_dir = os.path.join(dst_root_dir, 'targets') |
| 142 | + dst_pred_dir = os.path.join(dst_root_dir, 'preds') |
| 143 | + dst_img_dir = os.path.join(dst_root_dir, 'imgs') |
| 144 | + |
| 145 | + file.check_dir(dst_root_dir) |
| 146 | + file.check_dir(dst_target_dir) |
| 147 | + file.check_dir(dst_pred_dir) |
| 148 | + file.check_dir(dst_img_dir) |
| 149 | + |
| 150 | + img_path = os.path.join(dst_img_dir, img_name + ".png") |
| 151 | + cv2.imwrite(img_path, img) |
| 152 | + annotation_path = os.path.join(dst_target_dir, img_name + ".txt") |
| 153 | + with open(annotation_path, 'w') as f: |
| 154 | + for i in range(len(target_cates)): |
| 155 | + target_cate_name = target_cates[i] |
| 156 | + xmin, ymin, xmax, ymax = target_bboxs[i] |
| 157 | + |
| 158 | + f.write('%s %d %d %d %d' % (target_cate_name, xmin, ymin, xmax, ymax)) |
| 159 | + if i != (len(target_cates) - 1): |
| 160 | + f.write('\n') |
| 161 | + pred_path = os.path.join(dst_pred_dir, img_name + ".txt") |
| 162 | + with open(pred_path, 'w') as f: |
| 163 | + for i in range(len(pred_cates)): |
| 164 | + pred_cate_idx = pred_cates[i] |
| 165 | + pred_prob = pred_probs[i] |
| 166 | + xmin, ymin, xmax, ymax = pred_bboxs[i] |
| 167 | + |
| 168 | + f.write('%s %.3f %d %d %d %d' % (cate_list[pred_cate_idx], pred_prob, xmin, ymin, xmax, ymax)) |
| 169 | + if i != (len(pred_cates) - 1): |
| 170 | + f.write('\n') |
| 171 | + |
| 172 | + |
| 173 | +if __name__ == '__main__': |
| 174 | + # device = util.get_device() |
| 175 | + device = "cpu" |
| 176 | + model = load_model(device) |
| 177 | + |
| 178 | + transform = get_transform() |
| 179 | + img_path_list, annotation_path_list = load_data('./data/training_images') |
| 180 | + # print(img_path_list) |
| 181 | + |
| 182 | + N = len(img_path_list) |
| 183 | + for i in range(N): |
| 184 | + img_path = img_path_list[i] |
| 185 | + print(i, img_path) |
| 186 | + annotation_path = annotation_path_list[i] |
| 187 | + |
| 188 | + img, data_dict = parse_data(img_path, annotation_path, transform) |
| 189 | + |
| 190 | + # 计算 |
| 191 | + outputs = model.forward(img.to(device)).cpu().squeeze(0).numpy() |
| 192 | + |
| 193 | + # (S*S, C) |
| 194 | + pred_probs = outputs[:, :C] |
| 195 | + # (S*S, C:(C+B)) |
| 196 | + pred_confidences = outputs[:, C:(C + B)] |
| 197 | + # (S*S, (C+B):(C+5B)) |
| 198 | + pred_bboxs = outputs[:, (C + B):] |
| 199 | + |
| 200 | + # 计算类别 |
| 201 | + pred_cates = np.argmax(pred_probs, axis=1).astype(int) |
| 202 | + # 计算分类概率 |
| 203 | + pred_confidences_idxs = np.argmax(pred_confidences, axis=1) |
| 204 | + pred_cate_probs = pred_probs[range(S * S), pred_cates] \ |
| 205 | + * pred_confidences[range(S * S), pred_confidences_idxs] |
| 206 | + # 计算预测边界框 |
| 207 | + pred_cate_bboxs = np.zeros((S * S, 4)) |
| 208 | + pred_cate_bboxs[:, 0] = pred_bboxs[range(S * S), pred_confidences_idxs * 4] |
| 209 | + pred_cate_bboxs[:, 1] = pred_bboxs[range(S * S), pred_confidences_idxs * 4 + 1] |
| 210 | + pred_cate_bboxs[:, 2] = pred_bboxs[range(S * S), pred_confidences_idxs * 4 + 2] |
| 211 | + pred_cate_bboxs[:, 3] = pred_bboxs[range(S * S), pred_confidences_idxs * 4 + 3] |
| 212 | + |
| 213 | + # 预测边界框的缩放,回到原始图像 |
| 214 | + pred_bboxs = deform_bboxs(pred_cate_bboxs, data_dict) |
| 215 | + |
| 216 | + # 保存图像/标注边界框/预测边界框 |
| 217 | + img_name = os.path.splitext(os.path.basename(img_path))[0] |
| 218 | + save_data(img_name, data_dict['src'], data_dict['name_list'], data_dict['bndboxs'], |
| 219 | + pred_cates, pred_cate_probs, pred_bboxs) |
| 220 | + print('done') |
0 commit comments