Skip to content

Commit c173739

Browse files
committed
fix(detect): 中心坐标和网格对应关系
1 parent 064f03f commit c173739

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

py/detector.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,32 @@ def load_data(img_path, xml_path):
4343
ratio_h = scale_h / h
4444
ratio_w = scale_w / w
4545

46-
scale_bndboxs = torch.from_numpy(bndboxs).float()
47-
scale_bndboxs[:, 0] = scale_bndboxs[:, 0] * ratio_w
48-
scale_bndboxs[:, 1] = scale_bndboxs[:, 1] * ratio_h
49-
scale_bndboxs[:, 2] = scale_bndboxs[:, 2] * ratio_w
50-
scale_bndboxs[:, 3] = scale_bndboxs[:, 3] * ratio_h
51-
scale_bndboxs = scale_bndboxs.int().numpy()
52-
46+
# [C, H, W] -> [N, C, H, W]
5347
img = img.unsqueeze(0)
48+
5449
data_dict = {}
5550
data_dict['src'] = src
5651
data_dict['src_size'] = (h, w)
5752
data_dict['bndboxs'] = bndboxs
53+
data_dict['name_list'] = name_list
54+
5855
data_dict['img'] = img
5956
data_dict['scale_size'] = (scale_h, scale_w)
6057
data_dict['ratio'] = (ratio_h, ratio_w)
61-
return img, scale_bndboxs, name_list, data_dict
58+
59+
return img, data_dict
60+
61+
62+
def load_model():
63+
model_path = './models/checkpoint_yolo_v1_24.pth'
64+
model = YOLO_v1(S=7, B=2, C=3)
65+
model.load_state_dict(torch.load(model_path))
66+
model.eval()
67+
for param in model.parameters():
68+
param.requires_grad = False
69+
model = model.to(device)
70+
71+
return model
6272

6373

6474
def deform_bboxs(pred_bboxs, data_dict):
@@ -76,8 +86,8 @@ def deform_bboxs(pred_bboxs, data_dict):
7686
col = int(i % S)
7787

7888
x_center, y_center, box_w, box_h = pred_bboxs[i]
79-
bboxs[i, 0] = (row + x_center) * grid_w
80-
bboxs[i, 1] = (col + y_center) * grid_h
89+
bboxs[i, 0] = (col + x_center) * grid_w
90+
bboxs[i, 1] = (row + y_center) * grid_h
8191
bboxs[i, 2] = box_w * scale_w
8292
bboxs[i, 3] = box_h * scale_h
8393
# (x_center, y_center, w, h) -> (xmin, ymin, xmax, ymax)
@@ -103,17 +113,9 @@ def deform_bboxs(pred_bboxs, data_dict):
103113
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
104114
# device = "cpu"
105115

106-
img, bndboxs, name_list, data_dict = load_data('../imgs/cucumber_9.jpg', '../imgs/cucumber_9.xml')
107-
108-
model_path = './models/checkpoint_yolo_v1_24.pth'
109-
model = YOLO_v1(S=7, B=2, C=3)
110-
model.load_state_dict(torch.load(model_path))
111-
model.eval()
112-
for param in model.parameters():
113-
param.requires_grad = False
114-
model = model.to(device)
115-
116-
# 缩放图像
116+
img, data_dict = load_data('../imgs/cucumber_9.jpg', '../imgs/cucumber_9.xml')
117+
model = load_model()
118+
# 计算
117119
outputs = model.forward(img.to(device)).cpu().squeeze(0)
118120
print(outputs.shape)
119121

@@ -140,5 +142,5 @@ def deform_bboxs(pred_bboxs, data_dict):
140142
# 预测边界框的缩放,回到原始图像
141143
pred_bboxs = deform_bboxs(pred_cate_bboxs, data_dict)
142144
# 在原图绘制标注边界框和预测边界框
143-
dst = draw.plot_bboxs(data_dict['src'], data_dict['bndboxs'], name_list, pred_bboxs, pred_cates, pred_cate_probs)
144-
draw.show(dst)
145+
dst = draw.plot_bboxs(data_dict['src'], data_dict['bndboxs'], data_dict['name_list'], pred_bboxs, pred_cates, pred_cate_probs)
146+
draw.show(dst)

0 commit comments

Comments
 (0)