import gym
import d4rl
from gymnasium.spaces import Box
import numpy as np
from imitation.algorithms.bc import BC
from imitation.data.types import TransitionsMinimal
env = gym.make("halfcheetah-expert-v2")
dataset = env.get_dataset()
transitions = TransitionsMinimal(obs=dataset["observations"], acts=dataset["actions"], infos=dataset["infos/qpos"])
observation_space = Box(env.observation_space.low, env.observation_space.high, dtype=float)
action_space = Box(env.action_space.low, env.action_space.high, dtype=float)
bc = BC(
observation_space=observation_space,
action_space=action_space,
demonstrations=transitions,
rng=np.random.default_rng(0),
)
bc.train(n_epochs=1)
Bug description
imitation/src/imitation/data/types.py
Line 473 in a8b079c
While
TransitionsMinimaldoesn't have the fieldnext_obs, the collate function called during data loading expects there to benext_obs, makingTransitionsMinimalunusable with e.g.BC.Steps to reproduce
Environment
pip freeze --all: https://pastebin.com/9pKRZibK