Skip to content

Commit c6be7c8

Browse files
committed
added headpose estimation
Signed-off-by: Mpho Mphego <mpho112@gmail.com>
1 parent df0d551 commit c6be7c8

File tree

2 files changed

+115
-14
lines changed

2 files changed

+115
-14
lines changed

main.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def main(args):
142142

143143
for frame in video_feed.next_frame():
144144

145-
predict_end_time, face_bboxes = face_detection.predict(frame, draw=True)
145+
predict_end_time, face_bboxes = face_detection.predict(frame, show_bbox=True)
146146
text = f"Face Detection Inference time: {predict_end_time:.3f} s"
147147
face_detection.add_text(text, frame, (15, video_feed.source_height - 80))
148148

@@ -165,13 +165,22 @@ def main(args):
165165
if face_height < 20 or face_width < 20:
166166
continue
167167

168-
predict_end_time, eyes_coords = facial_landmarks.predict(face, draw=True)
168+
predict_end_time, eyes_coords = facial_landmarks.predict(
169+
face, show_bbox=True
170+
)
169171
text = f"Facial Landmarks Est. Inference time: {predict_end_time:.3f} s"
170172
facial_landmarks.add_text(
171173
text, frame, (15, video_feed.source_height - 60)
172174
)
173175

174-
176+
predict_end_time, head_pose_angles = head_pose_estimation.predict(
177+
face, show_bbox=True
178+
)
179+
text = f"Head Pose Est. Inference time: {predict_end_time:.3f} s"
180+
head_pose_estimation.add_text(
181+
text, frame, (15, video_feed.source_height - 40)
182+
)
183+
# print (f"head pose: {head_pose_angles}")
175184

176185
if args.debug:
177186
video_feed.show(video_feed.resize(frame))

src/model.py

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import os
3+
import math
34
import sys
45
import time
56
import subprocess
@@ -23,6 +24,10 @@
2324
]
2425

2526

27+
class InvalidModel(Exception):
28+
pass
29+
30+
2631
class Base(ABC):
2732
"""Model Base Class"""
2833

@@ -53,12 +58,19 @@ def __init__(
5358
self.input_name = next(iter(self.model.inputs))
5459
self.input_shape = self.model.inputs[self.input_name].shape
5560
self.output_name = next(iter(self.model.outputs))
61+
self._output_shape = None
5662
self.output_shape = self.model.outputs[self.output_name].shape
5763
self._init_image_w = source_width
5864
self._init_image_h = source_height
5965
self.exec_network = None
6066
self.load_model()
6167

68+
# @property
69+
# def output_shape(self):
70+
# if not self._output_shape:
71+
# self._output_shape =
72+
# return self._output_shape
73+
6274
def _get_model(self):
6375
"""Helper function for reading the network."""
6476
try:
@@ -91,7 +103,7 @@ def load_model(self):
91103
f"Model: {self.model_structure} took {self._model_load_time:.3f} ms to load."
92104
)
93105

94-
def predict(self, image, request_id=0, draw=False):
106+
def predict(self, image, request_id=0, show_bbox=False):
95107
if not isinstance(image, np.ndarray):
96108
raise IOError("Image not parsed correctly.")
97109

@@ -100,15 +112,17 @@ def predict(self, image, request_id=0, draw=False):
100112
request_id=request_id, inputs={self.input_name: p_image}
101113
)
102114
status = self.exec_network.requests[request_id].wait(-1)
103-
bbox = None
104115
if status == 0:
105116
predict_start_time = time.time()
106-
pred_result = self.exec_network.requests[request_id].outputs[
107-
self.output_name
108-
]
117+
pred_result = []
118+
for output_name, data_ptr in self.model.outputs.items():
119+
pred_result.append(
120+
self.exec_network.requests[request_id].outputs[output_name]
121+
)
109122
predict_end_time = float(time.time() - predict_start_time) * 1000
110-
if draw:
111-
bbox, _ = self.preprocess_output(pred_result, image, show_bbox=draw)
123+
bbox, _ = self.preprocess_output(
124+
pred_result, image, show_bbox=show_bbox
125+
)
112126
return (predict_end_time, bbox)
113127

