Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mpx/config/config_aliengo_trot_two_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
initial_state = base.initial_state

cost = partial(mpc_objectives.quadruped_wb_obj, True, n_joints, n_contact, N)
cost_smooth = partial(mpc_objectives.quadruped_wb_smooth_cost, True, n_joints, n_contact, N)
inequalities = partial(mpc_objectives.quadruped_wb_inequalities, n_joints, n_contact, 0.5, 44.0, 10.0)
hessian_approx = base.hessian_approx
dynamics = base.dynamics

Expand All @@ -64,3 +66,20 @@
solver_mode = "fddp"
max_torque = base.max_torque
min_torque = base.min_torque

lipa_enforce_inequalities = True

def _lipa_settings():
from primal_dual_lipa.types import SolverSettings
return SolverSettings(
max_iterations=2000,
η0=1e9,
η_update_factor=1.0,
µ_update_factor=0.9,
cost_improvement_threshold=1e-3,
primal_violation_threshold=1e-5,
use_parallel_lqr=False,
num_parallel_line_search_steps=1,
)

lipa_settings = _lipa_settings()
23 changes: 22 additions & 1 deletion mpx/config/config_barrel_roll.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
)

cost = partial(mpc_objectives.quadruped_wb_obj, False, n_joints, n_contact, N)
cost_smooth = partial(mpc_objectives.quadruped_wb_smooth_cost, False, n_joints, n_contact, N)
hessian_approx = None

def dynamics(model, mjx_model, contact_id, body_id):
Expand All @@ -105,4 +106,24 @@ def dynamics(model, mjx_model, contact_id, body_id):
# dynamics = mpc_dyn_model.quadruped_wb_dynamics_learned_contact_model
# dynamics = mpc_dyn_model.quadruped_wb_dynamics_explicit_contact
max_torque = 40
min_torque = -40
min_torque = -40

inequalities = partial(
mpc_objectives.quadruped_wb_inequalities, n_joints, n_contact, 0.5, 50.0, 20.0
)
lipa_enforce_inequalities = True

def _lipa_settings():
from primal_dual_lipa.types import SolverSettings
return SolverSettings(
max_iterations=2000,
η0=1e9,
η_update_factor=1.1,
µ_update_factor=0.9,
cost_improvement_threshold=1e-3,
primal_violation_threshold=1e-5,
use_parallel_lqr=False,
num_parallel_line_search_steps=1,
)

lipa_settings = _lipa_settings()
36 changes: 36 additions & 0 deletions mpx/config/config_h1_jump_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
torque_limits = base.torque_limits

cost = partial(mpc_objectives.h1_kinodynamic_obj, n_joints, n_contact, N)
cost_smooth = partial(mpc_objectives.h1_kinodynamic_smooth_cost, n_joints, n_contact, N)
inequalities = partial(mpc_objectives.h1_kinodynamic_inequalities, n_joints, n_contact, 0.7)
hessian_approx = base.hessian_approx
dynamics = base.dynamics
MPCWrapper = base.MPCWrapper
Expand All @@ -58,3 +60,37 @@
solver_mode = "fddp"
max_torque = base.max_torque
min_torque = base.min_torque

lipa_enforce_inequalities = True

def _lipa_settings():
from primal_dual_lipa.types import SolverSettings
return SolverSettings(
max_iterations=2000,
η0=1e9,
η_update_factor=1.0,
µ_update_factor=0.9,
cost_improvement_threshold=1e-3,
primal_violation_threshold=1e-5,
num_iterative_refinement_steps=2,
use_parallel_lqr=False,
num_parallel_line_search_steps=1,
)

lipa_settings = _lipa_settings()

def _lipa_settings_enforce():
from primal_dual_lipa.types import SolverSettings
return SolverSettings(
max_iterations=500,
η0=1e5,
η_update_factor=2.0,
µ_update_factor=0.9,
cost_improvement_threshold=1e-3,
primal_violation_threshold=1e-5,
num_iterative_refinement_steps=2,
use_parallel_lqr=False,
num_parallel_line_search_steps=1,
)

lipa_settings_enforce = _lipa_settings_enforce()
2 changes: 1 addition & 1 deletion mpx/data/acrobot/scene.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<option timestep="0.01" gravity="0 0 -9.81"/>

<visual>
<global azimuth="135" elevation="-20"/>
<global azimuth="135" elevation="-20" offwidth="1920" offheight="1080"/>
<headlight ambient="0.35 0.35 0.35" diffuse="0.75 0.75 0.75" specular="0.2 0.2 0.2"/>
<rgba haze="1 1 1 1"/>
</visual>
Expand Down
2 changes: 1 addition & 1 deletion mpx/data/aliengo/scene_flat.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<visual>
<headlight diffuse="0.6 0.6 0.6" ambient="0.3 0.3 0.3" specular="0 0 0"/>
<rgba haze="0.15 0.25 0.35 1"/>
<global azimuth="-130" elevation="-20"/>
<global azimuth="-130" elevation="-20" offwidth="1920" offheight="1080"/>
</visual>

