-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfeature_extractor.py
More file actions
131 lines (109 loc) · 4.28 KB
/
feature_extractor.py
File metadata and controls
131 lines (109 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Feature extraction utilities for IBSV experiments."""
from __future__ import annotations
import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
try: # pragma: no cover
import cv2
except ImportError: # pragma: no cover
cv2 = None # type: ignore[assignment]
def _to_grayscale(image: np.ndarray) -> np.ndarray:
"""Convert RGB image to grayscale float32."""
if image.ndim == 2:
return image.astype(np.float32)
if image.shape[2] == 1:
return image[..., 0].astype(np.float32)
r, g, b = image[..., 0], image[..., 1], image[..., 2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray.astype(np.float32)
@dataclass
class FeatureExtractor:
"""Simple feature extractor backed by OpenCV or a fallback gradient method."""
max_features: int = 200
quality_level: float = 0.01
min_distance: float = 5.0
def extract(self, image: np.ndarray) -> np.ndarray:
"""Return feature pixel coordinates (u, v)."""
gray = _to_grayscale(image)
if cv2 is not None:
return self._extract_cv(gray)
return self._extract_fallback(gray)
def _extract_cv(self, gray: np.ndarray) -> np.ndarray:
"""Extract features using OpenCV's goodFeaturesToTrack which implements Shi-Tomasi corner detection
which uses eigenvalues of the gradient covariance matrix."""
corners = cv2.goodFeaturesToTrack(
gray,
maxCorners=self.max_features,
qualityLevel=self.quality_level,
minDistance=self.min_distance,
)
if corners is None:
return np.zeros((0, 2))
corners = corners.reshape(-1, 2)
return corners
def _extract_fallback(self, gray: np.ndarray) -> np.ndarray:
"""Fallback feature extractor using simple gradient magnitude scoring."""
gy, gx = np.gradient(gray)
score = gx**2 + gy**2
k = min(self.max_features, score.size)
if k == 0:
return np.zeros((0, 2))
idx = np.argpartition(score.ravel(), -k)[-k:]
rows, cols = np.unravel_index(idx, gray.shape)
corners = np.column_stack((cols, rows))
return corners.astype(np.float32)
def _plot_frame(image: np.ndarray, features: np.ndarray, title: str) -> None:
plt.cla()
plt.imshow(image)
if len(features):
plt.scatter(features[:, 0], features[:, 1], s=18, c="tab:red", marker="x")
plt.title(title)
plt.axis("off")
plt.draw()
plt.pause(0.1)
def main() -> None:
parser = argparse.ArgumentParser(
description="Extract image features from a camera log generated by run_bullet_simulation."
)
parser.add_argument("log_path", type=Path, help="Path to camera log (.npz).")
parser.add_argument("--stride", type=int, default=5, help="Process one every N frames.")
parser.add_argument("--max-features", type=int, default=200)
parser.add_argument("--quality-level", type=float, default=0.01)
parser.add_argument("--min-distance", type=float, default=5.0)
parser.add_argument("--headless", action="store_true", help="Skip visualization, only print stats.")
args = parser.parse_args()
if not args.log_path.exists():
parser.error(f"{args.log_path} does not exist")
log = np.load(args.log_path)
images = log["image"]
times = log["time"]
extractor = FeatureExtractor(
max_features=args.max_features,
quality_level=args.quality_level,
min_distance=args.min_distance,
)
stride = max(1, args.stride)
total = 0
if not args.headless:
plt.ion()
fig = plt.figure(figsize=(4.5, 3.5))
for idx in range(0, len(images), stride):
feats = extractor.extract(images[idx])
total += len(feats)
if not args.headless:
_plot_frame(
images[idx],
feats,
f"t={times[idx]:.2f}s | features={len(feats)}",
)
else:
print(f"Frame {idx}: {len(feats)} features")
print(f"Processed {len(range(0, len(images), stride))} frames, avg features {total / max(1, len(range(0, len(images), stride))):.1f}")
if not args.headless:
plt.ioff()
plt.show()
if __name__ == "__main__":
main()