Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0112c69
update vln yaml; fix import agent
kew6688 Nov 3, 2025
05ea2a3
update habitat, using evaluator and config; env and agent is WIP
kew6688 Nov 10, 2025
9902401
add distributed_base evaluator
kew6688 Nov 11, 2025
0d00014
Habitat env applied, distributed evaluator applied; clean evaluator a…
kew6688 Nov 12, 2025
7e25e72
fix observation issues
kew6688 Nov 12, 2025
2b0eb8b
update new register name; tiny fix on style
kew6688 Nov 12, 2025
b414ba3
latest tested
kew6688 Nov 12, 2025
99adf73
delete temp agent; rename default evaluator for habitat
kew6688 Nov 12, 2025
75b38a7
update slurm bash
kew6688 Nov 12, 2025
dcf7ee5
merge to main
kew6688 Nov 12, 2025
08bb9c3
update readme
kew6688 Nov 12, 2025
cde84b3
fix init dist print
kew6688 Nov 13, 2025
c89723d
fix eval config; fix local rank to rank
kew6688 Nov 13, 2025
7836276
update init distributed mode if condition
kew6688 Nov 13, 2025
dac13e1
update dist for dlc
kew6688 Nov 13, 2025
d8734c7
fix bug in evaluator
kew6688 Nov 13, 2025
fb21071
[test] dialog+object
Dec 10, 2025
4ecb613
merge dev
Dec 10, 2025
cd00d1e
[Feature] Add testing code for VLLN
Dec 10, 2025
ee69a31
[Feature] Add testing code for VLLN
Dec 11, 2025
1d5a16a
[Feature] Add testing code for VLLN
Dec 11, 2025
8607b2d
fix bugs; refactor env
kew6688 Dec 16, 2025
7aa020d
Merge branch 'dev' into code_review
kew6688 Dec 16, 2025
c1e59fb
update code, merge dev; fix bug
kew6688 Dec 16, 2025
7058317
add raise flag in agent
kew6688 Dec 16, 2025
f6955a1
update a readme for vlln
kew6688 Dec 16, 2025
018549d
fix bug in dialog evaluator
kew6688 Dec 17, 2025
133a25b
add back save video
kew6688 Dec 17, 2025
bc76714
fix video save path
kew6688 Dec 17, 2025
fec6ec2
[FIX] vlln bugs
Dec 19, 2025
1f82a20
Merge pull request #1 from kew6688/code_review
0309hws Dec 19, 2025
57f9677
[FIX] modify the annotations and some small bugs
Dec 19, 2025
abe6b63
[FIX] change some annotations
Dec 19, 2025
a7b15eb
update folder structure
kew6688 Dec 22, 2025
895710d
Merge pull request #2 from kew6688/fix_habitat_extensions
0309hws Dec 22, 2025
8445356
[FIX] delete files
Dec 22, 2025
0c589aa
[FIX] docstrings
Dec 23, 2025
403c2d6
[FIX] fix docstrings
Dec 23, 2025
b4a02db
[FIX] fix docstrings
Dec 23, 2025
9f69e71
[FIX] fix docstrings
Dec 23, 2025
aa5873d
[FIX] numpy version
Dec 23, 2025
a3803ea
[FIX] fix docstrings.
Dec 24, 2025
24e616a
[FIX] fix docstrings.
Dec 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,13 @@ If you use the specific pretrained models and benchmarks, please kindly cite the
booktitle={arXiv},
}
@misc{wei2025groundslowfastdualsystem,
title={Ground Slow, Move Fast: A Dual-System Foundation Model for Generalizable Vision-and-Language Navigation},
title={Ground Slow, Move Fast: A Dual-System Foundation Model for Generalizable Vision-and-Language Navigation},
author={Meng Wei and Chenyang Wan and Jiaqi Peng and Xiqian Yu and Yuqiang Yang and Delin Feng and Wenzhe Cai and Chenming Zhu and Tai Wang and Jiangmiao Pang and Xihui Liu},
year={2025},
eprint={2512.08186},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2512.08186},
url={https://arxiv.org/abs/2512.08186},
}
```

Expand Down
11 changes: 3 additions & 8 deletions internnav/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from internnav.agent.base import Agent
from internnav.agent.cma_agent import CmaAgent
from internnav.agent.dialog_agent import DialogAgent
from internnav.agent.internvla_n1_agent import InternVLAN1Agent
from internnav.agent.rdp_agent import RdpAgent
from internnav.agent.seq2seq_agent import Seq2SeqAgent
from internnav.agent.internvla_n1_agent import InternVLAN1Agent

__all__ = [
'Agent',
'CmaAgent',
'RdpAgent',
'Seq2SeqAgent',
'InternVLAN1Agent'
]
__all__ = ['Agent', 'DialogAgent', 'CmaAgent', 'RdpAgent', 'Seq2SeqAgent', 'InternVLAN1Agent']
1 change: 1 addition & 0 deletions internnav/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def decorator(agent_class):
if agent_type in cls.agents:
raise ValueError(f"Agent {agent_type} already registered.")
cls.agents[agent_type] = agent_class
return agent_class

return decorator

Expand Down
477 changes: 477 additions & 0 deletions internnav/agent/dialog_agent.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions internnav/configs/evaluator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class MetricCfg(BaseModel):

class TaskCfg(BaseModel):
task_name: Optional[str] = None
task_settings: Dict[str, Any]
scene: SceneCfg
task_settings: Dict[str, Any] = None
scene: SceneCfg = None
robot_name: Optional[str] = None
robot: Optional[RobotCfg] = None
robot_flash: Optional[bool] = None
Expand Down
3 changes: 2 additions & 1 deletion internnav/env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from internnav.env.base import Env
from internnav.env.habitat_env import HabitatEnv
from internnav.env.internutopia_env import InternutopiaEnv

__all__ = ['Env', 'InternutopiaEnv']
__all__ = ['Env', 'InternutopiaEnv', 'HabitatEnv']
1 change: 1 addition & 0 deletions internnav/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def decorator(env_class):
if env_type in cls.envs:
raise ValueError(f"Env {env_type} already registered.")
cls.envs[env_type] = env_class
return env_class

return decorator

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

@base.Env.register('habitat')
class HabitatEnv(base.Env):
def __init__(self, env_config: EnvCfg, task_config: TaskCfg):
"""
env_settings include:
- habitat_config: loaded from get_habitat_config
- rank: int, rank index for sharding
- world_size: int, total number of ranks
"""
"""A lightweight wrapper around `habitat.Env` that adapts Habitat to the project's `base.Env` interface.

Args:
env_config (EnvCfg): Environment configuration.
task_config (TaskCfg): Optional task configuration passed to the base environment.
"""
def __init__(self, env_config: EnvCfg, task_config: TaskCfg = None):
try:
from habitat import Env
except ImportError as e:
Expand All @@ -23,7 +23,6 @@ def __init__(self, env_config: EnvCfg, task_config: TaskCfg):
) from e

super().__init__(env_config, task_config)

self.config = env_config.env_settings['habitat_config']
self._env = Env(self.config)

Expand All @@ -36,16 +35,14 @@ def __init__(self, env_config: EnvCfg, task_config: TaskCfg):
self.output_path = env_config.env_settings.get('output_path', './output')

# generate episodes
# self._env.episodes = self._env.episodes[0:1] # for debug
self.episodes = self.generate_episodes()
# print(self.episodes)

def generate_episodes(self) -> List[Any]:
"""
Generate list of episodes for the current split, already:
- grouped by scene
- filtered by done_res (the path is self.output_path/progress.json)
- sharded by (rank, world_size)
Generate list of episodes for the current split.

Returns:
List[Any]: A list of episode objects for the current split.
"""
all_episodes = []

Expand Down Expand Up @@ -80,9 +77,6 @@ def generate_episodes(self) -> List[Any]:
return all_episodes

def reset(self):
"""
load next episode and return first observation
"""
# no more episodes
if not (0 <= self._current_episode_index < len(self.episodes)):
self.is_running = False
Expand All @@ -94,17 +88,9 @@ def reset(self):

# Habitat reset
self._last_obs = self._env.reset()

return self._last_obs

def step(self, action: List[Any]):
"""
step the environment with given action

Args: action: List[Any], action for each env in the batch

Return: obs, reward, done, info
"""
obs = self._env.step(action)
done = self._env.episode_over
info = self._env.get_metrics()
Expand Down
196 changes: 196 additions & 0 deletions internnav/env/utils/dialog_mp3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import cv2
import numpy as np


def fill_small_holes(depth_img: np.ndarray, area_thresh: int) -> np.ndarray:
"""
Identifies regions in the depth image that have a value of 0 and fills them in
with 1 if the region is smaller than a given area threshold.

Args:
depth_img (np.ndarray): The input depth image
area_thresh (int): The area threshold for filling in holes

Returns:
filled_depth_img (np.ndarray): The depth image with small holes filled in
"""
# Create a binary image where holes are 1 and the rest is 0
binary_img = np.where(depth_img == 0, 1, 0).astype("uint8")

# Find contours in the binary image
contours, _ = cv2.findContours(binary_img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

filled_holes = np.zeros_like(binary_img)

for cnt in contours:
# If the area of the contour is smaller than the threshold
if cv2.contourArea(cnt) < area_thresh:
# Fill the contour
cv2.drawContours(filled_holes, [cnt], 0, 1, -1)

# Create the filled depth image
filled_depth_img = np.where(filled_holes == 1, 1, depth_img)

return filled_depth_img


class MP3DGTPerception:
"""
Ground-truth perception utility for projecting MP3D object 3D bounding boxes
into the current camera view to produce per-target semantic masks.

Args:
max_depth (float): Maximum metric depth (used for depth rescaling and masking).
min_depth (float): Minimum metric depth (used for depth rescaling).
fx (float): Camera focal length in pixels along x.
fy (float): Camera focal length in pixels along y.
"""

def __init__(self, max_depth, min_depth, fx, fy):
self.max_depth = max_depth
self.min_depth = min_depth
self.fx = fx
self.fy = fy

def predict(self, depth, targets, tf_camera_to_ply, area_threshold=2500):
"""
Get ground-truth semantic masks for target objects by projecting 3D bboxes into the image.

Args:
depth (np.ndarray): Depth image of shape (H, W). Values are assumed to be normalized to [0, 1] and will be rescaled to metric depth using ``depth * (max_depth - min_depth) + min_depth``.
targets (np.ndarray): Target 3D axis-aligned bounding boxes of shape (N, 6), formatted as ``[min_x, min_y, min_z, max_x, max_y, max_z]`` in the PLY/world frame.
tf_camera_to_ply (np.ndarray): Homogeneous 4x4 transform from camera frame to the PLY/world frame.
area_threshold (int): Area threshold used by the hole-filling routine for both the depth map and the output masks.

Returns:
semantic_images (np.ndarray): Binary semantic masks of shape (N, H, W) with dtype ``np.uint8`` where 1 indicates pixels belonging to the corresponding target and 0 otherwise. If no targets are provided, returns an all-zero array of shape (1, H, W).
"""
# get the point clouds of current frame
filled_depth = fill_small_holes(depth, area_threshold)
scaled_depth = filled_depth * (self.max_depth - self.min_depth) + self.min_depth
mask = scaled_depth < self.max_depth
point_cloud_camera_frame = get_point_cloud(scaled_depth, mask, self.fx, self.fy)
point_cloud_ply_frame = transform_points(tf_camera_to_ply, point_cloud_camera_frame)

# mark the points in the target objects' bboxes
semantic_images = []
for target in targets:
min_x, min_y, min_z = target[:3]
max_x, max_y, max_z = target[3:]

in_bbox = (
(point_cloud_ply_frame[:, 0] >= min_x)
& (point_cloud_ply_frame[:, 0] <= max_x)
& (point_cloud_ply_frame[:, 1] >= min_y)
& (point_cloud_ply_frame[:, 1] <= max_y)
& (point_cloud_ply_frame[:, 2] >= min_z)
& (point_cloud_ply_frame[:, 2] <= max_z)
)
in_bbox_points = point_cloud_ply_frame[in_bbox]
semantic_image = np.zeros(depth.shape, dtype=np.uint8)
if len(in_bbox_points) > 0:
# map the marked points back to the image to get the semantic map
in_bbox_camera_frame = inverse_transform_points(tf_camera_to_ply, in_bbox_points)
in_box_image_coords = project_points_to_image(in_bbox_camera_frame, self.fx, self.fy, depth.shape)
try:
mask = [
in_box_image_coords[i, 0] < 480 and in_box_image_coords[i, 1] < 640
for i in range(len(in_box_image_coords))
]
in_box_image_coords = in_box_image_coords[mask]
semantic_image[in_box_image_coords[:, 0], in_box_image_coords[:, 1]] = 1
except Exception as e:
print(e)
semantic_image = fill_small_holes(semantic_image, area_threshold)
semantic_images.append(semantic_image)
if len(semantic_images) > 0:
semantic_images = np.stack(semantic_images, axis=0)
else:
semantic_images = np.zeros((1, depth.shape[0], depth.shape[1]), dtype=np.uint8)
return semantic_images


def transform_points(transformation_matrix: np.ndarray, points: np.ndarray) -> np.ndarray:
# Add a homogeneous coordinate of 1 to each point for matrix multiplication
homogeneous_points = np.hstack((points, np.ones((points.shape[0], 1))))

# Apply the transformation matrix to the points
transformed_points = np.dot(transformation_matrix, homogeneous_points.T).T

# Remove the added homogeneous coordinate and divide by the last coordinate
return transformed_points[:, :3] / transformed_points[:, 3:]


def get_point_cloud(depth_image: np.ndarray, mask: np.ndarray, fx: float, fy: float) -> np.ndarray:
"""Calculates the 3D coordinates (x, y, z) of points in the depth image based on
the horizontal field of view (HFOV), the image width and height, the depth values,
and the pixel x and y coordinates.

Args:
depth_image (np.ndarray): 2D depth image.
mask (np.ndarray): 2D binary mask identifying relevant pixels.
fx (float): Focal length in the x direction.
fy (float): Focal length in the y direction.

Returns:
cloud (np.ndarray): Array of 3D coordinates (x, y, z) of the points in the image plane.
"""
v, u = np.where(mask)
z = depth_image[v, u]
x = (u - depth_image.shape[1] // 2) * z / fx
y = (v - depth_image.shape[0] // 2) * z / fy
cloud = np.stack((x, -y, -z), axis=-1)

return cloud


def inverse_transform_points(transformation_matrix: np.ndarray, points: np.ndarray) -> np.ndarray:
"""Convert point cloud from episodic coordinate system to camera coordinate system

Args:
transformation_matrix (np.ndarray): 4x4 transformation matrix
points (np.ndarray): Point cloud coordinates (N, 3)

Returns:
result_points (np.ndarray): Point cloud coordinates in camera coordinate system (N, 3)
"""
# Calculate the inverse of the transformation matrix
inv_matrix = np.linalg.inv(transformation_matrix)

# Add a homogeneous coordinate of 1 to each point for matrix multiplication
homogeneous_points = np.hstack((points, np.ones((points.shape[0], 1))))

# Apply the inverse transformation
transformed_points = np.dot(inv_matrix, homogeneous_points.T).T

# Remove the added homogeneous coordinate
result_points = transformed_points[:, :3] / transformed_points[:, 3:]
return result_points


def project_points_to_image(points: np.ndarray, fx: float, fy: float, image_shape: tuple) -> np.ndarray:
"""Project points from camera coordinate system to image plane

Args:
points (np.ndarray): Points in camera coordinate system (N, 3)
fx (float): x-axis focal length
fy (float): y-axis focal length
image_shape (tuple): Image dimensions (height, width)

Returns:
result_points (np.ndarray): Image coordinates (N, 2)
"""
points = np.stack((points[:, 0], -points[:, 1], -points[:, 2]), axis=-1)
# Ensure points are in front of the camera
valid_mask = points[:, 2] > 0 # z > 0

# Calculate image coordinates
u = points[:, 0] * fx / points[:, 2] + image_shape[1] // 2
v = points[:, 1] * fy / points[:, 2] + image_shape[0] // 2

# Combine coordinates
image_coords = np.stack((v, u), axis=-1)
image_coords = image_coords.astype(np.int32)
# Return valid points only
result_points = image_coords[valid_mask]
return result_points
17 changes: 14 additions & 3 deletions internnav/evaluator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,20 @@

# register habitat
try:
import internnav.habitat_extensions # noqa: F401 # isort: skip
import internnav.habitat_extensions.vlln # noqa: F401 # isort: skip
except Exception as e:
print(f"Warning: ({e}), Habitat Evaluation is not loaded in this runtime. Ignore this if not using Habitat.")
print(f"Warning: ({e}), Habitat vlln is not loaded in this runtime. Ignore this if not using Habitat vlln.")

try:
import internnav.habitat_extensions.vln # noqa: F401 # isort: skip
except Exception as e:
print(f"Warning: ({e}), Habitat vln is not loaded in this runtime. Ignore this if not using Habitat vln.")


__all__ = ['Evaluator', 'DistributedEvaluator', 'VLNDistributedEvaluator', 'HabitatVLNEvaluator']
__all__ = [
'Evaluator',
'DistributedEvaluator',
'VLNDistributedEvaluator',
'HabitatVLNEvaluator',
'HabitatDialogEvaluator',
]
1 change: 1 addition & 0 deletions internnav/evaluator/distributed_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, eval_cfg: EvalCfg, init_env: bool = True, init_agent: bool =
from internnav.agent import Agent

eval_cfg.agent.model_settings['local_rank'] = self.local_rank
eval_cfg.agent.model_settings['task_name'] = eval_cfg.task.task_name
self.agent = Agent.init(eval_cfg.agent)

def eval(self):
Expand Down
2 changes: 0 additions & 2 deletions internnav/habitat_extensions/__init__.py

This file was deleted.

Loading