Skip to content

Commit 0aa8da3

Browse files
committed
refactor(detect): 修改模型名
1 parent 8966736 commit 0aa8da3

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

py/batch_detect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def parse_data(img_path, xml_path, transform):
7777

7878

7979
def load_model(device):
80-
model_path = './models/checkpoint_yolo_v1_49.pth'
80+
model_path = './models/checkpoint_yolo_v1.pth'
8181
model = YOLO_v1(S=7, B=2, C=3)
8282
model.load_state_dict(torch.load(model_path))
8383
model.eval()

py/detector.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def load_data(img_path, xml_path):
6060

6161

6262
def load_model(device):
63-
model_path = './models/checkpoint_yolo_v1_49.pth'
63+
model_path = './models/checkpoint_yolo_v1.pth'
6464
model = YOLO_v1(S=7, B=2, C=3)
6565
model.load_state_dict(torch.load(model_path))
6666
model.eval()
@@ -110,8 +110,8 @@ def deform_bboxs(pred_bboxs, data_dict):
110110

111111

112112
if __name__ == '__main__':
113-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
114-
# device = "cpu"
113+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
114+
device = "cpu"
115115

116116
img, data_dict = load_data('../imgs/cucumber_9.jpg', '../imgs/cucumber_9.xml')
117117
model = load_model(device)
@@ -142,6 +142,9 @@ def deform_bboxs(pred_bboxs, data_dict):
142142
# 预测边界框的缩放,回到原始图像
143143
pred_bboxs = deform_bboxs(pred_cate_bboxs, data_dict)
144144
# 在原图绘制标注边界框和预测边界框
145-
dst = draw.plot_bboxs(data_dict['src'], data_dict['bndboxs'], data_dict['name_list'], pred_bboxs, pred_cates, pred_cate_probs)
146-
# cv2.imwrite('./detect.png', dst)
145+
dst = draw.plot_bboxs(data_dict['src'], data_dict['bndboxs'], data_dict['name_list'], pred_bboxs, pred_cates,
146+
pred_cate_probs)
147+
cv2.imwrite('./detect.png', dst)
148+
# BGR -> RGB
149+
dst = cv2.cvtColor(dst, cv2.COLOR_BGR2RGB)
147150
draw.show(dst)

0 commit comments

Comments
 (0)