<asset>
Expand Down
2 changes: 1 addition & 1 deletion mpx/data/unitree_h1/mjx_scene_h1_walk.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<visual>
<headlight diffuse="0.6 0.6 0.6" ambient="0.3 0.3 0.3" specular="0 0 0"/>
<rgba haze="0.15 0.25 0.35 1"/>
<global azimuth="160" elevation="-20"/>
<global azimuth="160" elevation="-20" offwidth="1920" offheight="1080"/>
</visual>

<asset>
Expand Down
2 changes: 1 addition & 1 deletion mpx/examples/acrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,6 @@ def step_controller(viewer=None):
parser = argparse.ArgumentParser()
parser.add_argument("--headless", action="store_true")
parser.add_argument("--steps", type=int, default=500)
parser.add_argument("--solver", choices=("primal_dual", "fddp"), default="primal_dual")
parser.add_argument("--solver", choices=("primal_dual", "fddp", "lipa"), default="primal_dual")
args = parser.parse_args()
main(headless=args.headless, steps=args.steps, solver_mode=args.solver)
46 changes: 43 additions & 3 deletions mpx/examples/mjx_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
sys.path.append(os.path.abspath(os.path.join(dir_path, "..")))
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")

if "--video" in sys.argv:
os.environ.setdefault("MUJOCO_GL", "egl")
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")

import jax
import jax.numpy as jnp
import mujoco
Expand Down Expand Up @@ -36,7 +40,7 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact):
return solve_mpc


def main(steps=500):
def main(steps=500, video=None, vx=0.0, vy=0.0, wz=0.0, fps=30, headless=False):
model = mujoco.MjModel.from_xml_path(
dir_path + "/../data/unitree_h1/mjx_scene_h1_walk.xml"
)
Expand All @@ -45,7 +49,7 @@ def main(steps=500):
model.opt.timestep = 1 / sim_frequency

mpc = mpc_wrapper.MPCWrapper(config, limited_memory=True)
command_handle = sim_utils.KeyboardVelocityCommand()
command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz)
solve_mpc = _build_solve_fn(mpc)
reset_mpc = jax.jit(mpc.reset)

Expand Down Expand Up @@ -102,6 +106,27 @@ def step_controller():
mujoco.mj_step(model, data)
counter += 1

if headless or video is not None:
recorder = None
capture_period = max(1, int(round(sim_frequency / fps)))
if video is not None:
os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True)
recorder = sim_utils.VideoRecorder(model, video, fps=fps)
p_start = np.asarray(data.qpos[:3]).copy()
try:
for i in range(steps):
step_controller()
if recorder is not None and i % capture_period == 0:
recorder.capture(data)
finally:
if recorder is not None:
recorder.close()
print(f"Wrote video: {video}")
p_end = np.asarray(data.qpos[:3])
delta = p_end - p_start
print(f"Base position: start={p_start} end={p_end} delta={delta}")
return

with mujoco.viewer.launch_passive(
model,
data,
Expand All @@ -119,5 +144,20 @@ def step_controller():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=500)
parser.add_argument("--headless", action="store_true")
parser.add_argument("--video", type=str, default=None,
help="Write an mp4 of the run to this path (forces headless).")
parser.add_argument("--vx", type=float, default=0.0)
parser.add_argument("--vy", type=float, default=0.0)
parser.add_argument("--wz", type=float, default=0.0)
parser.add_argument("--fps", type=int, default=30)
args = parser.parse_args()
main(steps=args.steps)
main(
steps=args.steps,
video=args.video,
vx=args.vx,
vy=args.vy,
wz=args.wz,
fps=args.fps,
headless=args.headless,
)
46 changes: 43 additions & 3 deletions mpx/examples/mjx_h1_kinodynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
sys.path.append(os.path.abspath(os.path.join(dir_path, "..")))
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")

if "--video" in sys.argv:
os.environ.setdefault("MUJOCO_GL", "egl")
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")

import jax
import jax.numpy as jnp
import mujoco
Expand Down Expand Up @@ -35,7 +39,7 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact):
return solve_mpc


def main(steps=500):
def main(steps=500, video=None, vx=0.0, vy=0.0, wz=0.0, fps=30, headless=False):
model = mujoco.MjModel.from_xml_path(
dir_path + "/../data/unitree_h1/mjx_scene_h1_walk.xml"
)
Expand All @@ -44,7 +48,7 @@ def main(steps=500):
model.opt.timestep = 1 / sim_frequency

