Skip to content
9 changes: 9 additions & 0 deletions Vision/classification/image/resnet50/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,15 @@ def parse_args(ignore_unknown_args=False):
dest="print_interval",
help="print loss every n iteration",
)

parser.add_argument(
"--exit-num",
type=int,
default=301,
dest="exit_num",
help="exit iter",
)

parser.add_argument(
"--print-timestamp", action="store_true", dest="print_timestamp",
)
Expand Down
5 changes: 4 additions & 1 deletion Vision/classification/image/resnet50/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def train(self):
acc = 0

save_dir = f"epoch_{self.cur_epoch}_val_acc_{acc}"
self.save(save_dir)
# self.save(save_dir)
self.cur_epoch += 1
self.cur_iter = 0

Expand All @@ -249,6 +249,9 @@ def train_one_epoch(self):

self.cur_iter += 1

if self.cur_iter == self.exit_num:
exit(0)

loss = tol(loss, self.metric_local)
if pred is not None and label is not None:
pred = tol(pred, self.metric_local)
Expand Down