-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
85 lines (69 loc) · 2.37 KB
/
train.py
File metadata and controls
85 lines (69 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from parl.algorithms import DQN
from pad import Paddle
from model import PadModel
from agent import PadAgent
from parl.utils import logger
from replay_memory import ReplayMemory
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES']=''
LEARN_FREQ = 5 # update parameters every 5 steps
MEMORY_SIZE = 20000 # replay memory size
MEMORY_WARMUP_SIZE = 200 # store some experiences in the replay memory in advance
BATCH_SIZE = 32
LEARNING_RATE = 0.001
GAMMA = 0.999 # discount factor of reward
def run_eposide(agent,env,rpm):
total_reward=0
obs=env.reset()
step=0
while True:
step+=1
action=agent.sample(obs)
reward,next_obs,done=env.step(action)
rpm.append((obs,action,reward,next_obs,done))
if(len(rpm)>MEMORY_WARMUP_SIZE) and (step%LEARN_FREQ):
batch_obs,batch_action,batch_reward,batch_next_obs,batch_done=rpm.sample(BATCH_SIZE)
train_loss=agent.learn(batch_obs,batch_action,batch_reward,batch_next_obs,batch_done)
total_reward+=reward
obs=next_obs
if done:
break
#print(step)
return total_reward
def evaluate(agent,env):
eval_reward=[]
for _ in range(5):
total_reward=0
obs=env.reset()
done=False
while not done:
action=agent.predict(obs)
reward,next_obs,done=env.step(action)
total_reward+=reward
obs=next_obs
eval_reward.append(total_reward)
return np.mean(eval_reward)
def main():
env=Paddle()
action_dims=3
obs_dims=5
rpm=ReplayMemory(MEMORY_SIZE)
model=PadModel(action_dims)
algorithm=DQN(model,action_dims,GAMMA,LEARNING_RATE)
agent=PadAgent(algorithm,obs_dim=obs_dims,act_dim=action_dims)
# use this to restore your model
# agent.restore('./dqn_model.ckpt')
while len(rpm)<MEMORY_WARMUP_SIZE:
run_eposide(agent,env,rpm)
max_eposide=1000
for eposide in range(1,max_eposide+1):
total_reward=run_eposide(agent,env,rpm)
print(total_reward)
if eposide%50==0:
eval_reward=evaluate(agent,env)
logger.info('eposide:{},test_reward:{}'.format(eposide,eval_reward))
save_path = './dqn_model.ckpt'
agent.save(save_path)
if __name__=='__main__':
main()