@@ -43,22 +43,32 @@ def load_data(img_path, xml_path):
4343 ratio_h = scale_h / h
4444 ratio_w = scale_w / w
4545
46- scale_bndboxs = torch .from_numpy (bndboxs ).float ()
47- scale_bndboxs [:, 0 ] = scale_bndboxs [:, 0 ] * ratio_w
48- scale_bndboxs [:, 1 ] = scale_bndboxs [:, 1 ] * ratio_h
49- scale_bndboxs [:, 2 ] = scale_bndboxs [:, 2 ] * ratio_w
50- scale_bndboxs [:, 3 ] = scale_bndboxs [:, 3 ] * ratio_h
51- scale_bndboxs = scale_bndboxs .int ().numpy ()
52-
46+ # [C, H, W] -> [N, C, H, W]
5347 img = img .unsqueeze (0 )
48+
5449 data_dict = {}
5550 data_dict ['src' ] = src
5651 data_dict ['src_size' ] = (h , w )
5752 data_dict ['bndboxs' ] = bndboxs
53+ data_dict ['name_list' ] = name_list
54+
5855 data_dict ['img' ] = img
5956 data_dict ['scale_size' ] = (scale_h , scale_w )
6057 data_dict ['ratio' ] = (ratio_h , ratio_w )
61- return img , scale_bndboxs , name_list , data_dict
58+
59+ return img , data_dict
60+
61+
62+ def load_model ():
63+ model_path = './models/checkpoint_yolo_v1_24.pth'
64+ model = YOLO_v1 (S = 7 , B = 2 , C = 3 )
65+ model .load_state_dict (torch .load (model_path ))
66+ model .eval ()
67+ for param in model .parameters ():
68+ param .requires_grad = False
69+ model = model .to (device )
70+
71+ return model
6272
6373
6474def deform_bboxs (pred_bboxs , data_dict ):
@@ -76,8 +86,8 @@ def deform_bboxs(pred_bboxs, data_dict):
7686 col = int (i % S )
7787
7888 x_center , y_center , box_w , box_h = pred_bboxs [i ]
79- bboxs [i , 0 ] = (row + x_center ) * grid_w
80- bboxs [i , 1 ] = (col + y_center ) * grid_h
89+ bboxs [i , 0 ] = (col + x_center ) * grid_w
90+ bboxs [i , 1 ] = (row + y_center ) * grid_h
8191 bboxs [i , 2 ] = box_w * scale_w
8292 bboxs [i , 3 ] = box_h * scale_h
8393 # (x_center, y_center, w, h) -> (xmin, ymin, xmax, ymax)
@@ -103,17 +113,9 @@ def deform_bboxs(pred_bboxs, data_dict):
103113 device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
104114 # device = "cpu"
105115
106- img , bndboxs , name_list , data_dict = load_data ('../imgs/cucumber_9.jpg' , '../imgs/cucumber_9.xml' )
107-
108- model_path = './models/checkpoint_yolo_v1_24.pth'
109- model = YOLO_v1 (S = 7 , B = 2 , C = 3 )
110- model .load_state_dict (torch .load (model_path ))
111- model .eval ()
112- for param in model .parameters ():
113- param .requires_grad = False
114- model = model .to (device )
115-
116- # 缩放图像
116+ img , data_dict = load_data ('../imgs/cucumber_9.jpg' , '../imgs/cucumber_9.xml' )
117+ model = load_model ()
118+ # 计算
117119 outputs = model .forward (img .to (device )).cpu ().squeeze (0 )
118120 print (outputs .shape )
119121
@@ -140,5 +142,5 @@ def deform_bboxs(pred_bboxs, data_dict):
140142 # 预测边界框的缩放,回到原始图像
141143 pred_bboxs = deform_bboxs (pred_cate_bboxs , data_dict )
142144 # 在原图绘制标注边界框和预测边界框
143- dst = draw .plot_bboxs (data_dict ['src' ], data_dict ['bndboxs' ], name_list , pred_bboxs , pred_cates , pred_cate_probs )
144- draw .show (dst )
145+ dst = draw .plot_bboxs (data_dict ['src' ], data_dict ['bndboxs' ], data_dict [ ' name_list' ] , pred_bboxs , pred_cates , pred_cate_probs )
146+ draw .show (dst )
0 commit comments