Skip to content

Commit d5be4fc

Browse files
authored
Merge pull request #52 from Paper-Chart-Extraction-Project/add-nms-sorting-parameter
Add a Custom Sorting Function for NMS
2 parents b83fbd6 + dd9619f commit d5be4fc

5 files changed

Lines changed: 24 additions & 5 deletions

File tree

ChartExtractor/extraction/checkboxes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def detect_checkboxes(
133133
detections=detections,
134134
threshold=0.8,
135135
overlap_comparator=intersection_over_minimum,
136+
sorting_fn=lambda det: det.annotation.area * det.annotation.confidence,
136137
)
137138
return [det.annotation for det in detections]
138139

ChartExtractor/extraction/extraction.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ def make_document_landmark_detections(
331331
MODEL_CONFIG["intraoperative_document_landmarks"]["vert_overlap_proportion"],
332332
)
333333
detections = non_maximum_suppression(
334-
detections, overlap_comparator=intersection_over_minimum
334+
detections,
335+
overlap_comparator=intersection_over_minimum,
336+
sorting_fn=lambda det: det.annotation.area * det.confidence,
335337
)
336338
del document_model
337339
return detections
@@ -479,13 +481,22 @@ def tile_predict(
479481
)
480482

481483
sys_dets: List[Detection] = non_maximum_suppression(
482-
sys_dets, 0.5, intersection_over_minimum
484+
sys_dets,
485+
0.5,
486+
intersection_over_minimum,
487+
lambda det: det.annotation.area * det.confidence,
483488
)
484489
dia_dets: List[Detection] = non_maximum_suppression(
485-
dia_dets, 0.5, intersection_over_minimum
490+
dia_dets,
491+
0.5,
492+
intersection_over_minimum,
493+
lambda det: det.annotation.area * det.confidence,
486494
)
487495
hr_dets: List[Detection] = non_maximum_suppression(
488-
hr_dets, 0.5, intersection_over_minimum
496+
hr_dets,
497+
0.5,
498+
intersection_over_minimum,
499+
lambda det: det.annotation.area * det.confidence,
489500
)
490501

491502
dets: List[Detection] = sys_dets + dia_dets + hr_dets

ChartExtractor/extraction/extraction_utilities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def detect_numbers(
146146
detections=detections,
147147
threshold=0.5,
148148
overlap_comparator=intersection_over_minimum,
149+
sorting_fn=lambda det: det.annotation.area * det.confidence,
149150
)
150151
return detections
151152

ChartExtractor/utilities/annotations.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,11 @@ def box(self) -> List[int]:
223223
"""A list containing this `BoundingBox`'s [left, top, right, bottom]."""
224224
return [self.left, self.top, self.right, self.bottom]
225225

226+
@property
227+
def area(self) -> float:
228+
"""The area of the box."""
229+
return (self.right - self.left) * (self.bottom - self.top)
230+
226231
def set_box(self, new_left: int, new_top: int, new_right: int, new_bottom: int):
227232
"""Sets this BoundingBox's values for left, top, right, bottom.
228233

ChartExtractor/utilities/detection_reassembly.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def non_maximum_suppression(
9595
overlap_comparator: Callable[
9696
[Detection, Detection], float
9797
] = intersection_over_union,
98+
sorting_fn: Callable[[Detection], bool] = lambda d: d.confidence,
9899
) -> List[Detection]:
99100
"""Applies Non-Maximum Suppression (NMS) to a list of detections.
100101
@@ -117,7 +118,7 @@ def non_maximum_suppression(
117118
Returns:
118119
A list of `Detection` objects containing the filtered detections after applying NMS.
119120
"""
120-
detections = sorted(detections, key=lambda d: d.confidence, reverse=True)
121+
detections = sorted(detections, key=sorting_fn, reverse=True)
121122
ix = 0
122123
while ix < len(detections):
123124
jx = ix + 1

0 commit comments

Comments
 (0)