-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
38 lines (29 loc) · 1.45 KB
/
train.py
File metadata and controls
38 lines (29 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from conditional_gan import make_generator, make_discriminator, CGAN
import cmd
from gan.train import Trainer
import keras.backend as K
from pose_dataset import PoseHMDataset
def main():
args = cmd.args()
generator = make_generator(args.image_size, args.nb_inputs, args.use_input_pose, args.warp_skip, args.disc_type, args.warp_agg,
args.use_bg, args.pose_rep_type,args.fusion_type,args.return_att,args.nb_rec,args.dmax,args.kernel_size_last,args.use3D)
if args.generator_checkpoint is not None:
generator.load_weights(args.generator_checkpoint,by_name=True)
if args.fusion_type in ["avg","att_simple"]:
nbadditional=0
assert(args.nb_rec==0)
elif args.fusion_type in ["att_dec"]:
nbadditional=args.nb_inputs
assert(args.nb_rec==1)
else:
nbadditional=args.nb_inputs+(args.nb_rec-1)
discriminator = make_discriminator(args.image_size,args.nb_inputs, args.use_input_pose, args.warp_skip, args.disc_type,
args.warp_agg, args.use_bg, args.pose_rep_type,args.return_att,nbadditional)
if args.discriminator_checkpoint is not None:
discriminator.load_weights(args.discriminator_checkpoint)
dataset = PoseHMDataset(test_phase=False, **vars(args))
gan = CGAN(generator, discriminator, **vars(args))
trainer = Trainer(dataset, gan, **vars(args))
trainer.train()
if __name__ == "__main__":
main()