Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions scripts/reinforcement_learning/rsl_rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app

"""Check for installed RSL-RL version."""

import importlib.metadata as metadata

from packaging import version

installed_version = metadata.version("rsl-rl-lib")

"""Rest everything follows."""

import os
Expand All @@ -70,7 +78,13 @@
from isaaclab.utils.assets import retrieve_file_path
from isaaclab.utils.dict import print_dict

from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx
from isaaclab_rl.rsl_rl import (
RslRlBaseRunnerCfg,
RslRlVecEnvWrapper,
export_policy_as_jit,
export_policy_as_onnx,
handle_deprecated_rsl_rl_cfg,
)
from isaaclab_rl.utils.pretrained_checkpoint import get_published_pretrained_checkpoint

import isaaclab_tasks # noqa: F401
Expand All @@ -91,6 +105,9 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs

# handle deprecated configurations
agent_cfg = handle_deprecated_rsl_rl_cfg(agent_cfg, installed_version)

# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg.seed
Expand Down Expand Up @@ -150,27 +167,31 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# obtain the trained policy for inference
policy = runner.get_inference_policy(device=env.unwrapped.device)

# extract the neural network module
# we do this in a try-except to maintain backwards compatibility.
try:
# version 2.3 onwards
policy_nn = runner.alg.policy
except AttributeError:
# version 2.2 and below
policy_nn = runner.alg.actor_critic

# extract the normalizer
if hasattr(policy_nn, "actor_obs_normalizer"):
normalizer = policy_nn.actor_obs_normalizer
elif hasattr(policy_nn, "student_obs_normalizer"):
normalizer = policy_nn.student_obs_normalizer
else:
normalizer = None

# export policy to onnx/jit
# export the trained policy to JIT and ONNX formats
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt")
export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx")

if version.parse(installed_version) >= version.parse("4.0.0"):
# use the new export functions for rsl-rl >= 4.0.0
runner.export_policy_to_jit(path=export_model_dir, filename="policy.pt")
runner.export_policy_to_onnx(path=export_model_dir, filename="policy.onnx")
else:
# extract the neural network for rsl-rl < 4.0.0
if version.parse(installed_version) >= version.parse("2.3.0"):
policy_nn = runner.alg.policy
else:
policy_nn = runner.alg.actor_critic

# extract the normalizer
if hasattr(policy_nn, "actor_obs_normalizer"):
normalizer = policy_nn.actor_obs_normalizer
elif hasattr(policy_nn, "student_obs_normalizer"):
normalizer = policy_nn.student_obs_normalizer
else:
normalizer = None

# export to JIT and ONNX
export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt")
export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx")

dt = env.unwrapped.step_dt

Expand All @@ -187,7 +208,10 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# env stepping
obs, _, dones, _ = env.step(actions)
# reset recurrent states for episodes that have terminated
policy_nn.reset(dones)
if version.parse(installed_version) >= version.parse("4.0.0"):
policy.reset(dones)
Comment thread
ClemensSchwarke marked this conversation as resolved.
else:
policy_nn.reset(dones)
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
Expand Down
5 changes: 4 additions & 1 deletion scripts/reinforcement_learning/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
from isaaclab.utils.dict import print_dict
from isaaclab.utils.io import dump_yaml

from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, handle_deprecated_rsl_rl_cfg

import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils import get_checkpoint_path
Expand All @@ -121,6 +121,9 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg.max_iterations
)

# handle deprecated configurations
agent_cfg = handle_deprecated_rsl_rl_cfg(agent_cfg, installed_version)

# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg.seed
Expand Down
2 changes: 1 addition & 1 deletion source/isaaclab_rl/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.4.7"
version = "0.5.0"

# Description
title = "Isaac Lab RL"
Expand Down
25 changes: 25 additions & 0 deletions source/isaaclab_rl/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,31 @@
Changelog
---------

0.5.0 (2026-3-04)
~~~~~~~~~~~~~~~~~~

Added
^^^^^
* Added function to handle deprecated RSL-RL configurations and automatically convert them to the new format compatible
with RSL-RL 4.0 and 5.0.
* Added new configuration classes "MLPModelCfg", "RNNModelCfg", and "CNNModelCfg", and "DistributionCfg" for the new
versions of RSL-RL.
* Added "check_for_nan" and "share_cnn_encoders" parameters to the configuration classes for RSL-RL 5.0.
* Added recurrent configurations for the "Isaac-Velocity-Flat-Anymal-D-v0" task for RSL-RL. to run RSL-RL CI.

Changed
^^^^^^^
* Adapted RSL-RL's train.py and play.py scripts to work with both old and the new versions of RSL-RL.

Deprecated
^^^^^^^^^^
* Deprecated old configuration classes "RslRlDistillationStudentTeacherCfg",
"RslRlDistillationStudentTeacherRecurrentCfg", "RslRlPpoActorCriticCfg", and "RslRlPpoActorCriticRecurrentCfg" in
favor of the new "MLPModelCfg", "RNNModelCfg", and "CNNModelCfg" configuration classes for RSL-RL 4.0.
* Deprecated old parameters "stochastic", "init_noise_std", "noise_std_type", amd "state_dependent_std" in favor of the
new "DistributionCfg" configuration class for RSL-RL 5.0.


