-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreplay_buffer_test.py
More file actions
52 lines (43 loc) · 1.66 KB
/
replay_buffer_test.py
File metadata and controls
52 lines (43 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
import numpy as np
import replay_buffer
test = tf.test
log = tf.logging
FLAGS = tf.flags.FLAGS
class ReplayBufferTest(test.TestCase):
def testInit(self):
with self.test_session():
buffer = replay_buffer.ReplayBuffer()
def testAdd(self):
with self.test_session():
buffer = replay_buffer.ReplayBuffer()
for i in range(1000):
buffer.add(np.array([i, 2*i]), np.array([i % 5,]), np.array([i]), False, np.array([2*i, 3*i]))
self.assertEqual(1000, buffer.size())
def testSample(self):
with self.test_session():
buffer = replay_buffer.ReplayBuffer()
for i in range(1000):
buffer.add(np.array([i, 2*i]), np.array([i % 5,]), np.array([i]), False, np.array([2*i, 3*i]))
num_samples = 32
for j in range(50):
old_states, actions, rewards, dones, new_states = buffer.sample(num_samples)
reward_set = set()
for s in range(num_samples):
i = rewards[s][0]
self.assertNotIn(i, reward_set)
reward_set.add(i)
self.assertFalse(dones[s])
self.assertTrue((actions[s] == i % 5).all())
self.assertTrue((old_states[s] == np.array([i, 2*i])).all())
self.assertTrue((new_states[s] == np.array([2*i, 3*i])).all())
def testOverWrite(self):
with self.test_session():
buffer = replay_buffer.ReplayBuffer()
for i in range(100000):
buffer.add(np.array([i, 2*i]), np.array([i % 5,]), np.array([i]), False, np.array([2*i, 3*i]))
num_samples = 32
self.assertEqual(replay_buffer.MAX_SIZE , buffer.size())
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
test.main()