-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_video.py
More file actions
34 lines (26 loc) · 1.02 KB
/
generate_video.py
File metadata and controls
34 lines (26 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import gymnasium as gym
from agents.dqn import DQNTrainer, QNetwork
import yaml
def generate_video(config_path='experiments/cartpole/configs/dqn_cartpole.yaml', video_folder='results/videos'):
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
env = gym.make(config['env_name'], render_mode='rgb_array')
env = gym.wrappers.RecordVideo(env, video_folder)
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
qnetwork = QNetwork(state_size, action_size)
qnetwork.load_state_dict(torch.load(config['save_path']))
qnetwork.eval()
state, _ = env.reset()
done = False
while not done:
state_tensor = torch.from_numpy(state).float().unsqueeze(0)
with torch.no_grad():
action_values = qnetwork(state_tensor)
action = action_values.argmax().item()
state, reward, done, _, _ = env.step(action)
env.close()
print(f"Video saved in {video_folder}")
if __name__ == '__main__':
generate_video()