2525class Base (abc .ABC ):
2626 """Model Base Class"""
2727
28- def __init__ (self , model_name , device = "CPU" , threshold = 0.60 , extensions = None ):
28+ def __init__ (
29+ self ,
30+ model_name ,
31+ source_width = None ,
32+ source_height = None ,
33+ device = "CPU" ,
34+ threshold = 0.60 ,
35+ extensions = None ,
36+ ):
2937 self .model_weights = f"{ model_name } .bin"
3038 self .model_structure = f"{ model_name } .xml"
3139 assert (
@@ -45,8 +53,8 @@ def __init__(self, model_name, device="CPU", threshold=0.60, extensions=None):
4553 self .input_shape = self .model .inputs [self .input_name ].shape
4654 self .output_name = next (iter (self .model .outputs ))
4755 self .output_shape = self .model .outputs [self .output_name ].shape
48- self ._init_image_w = None
49- self ._init_image_h = None
56+ self ._init_image_w = source_width
57+ self ._init_image_h = source_height
5058 self .exec_network = None
5159 self .load_model ()
5260
@@ -91,15 +99,16 @@ def predict(self, image, request_id=0, draw=False):
9199 request_id = request_id , inputs = {self .input_name : p_image }
92100 )
93101 status = self .exec_network .requests [request_id ].wait (- 1 )
102+ bbox = None
94103 if status == 0 :
95104 predict_start_time = time .time ()
96105 pred_result = self .exec_network .requests [request_id ].outputs [
97106 self .output_name
98107 ]
99- predict_end_time = (time .time () - predict_start_time ) * 1000
108+ predict_end_time = float (time .time () - predict_start_time ) * 1000
100109 if draw :
101- self .preprocess_output (pred_result , image , show_bbox = draw )
102- return (predict_end_time , pred_result )
110+ bbox , _ = self .preprocess_output (pred_result , image , show_bbox = draw )
111+ return (predict_end_time , pred_result , bbox )
103112
104113 @abc .abstractmethod
105114 def preprocess_output (self , inference_results , image , show_bbox = False ):
@@ -128,8 +137,18 @@ def preprocess_input(self, image):
128137class Face_Detection (Base ):
129138 """Class for the Face Detection Model."""
130139
131- def __init__ (self , model_name , device = "CPU" , threshold = 0.60 , extensions = None ):
132- super ().__init__ (model_name , device = "CPU" , threshold = 0.60 , extensions = None )
140+ def __init__ (
141+ self ,
142+ model_name ,
143+ source_width = None ,
144+ source_height = None ,
145+ device = "CPU" ,
146+ threshold = 0.60 ,
147+ extensions = None ,
148+ ):
149+ super ().__init__ (
150+ model_name , source_width , source_height , device , threshold , extensions ,
151+ )
133152
134153 def preprocess_output (self , inference_results , image , show_bbox = False ):
135154 """Draw bounding boxes onto the frame."""
@@ -199,8 +218,18 @@ def draw_output(
199218class Head_Pose_Estimation (Base ):
200219 """Class for the Head Pose Estimation Model."""
201220
202- def __init__ (self , model_name , device = "CPU" , threshold = 0.60 , extensions = None ):
203- super ().__init__ (model_name , device = "CPU" , threshold = 0.60 , extensions = None )
221+ def __init__ (
222+ self ,
223+ model_name ,
224+ source_width = None ,
225+ source_height = None ,
226+ device = "CPU" ,
227+ threshold = 0.60 ,
228+ extensions = None ,
229+ ):
230+ super ().__init__ (
231+ model_name , source_width , source_height , device , threshold , extensions ,
232+ )
204233
205234 def preprocess_output (self , inference_results , image ):
206235 pass
@@ -212,8 +241,18 @@ def draw_output(coords, image):
212241class Facial_Landmarks (Base ):
213242 """Class for the Facial Landmarks Detection Model."""
214243
215- def __init__ (self , model_name , device = "CPU" , threshold = 0.60 , extensions = None ):
216- super ().__init__ (model_name , device = "CPU" , threshold = 0.60 , extensions = None )
244+ def __init__ (
245+ self ,
246+ model_name ,
247+ source_width = None ,
248+ source_height = None ,
249+ device = "CPU" ,
250+ threshold = 0.60 ,
251+ extensions = None ,
252+ ):
253+ super ().__init__ (
254+ model_name , source_width , source_height , device , threshold , extensions ,
255+ )
217256
218257 def preprocess_output (self , inference_results , image ):
219258 pass
@@ -225,8 +264,18 @@ def draw_output(coords, image):
225264class Gaze_Estimation (Base ):
226265 """Class for the Gaze Estimation Detection Model."""
227266
228- def __init__ (self , model_name , device = "CPU" , threshold = 0.60 , extensions = None ):
229- super ().__init__ (model_name , device = "CPU" , threshold = 0.60 , extensions = None )
267+ def __init__ (
268+ self ,
269+ model_name ,
270+ source_width = None ,
271+ source_height = None ,
272+ device = "CPU" ,
273+ threshold = 0.60 ,
274+ extensions = None ,
275+ ):
276+ super ().__init__ (
277+ model_name , source_width , source_height , device , threshold , extensions ,
278+ )
230279
231280 def preprocess_output (self , inference_results , image ):
232281 pass
0 commit comments