-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
67 lines (51 loc) · 2.38 KB
/
train.py
File metadata and controls
67 lines (51 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Train script.
Usage:
train.py <hparams>
"""
import os
import torchvision.datasets as dset
import torch
import torchvision.utils as vutils
from docopt import docopt
from torchvision import transforms
from glow.builder import build
from glow.trainer import Trainer
from glow.config import JsonConfig
from backup_glow.utils import load_obj
from joblib import Parallel, delayed
from torch.utils.data import DataLoader
if __name__ == "__main__":
args = docopt(__doc__)
hparams_dir = args["<hparams>"]
assert os.path.exists(hparams_dir), (
"Failed to find hparams josn `{}`".format(hparams))
# dataset = args["<dataset>"]
hparams = JsonConfig(hparams_dir)
dataset = hparams.Data.dataset
dataset_root = hparams.Data.dataset_root
label_list = range(10)
def worker(label):
# load the subset data of the label
local_hparams = JsonConfig(hparams_dir)
local_hparams.Dir.log_root = os.path.join(local_hparams.Dir.log_root, "classfier{}".format(label))
dataset = load_obj(os.path.join(dataset_root, "classSets/" +"subset{}".format(label)))
if True:
tmp_dataloader = torch.utils.data.DataLoader(dataset, batch_size=64,
shuffle=True, num_workers=int(2))
img = next(iter(tmp_dataloader))
if not os.path.exists(local_hparams.Dir.log_root):
os.makedirs(local_hparams.Dir.log_root)
vutils.save_image(img.data.add(0.5), os.path.join(local_hparams.Dir.log_root, "img_under_evaluation.png"))
# dump the json file for performance evaluation
if not os.path.exists(os.path.join(local_hparams.Dir.log_root, local_hparams.Data.dataset+ ".json")):
get_hparams = JsonConfig(hparams_dir)
data_dir = get_hparams.Data.dataset_root
get_hparams.Data.dataset_root = data_dir.replace("separate", "all")
get_hparams.dump(dir_path=get_hparams.Dir.log_root,
json_name=get_hparams.Data.dataset + ".json")
### build model and train
built = build(local_hparams, True)
print(hparams.Dir.log_root)
trainer = Trainer(**built, dataset=dataset, hparams=local_hparams)
trainer.train()
Parallel(n_jobs=2, pre_dispatch="all", backend="threading")(map(delayed(worker), list(range(hparams.Data.num_classes))))