@@ -60,7 +60,7 @@ def load_data(img_path, xml_path):
6060
6161
6262def load_model (device ):
63- model_path = './models/checkpoint_yolo_v1_49 .pth'
63+ model_path = './models/checkpoint_yolo_v1 .pth'
6464 model = YOLO_v1 (S = 7 , B = 2 , C = 3 )
6565 model .load_state_dict (torch .load (model_path ))
6666 model .eval ()
@@ -110,8 +110,8 @@ def deform_bboxs(pred_bboxs, data_dict):
110110
111111
112112if __name__ == '__main__' :
113- device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
114- # device = "cpu"
113+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
114+ device = "cpu"
115115
116116 img , data_dict = load_data ('../imgs/cucumber_9.jpg' , '../imgs/cucumber_9.xml' )
117117 model = load_model (device )
@@ -142,6 +142,9 @@ def deform_bboxs(pred_bboxs, data_dict):
142142 # 预测边界框的缩放,回到原始图像
143143 pred_bboxs = deform_bboxs (pred_cate_bboxs , data_dict )
144144 # 在原图绘制标注边界框和预测边界框
145- dst = draw .plot_bboxs (data_dict ['src' ], data_dict ['bndboxs' ], data_dict ['name_list' ], pred_bboxs , pred_cates , pred_cate_probs )
146- # cv2.imwrite('./detect.png', dst)
145+ dst = draw .plot_bboxs (data_dict ['src' ], data_dict ['bndboxs' ], data_dict ['name_list' ], pred_bboxs , pred_cates ,
146+ pred_cate_probs )
147+ cv2 .imwrite ('./detect.png' , dst )
148+ # BGR -> RGB
149+ dst = cv2 .cvtColor (dst , cv2 .COLOR_BGR2RGB )
147150 draw .show (dst )
0 commit comments