-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathatari.py
More file actions
69 lines (57 loc) · 2.2 KB
/
atari.py
File metadata and controls
69 lines (57 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import tensorflow as tf
import numpy as np
import gym
import matplotlib.pyplot as plt
from agent import Agent
from savgol_filter import savgol_filter
from environments.gym import GymEnvironment
# from environments.flappy_bird import FlappyBird
# from environments.coin_collector import CoinCollector
if __name__ == '__main__':
env = GymEnvironment(gym.make('LunarLander-v2'))
# env = FlappyBird()
# env = CoinCollector()
n_games = int(input('How many games should the AI train on? '))
agent = Agent(
gamma=0.99,
epsilon=1,
alpha=0.0005,
input_dims=env.len_of_state(),
num_of_actions=env.num_of_actions(),
mem_size=1000000,
batch_size=64,
epsilon_decay=0.999,
epsilon_min=0.01,
random_action_func=env.random_action,
model_file="./models/" + input('Name the file the AI should save its brain? ') + '.h5'
)
brain_file = input("What file should the AI load in a brain from a file? ")
if brain_file != "" and brain_file.lower() != "none":
agent.load_model("./models/" + brain_file + ".h5")
scores = []
for i in range(1, n_games + 1):
done = False
score = 0
observation = env.reset()
while not done:
action = agent.choose_action(observation)
next_observation, reward, done, info = env.step(action)
score += reward
agent.remember(observation, action, reward, next_observation, done)
observation = next_observation
agent.learn()
# if i % 10 == 1:
# env.render()
scores.append(score)
avg_score = np.mean(scores[max(0, i - 100): i + 1])
print('Episode ', i, 'Score %.2f' % score, 'Average Score %.2f' % avg_score)
if i % 10 == 0:
agent.save_model()
if len(scores) > 10 and scores[-3] == 500 and scores[-2] == 500 and scores[-1] == 500:
break
agent.save_model()
plt.plot(scores, label='Scores Over Iterations')
plt.plot(savgol_filter(scores, n_games / 2, 4), label='Savgol Filter Smoothing')
plt.legend()
plt.savefig("./graphs/" + agent.model_file.split("/")[2][:-3] + '-scores.png')
plt.show()