Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ Manifest.toml
examples/**/.CondaPkg/*
*.bson
*.err
*.tsv
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ JuMP = "1.29.4"
MathOptInterface = "1.48.0"
ParametricOptInterface = "0.14.1"
Zygote = "0.6.77"
julia = "~1.9, ~1.10"
julia = "~1.9, 1.10"

[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
Expand Down
4 changes: 4 additions & 0 deletions examples/Atlas/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
DecisionRules = "47937410-f832-486f-8300-12c95b225dfc"
DiffOpt = "930fe3bc-9c6b-11ea-2d94-6184641e85e7"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326"
HSL_jll = "017b0a0e-03f4-516a-9b91-836bbd1904dd"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Expand All @@ -17,4 +20,5 @@ RigidBodyDynamics = "366cf18f-59d5-5db9-a4de-86a9f6786172"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Wandb = "ad70616a-06c9-5745-b1f1-6a5f42545108"
187 changes: 187 additions & 0 deletions examples/Atlas/atlas_visualization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

using MeshCat
using MeshCatMechanisms
using GeometryBasics: Point, Vec, HyperSphere
using RigidBodyDynamics: findbody, findjoint, successor, frame_after, default_frame, MechanismState, relative_transform, translation
using CoordinateTransformations: Translation
using Colors
using LinearAlgebra: norm, cross

const URDFPATH = joinpath(@__DIR__, "urdf", "atlas_all.urdf")

Expand All @@ -29,3 +34,185 @@ function animate!(model::Atlas, mvis::MechanismVisualizer, qs; Δt=0.001)

return anim
end

"""
animate_with_perturbation_cause!(model, mvis, qs, perturbations; kwargs...)

Animate Atlas and overlay an illustrative perturbation-cause arrow in MeshCat.
The arrow points in perturbation direction and its length scales with magnitude.

Arguments:
- `qs`: state trajectory (length `T`)
- `perturbations`: per-stage perturbations (typically length `T-1`)

Keyword arguments:
- `Δt`: frame step in seconds
- `arrow_scale`: converts perturbation magnitude to arrow length
- `min_arrow_length`: minimum visible arrow length when active
- `show_threshold`: hide arrow when `abs(perturbation) <= show_threshold`
- `linger_seconds`: how long the cause arrow remains visible after each perturbation
- `perturbation_state_index`: state index (in `x`) where perturbation is injected;
if it maps to a velocity state, the arrow is attached to that joint/body.
- `arrow_base`: local-frame base offset from the selected anchor frame
- `impact_distance`: local x-distance from anchor to initial contact point (keeps marker outside robot)
- `retreat_distance`: local x-distance the marker retreats after impact
- `shaft_radius`: thickness of the arrow shaft
"""
function animate_with_perturbation_cause!(
model::Atlas,
mvis::MechanismVisualizer,
qs,
perturbations;
Δt=0.001,
arrow_scale=1.0,
min_arrow_length=0.12,
show_threshold=1e-6,
linger_seconds=0.35,
perturbation_state_index=nothing,
arrow_base=Point(0.0, 0.0, 0.12),
impact_distance=0.18,
retreat_distance=0.35,
shaft_radius=0.03,
)
vis = mvis.visualizer
if isnothing(perturbation_state_index)
perturbation_state_index = model.nq + 5
end

velocity_idx = perturbation_state_index - model.nq
anchor_body = nothing
anchor_origin = Point(0.0, 0.0, 0.0)
perturbation_dir_local = Vec(0.0, 1.0, 0.0)
anchor_description = ""

if 1 <= velocity_idx <= length(model.joint_names)
joint_name = model.joint_names[velocity_idx]
joint = findjoint(model.mech, joint_name)
anchor_body = successor(joint, model.mech)
state0 = MechanismState(model.mech)
joint_in_body = translation(relative_transform(state0, default_frame(anchor_body), frame_after(joint)))
anchor_origin = Point(joint_in_body[1], joint_in_body[2], joint_in_body[3])
joint_type = getfield(joint, :joint_type)
if hasfield(typeof(joint_type), :axis)
axis = collect(getfield(joint_type, :axis))
axis_norm = norm(axis)
if axis_norm > 1e-8
axis ./= axis_norm
# Build a direction orthogonal to the joint axis so the effect reads as a lateral collision.
dir = cross(axis, [0.0, 0.0, 1.0])
if norm(dir) < 1e-8
dir = cross(axis, [1.0, 0.0, 0.0])
end
if norm(dir) > 1e-8
dir ./= norm(dir)
perturbation_dir_local = Vec(dir[1], dir[2], dir[3])
end
end
end
anchor_description =
"joint=$(joint_name), body=$(getfield(anchor_body, :name)), dir_local=$(collect(perturbation_dir_local))"
else
anchor_body = findbody(model.mech, "pelvis")
anchor_description =
"fallback=pelvis (state index $perturbation_state_index not mapped to velocity DOF), dir_local=$(collect(perturbation_dir_local))"
end

cause_arrow_parent = mvis[anchor_body]
cause_arrow = ArrowVisualizer(cause_arrow_parent[:perturbation_cause_arrow])
setobject!(
cause_arrow;
shaft_material=MeshLambertMaterial(color=colorant"red"),
head_material=MeshLambertMaterial(color=colorant"yellow"),
)
cause_impactor = cause_arrow_parent[:perturbation_cause_impactor]
setobject!(
cause_impactor,
HyperSphere(Point(0.0, 0.0, 0.0), 0.055),
MeshLambertMaterial(color=colorant"orange"),
)
linger_frames = max(1, round(Int, linger_seconds / Δt))
head_radius = 2.2 * shaft_radius
head_length = 2.8 * shaft_radius

anim = MeshCat.Animation(vis; fps=convert(Int, floor(1.0 / Δt)))
last_event_frame = 0
last_event_value = 0.0
last_event_sign = 1.0
event_count = count(p -> abs(p) > show_threshold, perturbations)
max_abs_pert = isempty(perturbations) ? 0.0 : maximum(abs.(perturbations))
println(
"Perturbation-cause overlay: events=$event_count, max_abs=$(round(max_abs_pert, digits=6)), " *
"perturb_state_idx=$perturbation_state_index, anchor={$anchor_description}, " *
"impact_distance=$impact_distance, retreat_distance=$retreat_distance"
)
for (frame, q) in enumerate(qs)
MeshCat.atframe(anim, frame) do
set_configuration!(mvis, q[1:model.nq])

p = frame <= length(perturbations) ? perturbations[frame] : 0.0
if abs(p) > show_threshold
last_event_frame = frame
last_event_value = p
last_event_sign = sign(p) == 0 ? 1.0 : sign(p)
end

frames_since_event = frame - last_event_frame
if last_event_frame > 0 && frames_since_event <= linger_frames
progress = frames_since_event / linger_frames
decay = 1.0 - progress
outward_dir = Vec(
last_event_sign * perturbation_dir_local[1],
last_event_sign * perturbation_dir_local[2],
last_event_sign * perturbation_dir_local[3],
)

# Contact happens just outside the body, then marker backs away.
contact_point = Point(
anchor_origin[1] + arrow_base[1] + outward_dir[1] * impact_distance,
anchor_origin[2] + arrow_base[2] + outward_dir[2] * impact_distance,
anchor_origin[3] + arrow_base[3] + outward_dir[3] * impact_distance,
)
impactor_point = Point(
contact_point[1] + outward_dir[1] * retreat_distance * progress,
contact_point[2] + outward_dir[2] * retreat_distance * progress,
contact_point[3] + outward_dir[3] * retreat_distance * progress,
)
settransform!(
cause_impactor,
Translation(impactor_point[1], impactor_point[2], impactor_point[3]),
)

effective_p = last_event_value * decay
arrow_length = max(min_arrow_length * decay, abs(effective_p) * arrow_scale)
# Arrow points from impactor toward robot (collision cause direction).
direction = Vec(
-outward_dir[1] * arrow_length,
-outward_dir[2] * arrow_length,
-outward_dir[3] * arrow_length,
)
settransform!(
cause_arrow,
impactor_point,
direction;
shaft_radius=shaft_radius,
max_head_radius=head_radius,
max_head_length=head_length,
)
else
# "Hide" by shrinking to zero length (more robust than animating visibility).
settransform!(
cause_arrow,
anchor_origin,
Vec(0.0, 0.0, 0.0);
shaft_radius=shaft_radius,
max_head_radius=head_radius,
max_head_length=head_length,
)
# Keep impactor out of view when there is no active perturbation event.
settransform!(cause_impactor, Translation(1000.0, 1000.0, 1000.0))
end
end
end
MeshCat.setanimation!(mvis, anim)
return anim
end
19 changes: 14 additions & 5 deletions examples/Atlas/build_atlas_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,20 @@ function build_atlas_subproblems(;

# Default optimizer
if isnothing(optimizer)
optimizer = () -> DiffOpt.diff_optimizer(optimizer_with_attributes(Ipopt.Optimizer,
"print_level" => 0,
"hsllib" => HSL_jll.libhsl_path,
"linear_solver" => "ma27"
))
optimizer = () -> begin
m = DiffOpt.diff_optimizer(
optimizer_with_attributes(
Ipopt.Optimizer,
"print_level" => 0,
"hsllib" => HSL_jll.libhsl_path,
"linear_solver" => "ma27",
),
)
# Atlas dynamics are encoded with a VectorNonlinearOracle constraint.
# Force the nonlinear DiffOpt backend so reverse differentiation works.
MOI.set(m, DiffOpt.ModelConstructor(), DiffOpt.NonLinearProgram.Model)
return m
end
end

if perturbation_frequency < 1
Expand Down
Loading