Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions diffusion_planner/diffusion_planner/utils/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,94 @@ def heading_transform(heading, transform_mat):
).reshape(*shape)


def _cross2d(u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""2D cross product along the last dimension: u × v = u.x*v.y - u.y*v.x"""
return u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0]


def _rect_corners(rect: torch.Tensor) -> torch.Tensor:
"""
rect: [B, 6] — (x, y, cos_h, sin_h, length, width)
Returns [B, 4, 2] corner points.
"""
B = rect.shape[0]
xy, cos_h, sin_h, lw = rect[:, :2], rect[:, 2], rect[:, 3], rect[:, 4:]
rot = torch.stack([cos_h, -sin_h, sin_h, cos_h], dim=1).reshape(B, 2, 2)
signs = torch.tensor([[1.0, 1], [-1, 1], [-1, -1], [1, -1]], device=lw.device)
local = torch.einsum("bj,ij->bij", lw / 2, signs) # [B, 4, 2]
local = torch.einsum("bij,bkj->bik", local, rot) # [B, 4, 2]
return xy[:, None, :] + local


def _sat_signed_distance(c1: torch.Tensor, c2: torch.Tensor) -> torch.Tensor:
"""
SAT signed distance between two rectangles.
c1, c2: [B, 4, 2] corner points
Returns [B] — negative means overlap.
"""
nv = torch.stack(
[c1[:, 0] - c1[:, 1], c1[:, 1] - c1[:, 2],
c2[:, 0] - c2[:, 1], c2[:, 1] - c2[:, 2]],
dim=1,
) # [B, 4, 2]
nv = nv / torch.norm(nv, dim=2, keepdim=True).clamp(min=1e-6)
p1 = torch.einsum("bij,bkj->bik", nv, c1) # [B, 4, 4]
p2 = torch.einsum("bij,bkj->bik", nv, c2)
overlap = torch.cat(
[p1.min(2).values - p2.max(2).values, p2.min(2).values - p1.max(2).values],
dim=1,
) # [B, 8]
is_overlap = (overlap < 0).all(dim=1)
pos = torch.where(overlap < 0, torch.full_like(overlap, 1e5), overlap)
return torch.where(is_overlap, overlap.max(1).values, pos.min(1).values)


def _segments_intersect_rect(
seg_start: torch.Tensor,
seg_end: torch.Tensor,
rect_corners: torch.Tensor,
valid: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Returns [B] bool — True if any valid segment touches the rectangle.

seg_start, seg_end: [B, N, 2]
rect_corners: [B, 4, 2]
valid: [B, N] bool — True for valid segments
"""
hit = torch.zeros(seg_start.shape[:2], dtype=torch.bool, device=seg_start.device)
edges = [(0, 1), (1, 2), (2, 3), (3, 0)]

# Proper segment–edge crossing: both pairs straddle each other's line
for i, j in edges:
C = rect_corners[:, i, :].unsqueeze(1) # [B, 1, 2]
D = rect_corners[:, j, :].unsqueeze(1) # [B, 1, 2]
AB = seg_end - seg_start # [B, N, 2]
CD = D - C # [B, 1, 2]
hit = hit | (
(_cross2d(AB, C - seg_start) * _cross2d(AB, D - seg_start) < 0)
& (_cross2d(CD, seg_start - C) * _cross2d(CD, seg_end - C) < 0)
)

# Endpoint inside polygon: all edge cross products share the same sign
for pt in (seg_start, seg_end):
crosses = torch.stack(
[
_cross2d(
(rect_corners[:, j, :] - rect_corners[:, i, :]).unsqueeze(1),
pt - rect_corners[:, i, :].unsqueeze(1),
)
for i, j in edges
],
dim=-1,
) # [B, N, 4]
hit = hit | (crosses > 0).all(-1) | (crosses < 0).all(-1)

if valid is not None:
hit = hit & valid
return hit.any(dim=1) # [B]


class StatePerturbation:
"""
Data augmentation that perturbs the current ego position and generates a feasible trajectory that
Expand Down Expand Up @@ -149,8 +237,90 @@ def augment(self, inputs):
ego_current_state[:, 8] = steering_angle
ego_current_state[:, 9] = new_yaw_rate

# Discard augmentations that cause collisions
collision = self._check_aug_validity(ego_current_state, inputs)
aug_flag = aug_flag & ~collision

return aug_flag, ego_current_state

def _check_aug_validity(
self, aug_ego_state: torch.Tensor, inputs: dict
) -> torch.Tensor:
"""
Returns [B] bool — True where the augmented ego position is invalid.

Invalid conditions:
1. Ego polygon overlaps with a neighbour agent polygon.
2. Ego polygon intersects a lane left or right boundary segment.
"""
B = aug_ego_state.shape[0]
device = aug_ego_state.device
dtype = aug_ego_state.dtype

# ego_shape: [B, 3] = (wheelbase, length, width)
ego_shape = inputs["ego_shape"].to(device=device, dtype=dtype)
ego_length = ego_shape[:, 1:2] # [B, 1]
ego_width = ego_shape[:, 2:3] # [B, 1]

ego_rect = torch.cat(
[aug_ego_state[:, :4], ego_length, ego_width],
dim=-1,
) # [B, 6]
ego_corners = _rect_corners(ego_rect) # [B, 4, 2]

collision = torch.zeros(B, dtype=torch.bool, device=device)

# ── 1. Neighbour agent polygon collision ──────────────────────────────
if "neighbor_agents_past" in inputs:
nbr = inputs["neighbor_agents_past"][:, :, -1, :] # [B, N, 11]
N = nbr.shape[1]
valid = torch.sum(torch.ne(nbr[:, :, :4], 0), dim=-1) > 0 # [B, N]
if valid.any():
# neighbor_agents_past layout: x,y,cos,sin (0:4), width (6), length (7)
nbr_rect = torch.cat(
[nbr[:, :, :4], nbr[:, :, 7:8], nbr[:, :, 6:7]], dim=-1
) # [B, N, 6] — (x,y,cos,sin,length,width)
dists = _sat_signed_distance(
_rect_corners(ego_rect.unsqueeze(1).expand(-1, N, -1).reshape(B * N, 6)),
_rect_corners(nbr_rect.reshape(B * N, 6)),
).reshape(B, N)
collision = collision | ((dists < 0) & valid).any(dim=1)

# ── 2. Lane boundary segment collision ───────────────────────────────
if "lanes" in inputs:
lanes = inputs["lanes"] # [B, L, P, 33]
left_offset = lanes[..., 4:6] # [B, L, P, 2]
right_offset = lanes[..., 6:8] # [B, L, P, 2]

# Absolute boundary positions
left_pts = lanes[..., :2] + left_offset # [B, L, P, 2]
right_pts = lanes[..., :2] + right_offset # [B, L, P, 2]

# A waypoint is valid when its first 8 features are not all zero.
# Additionally, only include a boundary side when its offset is
# non-trivial; a near-zero offset means no boundary data.
lane_valid = torch.sum(torch.ne(lanes[..., :8], 0), dim=-1) > 0 # [B, L, P]
left_bound_valid = (torch.norm(left_offset, dim=-1) > 0.01) & lane_valid
right_bound_valid = (torch.norm(right_offset, dim=-1) > 0.01) & lane_valid

def _boundary_segs(pts, point_valid):
s = pts[:, :, :-1, :].reshape(B, -1, 2)
e = pts[:, :, 1:, :].reshape(B, -1, 2)
v = (point_valid[:, :, :-1] & point_valid[:, :, 1:]).reshape(B, -1)
return s, e, v

ls, le, lv = _boundary_segs(left_pts, left_bound_valid)
rs, re, rv = _boundary_segs(right_pts, right_bound_valid)

collision = collision | _segments_intersect_rect(
torch.cat([ls, rs], dim=1),
torch.cat([le, re], dim=1),
ego_corners,
torch.cat([lv, rv], dim=1),
)

return collision

def normalize_angle(self, angle: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
return (angle + np.pi) % (2 * np.pi) - np.pi

Expand Down
Loading