Skip to content

Commit fc0327c

Browse files
committed
style(train): 加载LocationDataset时输入类别列表
1 parent fad8f85 commit fc0327c

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

py/lib/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@
2323
B = 2
2424
C = 3
2525

26+
cate_list = ['cucumber', 'eggplant', 'mushroom']
2627

27-
def load_data(data_root_dir, S=7, B=2, C=20):
28+
29+
def load_data(data_root_dir, cate_list, S=7, B=2, C=20):
2830
transform = transforms.Compose([
2931
transforms.ToPILImage(),
3032
transforms.Resize((448, 448)),
3133
transforms.ToTensor(),
3234
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
3335
])
3436

35-
data_set = LocationDataset(data_root_dir, transform=transform, S=S, B=B, C=C)
37+
data_set = LocationDataset(data_root_dir, cate_list, transform=transform, S=S, B=B, C=C)
3638
data_loader = DataLoader(data_set, batch_size=8, shuffle=True, num_workers=8)
3739

3840
return data_loader
@@ -108,7 +110,7 @@ def train_model(data_loader, model, criterion, optimizer, lr_scheduler, num_epoc
108110
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
109111
# device = "cpu"
110112

111-
data_loader = load_data('../data/training_images', S=S, B=B, C=C)
113+
data_loader = load_data('../data/training_images', cate_list, S=S, B=B, C=C)
112114
# print(len(data_loader))
113115

114116
model = YOLO_v1(S=S, B=B, C=C)

0 commit comments

Comments
 (0)