-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathparameters.py
More file actions
77 lines (68 loc) · 2.59 KB
/
parameters.py
File metadata and controls
77 lines (68 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#######
# parameters for training.
#######
import os
class Args:
# give a name to this task or a set of parameters
# models will be saved in this folder
# name = "top20classes_loss_weight_sgd_moreepochs"
name="remainingclasses_loss_weight_sgd_moreepochs"
batch_size = 16
data_dir = "D:\\Projects\\Datasets\\fashion-larger"
train_csv = "remainingclasses_set.csv"
learning_rate = 0.0007
momentum = 0.9
epochs = 5
optimizer = "sgd"
# checkpoint to start training from
checkpoint = "D:\\Projects\\Datasets\\fashion-larger\\model_output\\remainingclasses_loss_weight_sgd\\fashion-4.pth"
# checkpoint = None
# save model after every X epochs
save_model = 1
output_dir = "D:\\Projects\\Datasets\\fashion-larger\\model_output"
base_model = "resnet50"
# num class for model
num_classes = 88 # top 20 classes
# num classes different than the pre-trained model (checkpoint)
diff_num_class = False # default = False
# calculate avg loss for X minibatches
avg_loss_batch = 100
oversample=False #over sample to handle class imbalance
weighted_loss = True # oversampling takes priority over weighted loss.
# weighted loss schemes
# 1. weight = 1/class_size
# 2. max(class frequency)/freq of given class
weighted_loss_scheme = 2
# if True then only throws warning whenever possible
be_nice = False
tensorboard = True
def __init__(self):
self.verify_args()
def verify_args(self):
if not os.path.exists(self.data_dir):
raise ValueError("Invalid path for data_dir. Directory does not exists!")
if not os.path.exists(os.path.join( self.output_dir, self.name)):
os.makedirs(os.path.join(Args.output_dir, self.name))
self.save_settings()
else:
if not any(fname.endswith('.pth') for fname in os.listdir(os.path.join(self.output_dir, self.name))):
self.save_settings()
return
while True:
user_choice = input("Directory with the choosen task name already exists. Do you want to use checkpoint from this directory (y/n)?")
if user_choice == "y":
#overwrite checkpoint with the one in this dir
self.checkpoint = self.find_latest_checkpoint()
break
elif user_choice == "n":
break
else:
print("please enter a valid option.(y/n)")
def save_settings(self):
with open(os.path.join(Args.output_dir, self.name, "config.txt"), "w") as cfg:
cfg.write("####################################################\n")
cfg.write("# Settings for Task: {} #\n\n".format(Args.name))
for var in dir(Args):
if not var.startswith("__") and not callable(getattr(Args, var)):
cfg.write("{} : {} \n".format(var, getattr(Args,var)))
_ = Args()