Skip to content

Commit 30ec8f9

Browse files
yuecidengWaferLiliwenfeng
authored
Fix(gizmo): update to new Gizmo API (#213)
Co-authored-by: WaferLi <63717327+WaferLi@users.noreply.github.com> Co-authored-by: liwenfeng <liwenfeng@dexforce.top>
1 parent cf8e924 commit 30ec8f9

3 files changed

Lines changed: 71 additions & 58 deletions

File tree

embodichain/lab/sim/objects/gizmo.py

Lines changed: 67 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def _create_proxy_cube(
195195
proxy_cube.set_location(position[0], position[1], position[2])
196196
proxy_cube.set_rotation_euler(euler[0], euler[1], euler[2])
197197

198-
# Connect gizmo to proxy cube
199-
self._gizmo.node.update_gizmo_follow(proxy_cube.node)
198+
# Connect gizmo to proxy cube.
199+
self._gizmo.follow(proxy_cube.node)
200200

201201
logger.log_info(f"{name} gizmo proxy created at position: {position}")
202202
return proxy_cube
@@ -212,40 +212,55 @@ def _setup_camera_gizmo(self):
212212
self._proxy_cube = self._create_proxy_cube(
213213
camera_pos, camera_rot_matrix, "Camera"
214214
)
215-
self._gizmo.node.set_flush_transform_callback(self._proxy_gizmo_callback)
215+
# New API uses set_flush_localpose_callback
216+
try:
217+
self._gizmo.set_flush_localpose_callback(self._proxy_gizmo_callback)
218+
except Exception as e:
219+
logger.log_warning(f"Failed to set gizmo callback for camera: {e}")
220+
221+
def _proxy_gizmo_callback(self, *args):
222+
"""Generic callback for proxy-based gizmo.
216223
217-
def _proxy_gizmo_callback(self, node, translation, rotation, flag):
218-
"""Generic callback for proxy-based gizmo: only updates proxy cube transform, defers actual updates"""
224+
Supports both old signature: (node, translation, rotation, flag)
225+
and new signature: (node, local_pose, flag) where local_pose is a 4x4 matrix.
226+
Updates the proxy cube transform and sets `_pending_target_transform`.
227+
"""
228+
# New API callback signature: (node, local_pose, flag)
229+
if len(args) != 3:
230+
return
231+
node, local_pose, flag = args
219232
if node is None:
220233
return
221234

222-
# Check if proxy cube still exists (not destroyed)
235+
# Check if proxy cube still exists
223236
if not hasattr(self, "_proxy_cube") or self._proxy_cube is None:
224237
return
225238

226-
# Update proxy cube transform
227-
if flag == (TransformMask.TRANSFORM_LOCAL | TransformMask.TRANSFORM_T):
228-
node.set_translation(translation)
229-
elif flag == (TransformMask.TRANSFORM_LOCAL | TransformMask.TRANSFORM_R):
230-
node.set_rotation_rpy(rotation)
239+
# convert to numpy 4x4 matrix
240+
if isinstance(local_pose, torch.Tensor):
241+
lp = local_pose.cpu().numpy()
242+
else:
243+
lp = np.asarray(local_pose)
244+
245+
if lp.shape != (4, 4):
246+
return
247+
248+
trans = lp[:3, 3]
249+
rot_mat = lp[:3, :3]
250+
euler = R.from_matrix(rot_mat).as_euler("xyz", degrees=False)
231251

232-
# Mark that target needs to be updated, save target transform
233-
proxy_pos = self._proxy_cube.get_location()
234-
proxy_rot = self._proxy_cube.get_rotation_euler()
252+
self._proxy_cube.set_location(float(trans[0]), float(trans[1]), float(trans[2]))
253+
self._proxy_cube.set_rotation_euler(
254+
float(euler[0]), float(euler[1]), float(euler[2])
255+
)
256+
257+
# Build pending target transform (1,4,4)
235258
target_transform = torch.eye(4, dtype=torch.float32)
236259
target_transform[:3, 3] = torch.tensor(
237-
[proxy_pos[0], proxy_pos[1], proxy_pos[2]], dtype=torch.float32
238-
)
239-
target_transform[:3, :3] = torch.tensor(
240-
R.from_euler("xyz", proxy_rot).as_matrix(), dtype=torch.float32
260+
[trans[0], trans[1], trans[2]], dtype=torch.float32
241261
)
242-
# Ensure _pending_target_transform is (1, 4, 4)
243-
if isinstance(target_transform, torch.Tensor) and target_transform.shape == (
244-
4,
245-
4,
246-
):
247-
target_transform = target_transform.unsqueeze(0)
248-
self._pending_target_transform = target_transform
262+
target_transform[:3, :3] = torch.tensor(rot_mat, dtype=torch.float32)
263+
self._pending_target_transform = target_transform.unsqueeze(0)
249264

250265
def _update_camera_pose(self, target_transform: torch.Tensor):
251266
"""Update camera pose to match target transform"""
@@ -283,9 +298,9 @@ def _setup_robot_gizmo(self):
283298
ee_pos = ee_pose[:3, 3].cpu().numpy()
284299
ee_rot_matrix = ee_pose[:3, :3].cpu().numpy()
285300

286-
# Create proxy cube and set callback
301+
# Create proxy cube and set callback (use new callback API)
287302
self._proxy_cube = self._create_proxy_cube(ee_pos, ee_rot_matrix, "Robot")
288-
self._gizmo.node.set_flush_transform_callback(self._proxy_gizmo_callback)
303+
self._gizmo.set_flush_localpose_callback(self._proxy_gizmo_callback)
289304

290305
def _update_robot_ik(self, target_transform: torch.Tensor):
291306
"""Update robot joints using IK to reach target transform"""
@@ -343,9 +358,12 @@ def _update_robot_ik(self, target_transform: torch.Tensor):
343358
def _setup_gizmo_follow(self):
344359
"""Setup gizmo based on target type"""
345360
if self._target_type == "rigidobject":
346-
# RigidObject: direct node access through MeshObject
347-
self._gizmo.node.update_gizmo_follow(self.target._entities[0].node)
348-
self._gizmo.node.set_flush_transform_callback(create_gizmo_callback())
361+
# RigidObject: direct node access through MeshObject — use follow/attach
362+
tgt_node = self.target._entities[0].node
363+
self._gizmo.follow(tgt_node)
364+
# set callback (localpose-style)
365+
self._gizmo.set_flush_localpose_callback(create_gizmo_callback())
366+
349367
elif self._target_type == "robot":
350368
# Robot: create proxy object at end-effector position
351369
self._setup_robot_gizmo()
@@ -362,48 +380,45 @@ def attach(self, target: BatchEntity):
362380
def detach(self):
363381
"""Detach gizmo from current element."""
364382
self.target = None
365-
# Use detach_parent to properly disconnect gizmo
366-
try:
367-
self._gizmo.node.detach_parent()
368-
except Exception as e:
369-
logger.log_warning(f"Failed to detach gizmo parent: {e}")
383+
# Detach gizmo using new API
384+
self._gizmo.detach_parent()
370385

371386
def set_transform_callback(self, callback: Callable):
372387
"""Set callback for gizmo transform events (translation/rotation)."""
373388
self._callback = callback
374-
self._gizmo.node.set_flush_transform_callback(callback)
389+
self._gizmo.set_transform_flush_callback(callback)
375390

376391
def set_world_pose(self, pose):
377392
"""Set gizmo's world pose."""
378-
self._gizmo.node.set_world_pose(pose)
393+
self._gizmo.set_world_pose(pose)
379394

380395
def set_local_pose(self, pose):
381396
"""Set gizmo's local pose."""
382-
self._gizmo.node.set_local_pose(pose)
397+
self._gizmo.set_local_pose(pose)
383398

384399
def set_line_width(self, width: float):
385400
"""Set gizmo line width."""
386-
self._gizmo.node.set_line_width(width)
401+
self._gizmo.set_line_width(width)
387402

388403
def enable_collision(self, enabled: bool):
389404
"""Enable or disable gizmo collision."""
390-
self._gizmo.node.enable_collision(enabled)
405+
self._gizmo.enable_collision(enabled)
391406

392407
def get_world_pose(self):
393408
"""Get gizmo's world pose."""
394-
return self._gizmo.node.get_world_pose()
409+
return self._gizmo.get_world_pose()
395410

396411
def get_local_pose(self):
397412
"""Get gizmo's local pose."""
398-
return self._gizmo.node.get_local_pose()
413+
return self._gizmo.get_local_pose()
399414

400415
def get_name(self):
401416
"""Get gizmo node name."""
402-
return self._gizmo.node.get_name()
417+
return self._gizmo.get_name()
403418

404419
def get_parent(self):
405420
"""Get gizmo's parent node."""
406-
return self._gizmo.node.get_parent()
421+
return self._gizmo.get_parent()
407422

408423
def toggle_visibility(self) -> bool:
409424
"""
@@ -419,8 +434,8 @@ def toggle_visibility(self) -> bool:
419434
self._is_visible = not self._is_visible
420435

421436
# Apply the visibility setting to the gizmo node
422-
if self._gizmo and hasattr(self._gizmo, "node"):
423-
self._gizmo.node.set_visible(self._is_visible)
437+
if self._gizmo:
438+
self._gizmo.set_visible(self._is_visible)
424439

425440
return self._is_visible
426441

@@ -434,8 +449,8 @@ def set_visible(self, visible: bool):
434449
self._is_visible = visible
435450

436451
# Apply the visibility setting to the gizmo node
437-
if self._gizmo and hasattr(self._gizmo, "node"):
438-
self._gizmo.node.set_visible(self._is_visible)
452+
if self._gizmo:
453+
self._gizmo.set_visible(self._is_visible)
439454

440455
def is_visible(self) -> bool:
441456
"""
@@ -449,7 +464,9 @@ def is_visible(self) -> bool:
449464
def update(self):
450465
"""Synchronize gizmo with target's current transform, and handle IK solving here."""
451466
if self._target_type == "rigidobject":
452-
self._gizmo.node.update_gizmo_follow(self.target._entities[0].node)
467+
tgt_node = self.target._entities[0].node
468+
self._gizmo.follow(tgt_node)
469+
453470
elif self._target_type == "robot":
454471
# If there is a pending target, solve IK and clear it
455472
if (
@@ -514,7 +531,7 @@ def destroy(self):
514531
and self._gizmo
515532
and hasattr(self._gizmo, "node")
516533
):
517-
self._gizmo.node.detach_parent()
534+
self._gizmo.detach_parent()
518535
# Then remove the proxy cube
519536
self._env.remove_actor(self._proxy_cube)
520537
logger.log_info("Successfully removed proxy cube from environment")

embodichain/lab/sim/utility/gizmo_utils.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,16 @@
3131
def create_gizmo_callback() -> Callable:
3232
"""Create a standard gizmo transform callback function.
3333
34-
This callback handles basic translation and rotation operations for gizmo controls.
34+
This callback handles local pose for gizmo controls.
3535
It applies transformations directly to the node when gizmo controls are manipulated.
3636
3737
Returns:
3838
Callable: A callback function that can be used with gizmo.node.set_flush_transform_callback()
3939
"""
4040

41-
def gizmo_transform_callback(node, translation, rotation, flag):
41+
def gizmo_transform_callback(node, local_pose, flag):
4242
if node is not None:
43-
if flag == (TransformMask.TRANSFORM_LOCAL | TransformMask.TRANSFORM_T):
44-
# Handle translation changes
45-
node.set_translation(translation)
46-
elif flag == (TransformMask.TRANSFORM_LOCAL | TransformMask.TRANSFORM_R):
47-
# Handle rotation changes
48-
node.set_rotation_rpy(rotation)
43+
node.set_transform(local_pose, flag)
4944

5045
return gizmo_transform_callback
5146

examples/sim/gizmo/gizmo_object.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def run_simulation(sim: SimulationManager):
126126
sim.init_gpu_physics()
127127

128128
step_count = 0
129+
gizmo_enabled = True
129130
try:
130131
last_time = time.time()
131132
last_step = 0

0 commit comments

Comments
 (0)