Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ commands:
steps:
- run:
name: install macOS packages
command: HOMEBREW_NO_AUTO_UPDATE=1 brew install coreutils parallel gnu-getopt
command: HOMEBREW_NO_AUTO_UPDATE=1 brew install coreutils gnu-getopt parallel python@3.9 virtualenv

- checkout

Expand Down Expand Up @@ -138,11 +138,13 @@ commands:
# Download and cache dependencies
- restore_cache:
keys:
- v11win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}
- v13win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}

- run:
name: install python
command: choco install --allow-downgrade -y python --version=3.8.10
# Use python3.9 in Windows instead of python3.8 because otherwise
# pytest-notebook's indirect dependency pywinpty will fail to build.
command: choco install --allow-downgrade -y python --version=3.9.13
shell: powershell.exe

- run:
Expand All @@ -163,14 +165,20 @@ commands:

- run:
name: install dependencies
# Only create venv if it's not been restored from cache
command: if (-not (Test-Path venv)) { .\ci\build_and_activate_venv.ps1 -venv venv }
# Only create venv if it's not been restored from cache.
# Need to throw error explicitly on error or else {} will get rid of
# the exit code.
command: |
if (-not (Test-Path venv)) {
.\ci\build_and_activate_venv.ps1 -venv venv
if ($LASTEXITCODE -ne 0) { throw "Failed to create venv" }
}
shell: powershell.exe

- save_cache:
paths:
- .\venv
key: v11win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}
key: v13win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}

- run:
name: install imitation
Expand Down
2 changes: 1 addition & 1 deletion ci/build_and_activate_venv.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ If ($venv -eq $null) {
$venv = "venv"
}

virtualenv -p python3.8 $venv
virtualenv -p python3.9 $venv
& $venv\Scripts\activate
pip install ".[docs,parallel,test]"
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[mypy]
ignore_missing_imports = true
exclude = output

# torch had some type errors, we ignore them because they're not our fault
[mypy-torch._dynamo.*]
follow_imports = skip
follow_imports_for_stubs = True
2 changes: 1 addition & 1 deletion src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(
self.debug_use_ground_truth = debug_use_ground_truth
self.venv = venv
self.gen_algo = gen_algo
self._reward_net = reward_net.to(gen_algo.device)
self._reward_net: reward_nets.RewardNet = reward_net.to(gen_algo.device)
self._log_dir = util.parse_path(log_dir)

# Create graph for optimising/recording stats on discriminator
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/data/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]) -> None:
trajectories: The trajectories to save.
"""
p = util.parse_path(path)
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(str(p))
logging.info(f"Dumped demonstrations to {p}.")


Expand Down
2 changes: 1 addition & 1 deletion src/imitation/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _predict(
):
np_actions = []
if isinstance(obs, dict):
np_obs = types.DictObs(
np_obs: Union[types.DictObs, np.ndarray] = types.DictObs(
{k: v.detach().cpu().numpy() for k, v in obs.items()},
)
else:
Expand Down
7 changes: 2 additions & 5 deletions src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,7 @@ def safe_to_tensor(array: Union[np.ndarray, th.Tensor], **kwargs) -> th.Tensor:
Returns:
A PyTorch tensor with the same content as `array`.
"""
if isinstance(array, th.Tensor):
return array

if not array.flags.writeable:
if isinstance(array, np.ndarray) and not array.flags.writeable:
array = array.copy()

return th.as_tensor(array, **kwargs)
Expand Down Expand Up @@ -476,6 +473,6 @@ def split_in_half(x: int) -> Tuple[int, int]:
def clear_screen() -> None:
"""Clears the console screen."""
if os.name == "nt": # Windows
os.system("cls")
os.system("cls") # pragma: no cover
else:
os.system("clear")
2 changes: 1 addition & 1 deletion tests/algorithms/test_sqil.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_sqil_performance_continuous(
pytestconfig: pytest.Config,
pendulum_single_venv: vec_env.VecEnv,
rl_algo_class: Type[off_policy_algorithm.OffPolicyAlgorithm],
):
): # pragma: no cover
rl_kwargs = dict(
learning_starts=500,
learning_rate=0.001,
Expand Down
4 changes: 4 additions & 0 deletions tests/data/test_huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def test_save_load_roundtrip(


@hypothesis.given(st.data(), h_strats.trajectories_list)
# the first run sometimes takes longer, so we give it more time
@hypothesis.settings(deadline=datetime.timedelta(milliseconds=300))
def test_sliced_access(data: st.DataObject, trajectories: Sequence[types.Trajectory]):
"""Test that slicing a TrajectoryDatasetSequence behaves as expected."""
# GIVEN
Expand All @@ -84,6 +86,8 @@ def test_sliced_access(data: st.DataObject, trajectories: Sequence[types.Traject


@hypothesis.given(st.data(), h_strats.trajectory)
# the first run sometimes takes longer, so we give it more time
@hypothesis.settings(deadline=datetime.timedelta(milliseconds=300))
def test_sliced_info_dict_access(
data: st.DataObject,
trajectory: types.Trajectory,
Expand Down
2 changes: 2 additions & 0 deletions tests/util/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def test_safe_to_numpy():
numpy = util.safe_to_numpy(tensor)
assert (numpy == tensor.numpy()).all()
assert util.safe_to_numpy(None) is None
with pytest.warns(UserWarning, match=".*performance.*"):
util.safe_to_numpy(tensor, warn=True)


def test_tensor_iter_norm():
Expand Down
Loading