114128
@abstractmethod
@@ -162,6 +176,8 @@ def preprocess_output(self, inference_results, image, show_bbox=False):
162176
"""Draw bounding boxes onto the Face Detection frame."""
163177
if not (self._init_image_w and self._init_image_h):
164178
raise RuntimeError("Initial image width and height cannot be None.")
179+
if len(inference_results) == 1:
180+
inference_results = inference_results[0]
165181

166182
coords = []
167183
for box in inference_results[0][0]: # Output shape is 1x1xNx7
@@ -282,11 +298,86 @@ def __init__(
282298
model_name, source_width, source_height, device, threshold, extensions,
283299
)
284300

285-
def preprocess_output(self, inference_results, image):
286-
pass
301+
def preprocess_output(self, inference_results, image, show_bbox):
302+
"""
303+
Estimate the Head Pose on a cropped face.
304+
305+
Example
306+
-------
307+
Model: head-pose-estimation-adas-0001
308+
309+
Output layer names in Inference Engine format:
310+
311+
name: "angle_y_fc", shape: [1, 1] - Estimated yaw (in degrees).
312+
name: "angle_p_fc", shape: [1, 1] - Estimated pitch (in degrees).
313+
name: "angle_r_fc", shape: [1, 1] - Estimated roll (in degrees).
314+
315+
"""
316+
if len(inference_results) != 3:
317+
msg = (
318+
f"The model:{self.model_structure} does not contain expected output "
319+
"shape as per the docs."
320+
)
321+
self.logger.error(msg)
322+
raise InvalidModel(msg)
323+
324+
output_layer_names = ["yaw", "pitch", "roll"]
325+
flattened_predictions = np.vstack(inference_results).ravel()
326+
head_pose_angles = dict(zip(output_layer_names, flattened_predictions))
327+
if show_bbox:
328+
self.draw_output(head_pose_angles, image)
329+
330+
return head_pose_angles, image
287331

332+
@staticmethod
288333
def draw_output(coords, image):
289-
pass
334+
"""Draw head pose estimation on frame.
335+
336+
Ref: https://github.com/natanielruiz/deep-head-pose/blob/master/code/utils.py#L86+L117
337+
"""
338+
yaw, pitch, roll = coords.values()
339+
340+
pitch = pitch * np.pi / 180
341+
yaw = -(yaw * np.pi / 180)
342+
roll = roll * np.pi / 180
343+
344+
height, width = image.shape[:2]
345+
tdx = width / 2
346+
tdy = height / 2
347+
size = 1000
348+
349+
# X-Axis pointing to right. drawn in red
350+
x1 = size * (math.cos(yaw) * math.cos(roll)) + tdx
351+
y1 = (
352+
size
353+
* (
354+
math.cos(pitch) * math.sin(roll)
355+
+ math.cos(roll) * math.sin(pitch) * math.sin(yaw)
356+
)
357+
+ tdy
358+
)
359+
360+
# Y-Axis | drawn in green
361+
# v
362+
x2 = size * (-math.cos(yaw) * math.sin(roll)) + tdx
363+
y2 = (
364+
size
365+
* (
366+
math.cos(pitch) * math.cos(roll)
367+
- math.sin(pitch) * math.sin(yaw) * math.sin(roll)
368+
)
369+
+ tdy
370+
)
371+
372+
# Z-Axis (out of the screen) drawn in blue
373+
x3 = size * (math.sin(yaw)) + tdx
374+
y3 = size * (-math.cos(yaw) * math.sin(pitch)) + tdy
375+
376+
cv2.line(image, (int(tdx), int(tdy)), (int(x1), int(y1)), (0, 0, 255), 3)
377+
cv2.line(image, (int(tdx), int(tdy)), (int(x2), int(y2)), (0, 255, 0), 3)
378+
cv2.line(image, (int(tdx), int(tdy)), (int(x3), int(y3)), (255, 0, 0), 2)
379+
380+
return image
290381

291382

292383
class Gaze_Estimation(Base):
@@ -305,8 +396,9 @@ def __init__(
305396
model_name, source_width, source_height, device, threshold, extensions,
306397
)
307398

308-
def preprocess_output(self, inference_results, image):
399+
def preprocess_output(self, inference_results, image, show_bbox):
309400
pass
310401

402+
@staticmethod
311403
def draw_output(coords, image):
312404
pass

0 commit comments

Comments
 (0)