Skip to content

Commit dfa83fa

Browse files
committed
perf(model): 在cpu中保存模型权重
1 parent 619bcf1 commit dfa83fa

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

py/lib/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ def train_model(data_loader, model, criterion, optimizer, lr_scheduler, num_epoc
8989
# deep copy the model
9090
if epoch_loss < best_loss:
9191
best_loss = epoch_loss
92-
best_model_weights = copy.deepcopy(model.state_dict())
92+
best_model_weights = copy.deepcopy(model.cpu().state_dict())
93+
model = model.to(device)
9394

9495
file.check_dir('../models')
95-
# util.save_checkpoint('../models/checkpoint_yolo_v1_%d.pth' % (epoch), epoch, model, optimizer, loss)
96-
file.save_model(model, '../models/checkpoint_yolo_v1_%d.pth' % (epoch))
96+
file.save_model(best_model_weights, '../models/checkpoint_yolo_v1_%d.pth' % (epoch))
9797
print('save model')
9898

9999
print()

py/lib/utils/file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def parse_location_xml(xml_path):
7272
return np.array(bndboxs), name_list
7373

7474

75-
def save_model(model, model_save_path):
76-
torch.save(model.state_dict().cpu(), model_save_path)
75+
def save_model(model_weights, model_save_path):
76+
torch.save(model_weights, model_save_path)
7777

7878

7979
def save_checkpoint(model_save_path, epoch, model, optimizer, loss):

0 commit comments

Comments
 (0)