0.4.7 (2025-12-29)
~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .rnd_cfg import RslRlRndCfg
from .symmetry_cfg import RslRlSymmetryCfg
from .vecenv_wrapper import RslRlVecEnvWrapper
from .utils import handle_deprecated_rsl_rl_cfg
124 changes: 70 additions & 54 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/distillation_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,79 @@

from isaaclab.utils import configclass

from .rl_cfg import RslRlBaseRunnerCfg
from .rl_cfg import RslRlBaseRunnerCfg, RslRlMLPModelCfg

############################
# Algorithm configurations #
############################


@configclass
class RslRlDistillationAlgorithmCfg:
"""Configuration for the distillation algorithm."""

class_name: str = "Distillation"
"""The algorithm class name. Default is Distillation."""

num_learning_epochs: int = MISSING
"""The number of updates performed with each sample."""

learning_rate: float = MISSING
"""The learning rate for the student policy."""

gradient_length: int = MISSING
"""The number of environment steps the gradient flows back."""

max_grad_norm: None | float = None
"""The maximum norm the gradient is clipped to."""

optimizer: Literal["adam", "adamw", "sgd", "rmsprop"] = "adam"
"""The optimizer to use for the student policy."""

loss_type: Literal["mse", "huber"] = "mse"
"""The loss type to use for the student policy."""


#########################
# Policy configurations #
# Runner configurations #
#########################


@configclass
class RslRlDistillationRunnerCfg(RslRlBaseRunnerCfg):
"""Configuration of the runner for distillation algorithms."""

class_name: str = "DistillationRunner"
"""The runner class name. Default is DistillationRunner."""

student: RslRlMLPModelCfg = MISSING
"""The student configuration."""

teacher: RslRlMLPModelCfg = MISSING
"""The teacher configuration."""

algorithm: RslRlDistillationAlgorithmCfg = MISSING
"""The algorithm configuration."""

policy: RslRlDistillationStudentTeacherCfg = MISSING
"""The policy configuration.

For rsl-rl >= 4.0.0, this configuration is deprecated. Please use `student` and `teacher` model configurations
instead.
"""


#############################
# Deprecated configurations #
#############################


@configclass
class RslRlDistillationStudentTeacherCfg:
"""Configuration for the distillation student-teacher networks."""
"""Configuration for the distillation student-teacher networks.

For rsl-rl >= 4.0.0, this configuration is deprecated. Please use `RslRlMLPModelCfg` instead.
"""

class_name: str = "StudentTeacher"
"""The policy class name. Default is StudentTeacher."""
Expand Down Expand Up @@ -48,7 +111,10 @@ class RslRlDistillationStudentTeacherCfg:

@configclass
class RslRlDistillationStudentTeacherRecurrentCfg(RslRlDistillationStudentTeacherCfg):
"""Configuration for the distillation student-teacher recurrent networks."""
"""Configuration for the distillation student-teacher recurrent networks.

For rsl-rl >= 4.0.0, this configuration is deprecated. Please use `RslRlRNNModelCfg` instead.
"""

class_name: str = "StudentTeacherRecurrent"
"""The policy class name. Default is StudentTeacherRecurrent."""
Expand All @@ -64,53 +130,3 @@ class RslRlDistillationStudentTeacherRecurrentCfg(RslRlDistillationStudentTeache

teacher_recurrent: bool = MISSING
"""Whether the teacher network is recurrent too."""


############################
# Algorithm configurations #
############################


@configclass
class RslRlDistillationAlgorithmCfg:
"""Configuration for the distillation algorithm."""

class_name: str = "Distillation"
"""The algorithm class name. Default is Distillation."""

num_learning_epochs: int = MISSING
"""The number of updates performed with each sample."""

learning_rate: float = MISSING
"""The learning rate for the student policy."""

gradient_length: int = MISSING
"""The number of environment steps the gradient flows back."""

max_grad_norm: None | float = None
"""The maximum norm the gradient is clipped to."""

optimizer: Literal["adam", "adamw", "sgd", "rmsprop"] = "adam"
"""The optimizer to use for the student policy."""

loss_type: Literal["mse", "huber"] = "mse"
"""The loss type to use for the student policy."""


#########################
# Runner configurations #
#########################


@configclass
class RslRlDistillationRunnerCfg(RslRlBaseRunnerCfg):
"""Configuration of the runner for distillation algorithms."""

class_name: str = "DistillationRunner"
"""The runner class name. Default is DistillationRunner."""

policy: RslRlDistillationStudentTeacherCfg = MISSING
"""The policy configuration."""

algorithm: RslRlDistillationAlgorithmCfg = MISSING
"""The algorithm configuration."""
Loading
Loading