Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions config/four_rooms.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
[base]
env_name = four_rooms

[vec]
total_agents = 4096
num_buffers = 2
num_threads = 8

[env]
size = 19
# if 0, max_steps = 4 * size. Positive values override it.
max_steps = 0

[policy]
hidden_size = 128
num_layers = 2
expansion_factor = 1

[train]
total_timesteps = 100_000_000
gamma = 0.99
gae_lambda = 0.95
learning_rate = 0.005
minibatch_size = 32768
horizon = 64
ent_coef = 0.01

[sweep]
metric = score
metric_distribution = linear
goal = maximize
max_runs = 100
gpus = 1
downsample = 5
sweep_only = hidden_size,num_layers,total_timesteps,learning_rate

[sweep.policy.hidden_size]
distribution = uniform_pow2
min = 64
max = 1024
mean = 256
scale = auto

[sweep.policy.num_layers]
distribution = int_uniform
min = 1
max = 4
mean = 2
scale = auto

[sweep.train.total_timesteps]
distribution = log_normal
min = 20_000_000
max = 500_000_000
mean = 100_000_000
scale = auto

[sweep.train.learning_rate]
distribution = log_normal
min = 0.0005
max = 0.01
scale = auto
46 changes: 46 additions & 0 deletions ocean/four_rooms/binding.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "four_rooms.h"

#define OBS_SIZE (FOUR_ROOMS_VIEW_SIZE * FOUR_ROOMS_VIEW_SIZE * FOUR_ROOMS_OBS_CHANNELS)
#define NUM_ATNS 1
#define ACT_SIZES {FOUR_ROOMS_NUM_ACTIONS}
#define OBS_TENSOR_T ByteTensor

#define MY_VEC_STEP four_rooms_vec_step
#define MY_VEC_STEP_RANGE four_rooms_vec_step_range
#define Env FourRooms
#include "vecenv.h"

void four_rooms_vec_step(StaticVec* vec) {
memset(vec->rewards, 0, vec->total_agents * sizeof(float));
memset(vec->terminals, 0, vec->total_agents * sizeof(float));
FourRooms* envs = (FourRooms*)vec->envs;
for (int i = 0; i < vec->size; i++) {
c_step(&envs[i]);
}
}

void four_rooms_vec_step_range(StaticVec* vec, int env_start, int env_count, int num_workers) {
(void)num_workers;
FourRooms* envs = (FourRooms*)vec->envs;
for (int i = env_start; i < env_start + env_count; i++) {
c_step(&envs[i]);
}
}

void my_init(Env* env, Dict* kwargs) {
env->num_agents = 1;
env->size = (int)dict_get(kwargs, "size")->value;
env->max_steps = (int)dict_get(kwargs, "max_steps")->value;
if (env->max_steps <= 0) {
env->max_steps = 4 * env->size;
}
env->see_through_walls = 0;
env->grid = (unsigned char*)calloc(env->size * env->size, sizeof(unsigned char));
}

void my_log(Log* log, Dict* out) {
dict_set(out, "perf", log->perf);
dict_set(out, "score", log->score);
dict_set(out, "episode_return", log->episode_return);
dict_set(out, "episode_length", log->episode_length);
}
38 changes: 38 additions & 0 deletions ocean/four_rooms/four_rooms.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "four_rooms.h"

int main() {
FourRooms env = {};
env.size = 19;
env.max_steps = 0;
env.num_agents = 1;
env.rng = 0;
env.observations = (unsigned char*)calloc(
FOUR_ROOMS_VIEW_SIZE * FOUR_ROOMS_VIEW_SIZE * FOUR_ROOMS_OBS_CHANNELS,
sizeof(unsigned char)
);
env.actions = (float*)calloc(1, sizeof(float));
env.rewards = (float*)calloc(1, sizeof(float));
env.terminals = (float*)calloc(1, sizeof(float));
env.grid = (unsigned char*)calloc(env.size * env.size, sizeof(unsigned char));

c_reset(&env);
c_render(&env);
while (!WindowShouldClose()) {
if (IsKeyDown(KEY_LEFT_SHIFT)) {
env.actions[0] = 7; // Invalid action = no-op
if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) env.actions[0] = FORWARD;
if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) env.actions[0] = LEFT;
if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) env.actions[0] = RIGHT;
} else {
env.actions[0] = four_rooms_rand(&env, 3); // Only use left, right, forward
}
c_step(&env);
c_render(&env);
}
free(env.observations);
free(env.actions);
free(env.rewards);
free(env.terminals);
c_close(&env);
return 0;
}
Loading