Skip to content

Commit 4406158

Browse files
committed
fix(dataset): 边界框中心对应的网格计算
1 parent c173739 commit 4406158

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

py/lib/models/location_dataset.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
from torch.utils.data import DataLoader
1414
from torch.utils.data import Dataset
1515
from utils import file
16+
from utils import util
1617
import torchvision.transforms as transforms
1718

18-
cate_list = ['cucumber', 'eggplant', 'mushroom']
19-
2019

2120
class LocationDataset(Dataset):
2221

23-
def __init__(self, root_dir, transform=None, S=7, B=2, C=20):
22+
def __init__(self, root_dir, cate_list, transform=None, S=7, B=2, C=20):
2423
"""
2524
保存图像以及标注框性能
2625
:param root_dir: 根目录
@@ -34,6 +33,7 @@ def __init__(self, root_dir, transform=None, S=7, B=2, C=20):
3433
self.S = S
3534
self.B = B
3635
self.C = C
36+
self.cate_list = cate_list
3737

3838
jpeg_path_list = []
3939
xml_path_list = []
@@ -54,11 +54,13 @@ def __getitem__(self, index):
5454
assert index < len(self.jpeg_path_list), 'image length: %d' % len(self.jpeg_path_list)
5555

5656
# print(self.jpeg_path_list[index])
57-
image = cv2.imread(self.jpeg_path_list[index])
57+
img_path = self.jpeg_path_list[index]
58+
image = cv2.imread(img_path)
5859
img_h, img_w = image.shape[:2]
59-
ratio_h = 1
60-
ratio_w = 1
60+
ratio_h = 1.0
61+
ratio_w = 1.0
6162
if self.transform:
63+
# [H, W, C] -> [C, H, W]
6264
image = self.transform(image)
6365
# 计算图像缩放比例
6466
dst_img_h, dst_img_w = image.shape[1:3]
@@ -73,7 +75,8 @@ def __getitem__(self, index):
7375

7476
target = torch.zeros((self.S * self.S, self.C + self.B * 5))
7577
bndboxs, name_list = file.parse_location_xml(self.xml_path_list[index])
76-
# 缩放边界框坐标(x, y, w, h)
78+
bndboxs = util.bbox_corner_to_center(bndboxs)
79+
# 缩放边界框坐标(x_center, y_center, w, h)
7780
bndboxs[:, 0] = bndboxs[:, 0] * ratio_w
7881
bndboxs[:, 1] = bndboxs[:, 1] * ratio_h
7982
bndboxs[:, 2] = bndboxs[:, 2] * ratio_w
@@ -89,6 +92,8 @@ def __getitem__(self, index):
8992
# 边界框中心位于哪个网格
9093
grid_x = int(box_x / grid_w)
9194
grid_y = int(box_y / grid_h)
95+
# 行/列从0开始计数
96+
print(grid_x + 1, grid_y + 1)
9297
# 边界框中心相对于网格的比例(0,1)
9398
x = (box_x % grid_w) / grid_w
9499
y = (box_y % grid_h) / grid_h
@@ -97,22 +102,24 @@ def __getitem__(self, index):
97102
h = box_h / img_h
98103
# 该网格内是否已填充(每个网格1个标注边界框)
99104
if grid_nums[grid_x, grid_y] > 1:
100-
print('网格(%d, %d)已填充:%s' % (grid_x, grid_y, str(target[grid_x, grid_y])))
105+
print('网格(%d, %d)已填充:%s' % (grid_x, grid_y, img_path))
101106
else:
107+
grid_nums[grid_x, grid_y] = 1
108+
102109
# 转换类别和标签
103-
cate_idx = cate_list.index(name)
110+
cate_idx = self.cate_list.index(name)
111+
# 指定网格
112+
grid_idx = self.S * grid_y + grid_x
104113
# 指定类别概率为1
105-
target[grid_x * grid_y, cate_idx] = 1
114+
target[grid_idx, cate_idx] = 1
106115
for j in range(self.B):
107116
# 置信度
108-
target[grid_x * grid_y, self.C + j] = 1
117+
target[grid_idx, self.C + j] = 1
109118
# 相应的边界框坐标
110-
target[grid_x * grid_y, self.C + self.B + 4 * j] = x
111-
target[grid_x * grid_y, self.C + self.B + 4 * j + 1] = y
112-
target[grid_x * grid_y, self.C + self.B + 4 * j + 2] = w
113-
target[grid_x * grid_y, self.C + self.B + 4 * j + 3] = h
114-
115-
grid_nums[grid_x, grid_y] += 1
119+
target[grid_idx, self.C + self.B + 4 * j] = x
120+
target[grid_idx, self.C + self.B + 4 * j + 1] = y
121+
target[grid_idx, self.C + self.B + 4 * j + 2] = w
122+
target[grid_idx, self.C + self.B + 4 * j + 3] = h
116123

117124
return image, target
118125

@@ -129,17 +136,18 @@ def __len__(self):
129136
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
130137
])
131138

132-
data_set = LocationDataset(root_dir, transform, 7, 2, 3)
133-
# print(data_set)
134-
# print(len(data_set))
135-
#
136-
# image, target = data_set.__getitem__(3)
137-
# print(image.shape)
138-
# print(target.shape)
139-
# print(target)
140-
141-
data_loader = DataLoader(data_set, shuffle=True, batch_size=8, num_workers=8)
142-
items = next(iter(data_loader))
143-
inputs, labels = items
144-
print(inputs.shape)
145-
print(labels.shape)
139+
cate_list = ['cucumber', 'eggplant', 'mushroom']
140+
data_set = LocationDataset(root_dir, cate_list, transform, 7, 2, 3)
141+
print(data_set)
142+
print(len(data_set))
143+
144+
image, target = data_set.__getitem__(3)
145+
print(image.shape)
146+
print(target.shape)
147+
print(target)
148+
149+
# data_loader = DataLoader(data_set, shuffle=True, batch_size=8, num_workers=8)
150+
# items = next(iter(data_loader))
151+
# inputs, labels = items
152+
# print(inputs.shape)
153+
# print(labels.shape)

py/lib/utils/file.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ def parse_location_xml(xml_path):
6969
else:
7070
pass
7171

72-
bndboxs = np.array(bndboxs)
73-
return bndboxs, name_list
72+
return np.array(bndboxs), name_list
7473

7574

7675
def save_model(model, model_save_path):

0 commit comments

Comments
 (0)