1313from torch .utils .data import DataLoader
1414from torch .utils .data import Dataset
1515from utils import file
16+ from utils import util
1617import torchvision .transforms as transforms
1718
18- cate_list = ['cucumber' , 'eggplant' , 'mushroom' ]
19-
2019
2120class LocationDataset (Dataset ):
2221
23- def __init__ (self , root_dir , transform = None , S = 7 , B = 2 , C = 20 ):
22+ def __init__ (self , root_dir , cate_list , transform = None , S = 7 , B = 2 , C = 20 ):
2423 """
2524 保存图像以及标注框性能
2625 :param root_dir: 根目录
@@ -34,6 +33,7 @@ def __init__(self, root_dir, transform=None, S=7, B=2, C=20):
3433 self .S = S
3534 self .B = B
3635 self .C = C
36+ self .cate_list = cate_list
3737
3838 jpeg_path_list = []
3939 xml_path_list = []
@@ -54,11 +54,13 @@ def __getitem__(self, index):
5454 assert index < len (self .jpeg_path_list ), 'image length: %d' % len (self .jpeg_path_list )
5555
5656 # print(self.jpeg_path_list[index])
57- image = cv2 .imread (self .jpeg_path_list [index ])
57+ img_path = self .jpeg_path_list [index ]
58+ image = cv2 .imread (img_path )
5859 img_h , img_w = image .shape [:2 ]
59- ratio_h = 1
60- ratio_w = 1
60+ ratio_h = 1.0
61+ ratio_w = 1.0
6162 if self .transform :
63+ # [H, W, C] -> [C, H, W]
6264 image = self .transform (image )
6365 # 计算图像缩放比例
6466 dst_img_h , dst_img_w = image .shape [1 :3 ]
@@ -73,7 +75,8 @@ def __getitem__(self, index):
7375
7476 target = torch .zeros ((self .S * self .S , self .C + self .B * 5 ))
7577 bndboxs , name_list = file .parse_location_xml (self .xml_path_list [index ])
76- # 缩放边界框坐标(x, y, w, h)
78+ bndboxs = util .bbox_corner_to_center (bndboxs )
79+ # 缩放边界框坐标(x_center, y_center, w, h)
7780 bndboxs [:, 0 ] = bndboxs [:, 0 ] * ratio_w
7881 bndboxs [:, 1 ] = bndboxs [:, 1 ] * ratio_h
7982 bndboxs [:, 2 ] = bndboxs [:, 2 ] * ratio_w
@@ -89,6 +92,8 @@ def __getitem__(self, index):
8992 # 边界框中心位于哪个网格
9093 grid_x = int (box_x / grid_w )
9194 grid_y = int (box_y / grid_h )
95+ # 行/列从0开始计数
96+ print (grid_x + 1 , grid_y + 1 )
9297 # 边界框中心相对于网格的比例(0,1)
9398 x = (box_x % grid_w ) / grid_w
9499 y = (box_y % grid_h ) / grid_h
@@ -97,22 +102,24 @@ def __getitem__(self, index):
97102 h = box_h / img_h
98103 # 该网格内是否已填充(每个网格1个标注边界框)
99104 if grid_nums [grid_x , grid_y ] > 1 :
100- print ('网格(%d, %d)已填充:%s' % (grid_x , grid_y , str ( target [ grid_x , grid_y ]) ))
105+ print ('网格(%d, %d)已填充:%s' % (grid_x , grid_y , img_path ))
101106 else :
107+ grid_nums [grid_x , grid_y ] = 1
108+
102109 # 转换类别和标签
103- cate_idx = cate_list .index (name )
110+ cate_idx = self .cate_list .index (name )
111+ # 指定网格
112+ grid_idx = self .S * grid_y + grid_x
104113 # 指定类别概率为1
105- target [grid_x * grid_y , cate_idx ] = 1
114+ target [grid_idx , cate_idx ] = 1
106115 for j in range (self .B ):
107116 # 置信度
108- target [grid_x * grid_y , self .C + j ] = 1
117+ target [grid_idx , self .C + j ] = 1
109118 # 相应的边界框坐标
110- target [grid_x * grid_y , self .C + self .B + 4 * j ] = x
111- target [grid_x * grid_y , self .C + self .B + 4 * j + 1 ] = y
112- target [grid_x * grid_y , self .C + self .B + 4 * j + 2 ] = w
113- target [grid_x * grid_y , self .C + self .B + 4 * j + 3 ] = h
114-
115- grid_nums [grid_x , grid_y ] += 1
119+ target [grid_idx , self .C + self .B + 4 * j ] = x
120+ target [grid_idx , self .C + self .B + 4 * j + 1 ] = y
121+ target [grid_idx , self .C + self .B + 4 * j + 2 ] = w
122+ target [grid_idx , self .C + self .B + 4 * j + 3 ] = h
116123
117124 return image , target
118125
@@ -129,17 +136,18 @@ def __len__(self):
129136 transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
130137 ])
131138
132- data_set = LocationDataset (root_dir , transform , 7 , 2 , 3 )
133- # print(data_set)
134- # print(len(data_set))
135- #
136- # image, target = data_set.__getitem__(3)
137- # print(image.shape)
138- # print(target.shape)
139- # print(target)
140-
141- data_loader = DataLoader (data_set , shuffle = True , batch_size = 8 , num_workers = 8 )
142- items = next (iter (data_loader ))
143- inputs , labels = items
144- print (inputs .shape )
145- print (labels .shape )
139+ cate_list = ['cucumber' , 'eggplant' , 'mushroom' ]
140+ data_set = LocationDataset (root_dir , cate_list , transform , 7 , 2 , 3 )
141+ print (data_set )
142+ print (len (data_set ))
143+
144+ image , target = data_set .__getitem__ (3 )
145+ print (image .shape )
146+ print (target .shape )
147+ print (target )
148+
149+ # data_loader = DataLoader(data_set, shuffle=True, batch_size=8, num_workers=8)
150+ # items = next(iter(data_loader))
151+ # inputs, labels = items
152+ # print(inputs.shape)
153+ # print(labels.shape)
0 commit comments