Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions squiRL/common/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from typing import Tuple, List
from squiRL.common.data_stream import Experience


Expand All @@ -17,25 +17,34 @@ class Agent:
env: training environment

Attributes:
env (gym.Env): OpenAI gym training environment
env (List[gym.Env]): List of OpenAI gym training environment
obs (int): Array of env observation state
replay_buffer (TYPE): Data collector for saving experience
"""
def __init__(self, env: gym.Env, replay_buffer) -> None:
def __init__(self, env: List[gym.Env], replay_buffer) -> None:
"""Initializes agent class

Args:
env (gym.Env): OpenAI gym training environment
env (List[gym.Env]): List of OpenAI gym training environment
replay_buffer (TYPE): Data collector for saving experience
"""
self.env = env
self.envs = env
self.obs = [None] * len(self.envs)
self.replay_buffer = replay_buffer
self.reset()

for i in range(len(self.envs)):
self.env_idx = i
self.reset()

self.env_idx = 0

def reset(self) -> None:
"""Resets the environment and updates the obs
"""
self.obs = self.env.reset()
self.obs[self.env_idx] = self.envs[self.env_idx].reset()

def next_env(self) -> None:
self.env_idx = (self.env_idx + 1) % len(self.envs)

def process_obs(self, obs: int) -> torch.Tensor:
"""Converts obs np.array to torch.Tensor for passing through NN
Expand All @@ -61,7 +70,9 @@ def get_action(
Returns:
action (int): Action to be carried out
"""
obs = self.process_obs(self.obs)
obs = self.obs[self.env_idx]
assert obs is not None
obs = self.process_obs(obs)

action_logit = net(obs)
probs = F.softmax(action_logit, dim=-1)
Expand Down Expand Up @@ -89,11 +100,11 @@ def play_step(
action = self.get_action(net)

# do step in the environment
new_obs, reward, done, _ = self.env.step(action)
exp = Experience(self.obs, action, reward, done, new_obs)
new_obs, reward, done, _ = self.envs[self.env_idx].step(action)
exp = Experience(self.obs[self.env_idx], action, reward, done, new_obs)
self.replay_buffer.append(exp)

self.obs = new_obs
self.obs[self.env_idx] = new_obs
if done:
self.reset()
return reward, done
90 changes: 64 additions & 26 deletions squiRL/common/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
Attributes:
Experience (namedtuple): An environment step experience
"""
import numpy as np
from torch.utils.data.dataset import IterableDataset
from collections import deque
from collections import namedtuple
from squiRL.common.policies import MLP
import gym
from typing import Tuple

from torch.utils.data.dataset import IterableDataset
import torch.multiprocessing as mp
import torch

from squiRL.common.policies import MLP

Experience = namedtuple('Experience',
('state', 'action', 'reward', 'done', 'last_state'))

Expand All @@ -27,50 +29,85 @@ class RolloutCollector:
capacity (int): Size of the buffer
replay_buffer (deque): Experience buffer
"""
def __init__(self, capacity: int) -> None:
def __init__(self, capacity: int, state_shape: tuple, action_shape: tuple, should_share: bool = False) -> None:
"""Summary

Args:
capacity (int): Description
"""

state_shape = [capacity] + list(state_shape)
action_shape = [capacity] + list(action_shape)

self.capacity = capacity
self.replay_buffer = deque(maxlen=self.capacity)
self.count = torch.tensor([0], dtype=torch.int64)
self.states = torch.zeros(state_shape, dtype=torch.float32)
self.actions = torch.zeros(action_shape, dtype=torch.float32)
self.rewards = torch.zeros((capacity), dtype=torch.float32)
self.dones = torch.zeros((capacity), dtype=torch.bool)
self.next_states = torch.zeros(state_shape, dtype=torch.float32)

if should_share:
self.count.share_memory_()
self.states.share_memory_()
self.actions.share_memory_()
self.next_states.share_memory_()
self.rewards.share_memory_()
self.dones.share_memory_()

self.lock = mp.Lock()

def __len__(self) -> int:
"""Calculates length of buffer

Returns:
int: Length of buffer
"""
return len(self.replay_buffer)
return self.count.detach().numpy().item()

def append(self, experience: Experience) -> None:
"""
Add experience to the buffer

Args:
experience (Experience): Tuple (state, action, reward, done,
new_state)
last_state)
"""
self.replay_buffer.append(experience)

with self.lock:
if self.count[0] < self.capacity:
self.count[0] += 1

# count keeps the exact length, but indexing starts from 0 so we decrease by 1
nr = self.count[0] - 1

self.states[nr] = torch.tensor(experience.state, dtype=torch.float32)
self.actions[nr] = torch.tensor(experience.action, dtype=torch.float32)
self.rewards[nr] = torch.tensor(experience.reward, dtype=torch.float32)
self.dones[nr] = torch.tensor(experience.done, dtype=torch.bool)
self.next_states[nr] = torch.tensor(experience.last_state, dtype=torch.float32)

else:
exit("RolloutCollector: Buffer is full but samples are being added to it")


def sample(self) -> Tuple:
"""Sample experience from buffer

Returns:
Tuple: Sampled experience
"""
states, actions, rewards, dones, next_states = zip(
*[self.replay_buffer[i] for i in range(len(self.replay_buffer))])

return (np.array(states), np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(dones, dtype=np.bool), np.array(next_states))
# count keeps the exact length, but indexing starts from 0 so we decrease by 1
nr = self.count[0] - 1
return (self.states[:nr], self.actions[:nr], self.rewards[:nr], self.dones[:nr], self.next_states[:nr])

def empty_buffer(self) -> None:
"""Empty replay buffer
"""Empty replay buffer by resetting the count (so old data gets overwritten)
"""
self.replay_buffer.clear()
with self.lock:
# the [0] is very important, otherwise we throw the tensor out and the int that replaces it won't get shared
self.count[0] = 0


class RLDataset(IterableDataset):
Expand All @@ -84,24 +121,23 @@ class RLDataset(IterableDataset):

Attributes:
agent (Agent): Agent that interacts with env
env (gym.Env): OpenAI gym environment
net (nn.Module): Policy network
replay_buffer: Replay buffer
"""
def __init__(self, replay_buffer: RolloutCollector, env: gym.Env, net: MLP,
agent) -> None:

def __init__(self, replay_buffer: RolloutCollector, net: MLP,
agent, episodes_per_batch: int = 1) -> None:
"""Summary

Args:
replay_buffer (RolloutCollector): Description
env (gym.Env): OpenAI gym environment
net (nn.Module): Policy network
agent (Agent): Agent that interacts with env
"""
self.replay_buffer = replay_buffer
self.env = env
self.net = net
self.agent = agent
self.episodes_per_batch = episodes_per_batch

def populate(self) -> None:
"""
Expand All @@ -119,8 +155,10 @@ def __iter__(self):
Yields:
Tuple: Sampled experience
"""
for i in range(1):
self.populate()
states, actions, rewards, dones, new_states = self.replay_buffer.sample(
)
yield (states, actions, rewards, dones, new_states)
for i in range(self.episodes_per_batch):
for j in range(len(self.agent.envs)):
self.agent.env_idx = j
self.populate()
states, actions, rewards, dones, new_states = self.replay_buffer.sample(
)
yield (states, actions, rewards, dones, new_states)
Loading