mpc = config.MPCWrapper(config, limited_memory=True)
command_handle = sim_utils.KeyboardVelocityCommand()
command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz)
solve_mpc = _build_solve_fn(mpc)
reset_mpc = jax.jit(mpc.reset)

Expand Down Expand Up @@ -100,6 +104,27 @@ def step_controller():
mujoco.mj_step(model, data)
counter += 1

if headless or video is not None:
recorder = None
capture_period = max(1, int(round(sim_frequency / fps)))
if video is not None:
os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True)
recorder = sim_utils.VideoRecorder(model, video, fps=fps)
p_start = np.asarray(data.qpos[:3]).copy()
try:
for i in range(steps):
step_controller()
if recorder is not None and i % capture_period == 0:
recorder.capture(data)
finally:
if recorder is not None:
recorder.close()
print(f"Wrote video: {video}")
p_end = np.asarray(data.qpos[:3])
delta = p_end - p_start
print(f"Base position: start={p_start} end={p_end} delta={delta}")
return

with mujoco.viewer.launch_passive(
model,
data,
Expand All @@ -117,5 +142,20 @@ def step_controller():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=500)
parser.add_argument("--headless", action="store_true")
parser.add_argument("--video", type=str, default=None,
help="Write an mp4 of the run to this path (forces headless).")
parser.add_argument("--vx", type=float, default=0.0)
parser.add_argument("--vy", type=float, default=0.0)
parser.add_argument("--wz", type=float, default=0.0)
parser.add_argument("--fps", type=int, default=30)
args = parser.parse_args()
main(steps=args.steps)
main(
steps=args.steps,
video=args.video,
vx=args.vx,
vy=args.vy,
wz=args.wz,
fps=args.fps,
headless=args.headless,
)
56 changes: 51 additions & 5 deletions mpx/examples/mjx_quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
sys.path.append(os.path.abspath(os.path.join(dir_path, "..")))
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")

# Headless video recording uses `mujoco.Renderer`, which requires an OpenGL
# backend to be configured before the first `import mujoco` in the process.
if "--video" in sys.argv:
os.environ.setdefault("MUJOCO_GL", "egl")
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")

import jax
import jax.numpy as jnp
import mujoco
Expand Down Expand Up @@ -37,7 +43,16 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact):
return solve_mpc


def main(headless=False, steps=500, scene="flat"):
def main(
headless=False,
steps=500,
scene="flat",
video=None,
vx=0.0,
vy=0.0,
wz=0.0,
fps=30,
):
model = mujoco.MjModel.from_xml_path(
dir_path + f"/../data/aliengo/scene_{scene}.xml"
)
Expand All @@ -47,7 +62,9 @@ def main(headless=False, steps=500, scene="flat"):

contact_ids = sim_utils.geom_ids(model, config.contact_frame)
mpc = mpc_wrapper.MPCWrapper(config, limited_memory=True)
command_handle = sim_utils.KeyboardVelocityCommand()
# Headless+video: scripted velocity (no keyboard); viewer mode keeps the
# interactive arrow-key handle.
command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz)
solve_mpc = _build_solve_fn(mpc)
reset_mpc = jax.jit(mpc.reset)

Expand Down Expand Up @@ -112,9 +129,25 @@ def step_controller():
mujoco.mj_step(model, data)
counter += 1

if headless:
for _ in range(steps):
step_controller()
if headless or video is not None:
recorder = None
capture_period = max(1, int(round(sim_frequency / fps)))
if video is not None:
os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True)
recorder = sim_utils.VideoRecorder(model, video, fps=fps)
p_start = np.asarray(data.qpos[:3]).copy()
try:
for i in range(steps):
step_controller()
if recorder is not None and i % capture_period == 0:
recorder.capture(data)
finally:
if recorder is not None:
recorder.close()
print(f"Wrote video: {video}")
p_end = np.asarray(data.qpos[:3])
delta = p_end - p_start
print(f"Base position: start={p_start} end={p_end} delta={delta}")
return

with mujoco.viewer.launch_passive(
Expand All @@ -141,9 +174,22 @@ def step_controller():
parser.add_argument("--steps", type=int, default=500)
parser.add_argument("--scene", type=str, default="flat")
parser.add_argument("--headless", action="store_true")
parser.add_argument("--video", type=str, default=None,
help="Write an mp4 of the run to this path (forces headless).")
parser.add_argument("--vx", type=float, default=0.0,
help="Forward velocity command (m/s) for headless/video runs.")
parser.add_argument("--vy", type=float, default=0.0)
parser.add_argument("--wz", type=float, default=0.0,
help="Yaw-rate command (rad/s).")
parser.add_argument("--fps", type=int, default=30)
args = parser.parse_args()
main(
headless=args.headless,
steps=args.steps,
scene=args.scene,
video=args.video,
vx=args.vx,
vy=args.vy,
wz=args.wz,
fps=args.fps,
)
Loading