-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathautomatic_sam.py
More file actions
206 lines (160 loc) · 6.88 KB
/
automatic_sam.py
File metadata and controls
206 lines (160 loc) · 6.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor
# Set paths
images_folder = os.path.join('raw_data', 'images')
labels_folder = os.path.join('raw_data', 'labels')
def yolo_to_pixel(yolo_label, image_width, image_height):
class_id, x_center, y_center, bbox_width, bbox_height = yolo_label
x_center_pixel = x_center * image_width
y_center_pixel = y_center * image_height
bbox_width_pixel = bbox_width * image_width
bbox_height_pixel = bbox_height * image_height
x_min = x_center_pixel - (bbox_width_pixel / 2)
y_min = y_center_pixel - (bbox_height_pixel / 2)
x_max = x_center_pixel + (bbox_width_pixel / 2)
y_max = y_center_pixel + (bbox_height_pixel / 2)
return {
'class_id': class_id,
'x_center': x_center,
'y_center': y_center,
'bbox_width': bbox_width,
'bbox_height': bbox_height,
'x_center_pixel': x_center_pixel,
'y_center_pixel': y_center_pixel,
'x_min': x_min,
'y_min': y_min,
'x_max': x_max,
'y_max': y_max
}
def check_overlap(box1, box2):
return not (box1['x_max'] < box2['x_min'] or box1['x_min'] > box2['x_max'] or
box1['y_max'] < box2['y_min'] or box1['y_min'] > box2['y_max'])
def check_edge(coord):
if coord['x_center'] + (coord['bbox_width'] / 2) >= 0.99:
return True
if coord['x_center'] - (coord['bbox_width'] / 2) <= 0.01:
return True
if coord['y_center'] + (coord['bbox_height'] / 2) >= 0.99:
return True
if coord['y_center'] - (coord['bbox_height'] / 2) <= 0.01:
return True
return False
def process_yolo_labels(label_file_path, image_width, image_height):
coordinates_list = []
try:
with open(label_file_path, 'r') as file:
for line in file:
parts = line.strip().split()
if len(parts) != 5:
continue
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
bbox_width = float(parts[3])
bbox_height = float(parts[4])
yolo_label = [class_id, x_center, y_center, bbox_width, bbox_height]
pixel_coordinates = yolo_to_pixel(yolo_label, image_width, image_height)
coordinates_list.append(pixel_coordinates)
except FileNotFoundError:
print(f"Label file not found: {label_file_path}")
return coordinates_list
def get_image_dimensions(image_path):
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"Image file not found: {image_path}")
return image.shape[1], image.shape[0]
def process_all_files(images_folder, labels_folder):
if not os.path.isdir(images_folder):
raise FileNotFoundError(f"Images directory does not exist: {images_folder}")
if not os.path.isdir(labels_folder):
raise FileNotFoundError(f"Labels directory does not exist: {labels_folder}")
image_coords_list = []
for file_name in os.listdir(images_folder):
if file_name.lower().endswith(('.jpg', '.png')):
image_path = os.path.join(images_folder, file_name)
label_path = os.path.join(labels_folder, file_name.rsplit('.', 1)[0] + '.txt')
if not os.path.isfile(label_path):
print(f"Label file not found for {file_name}")
continue
try:
image_width, image_height = get_image_dimensions(image_path)
coordinates = process_yolo_labels(label_path, image_width, image_height)
image_coords_tuple = (image_path, coordinates)
image_coords_list.append(image_coords_tuple)
except Exception as e:
print(f"Error processing file {file_name}: {e}")
return image_coords_list
all_image_coords = process_all_files(images_folder, labels_folder)
# Load the SAM model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
def runRandomlySegment(filename, coords, name, confidence_threshold):
image = cv2.imread(filename)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
input_point = np.array([coords])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
# Apply confidence thresholding
valid_masks = []
valid_scores = []
for i, score in enumerate(scores):
if score >= confidence_threshold:
valid_masks.append(masks[i])
valid_scores.append(score)
print(score)
if not valid_masks:
print(f"No valid masks found for {name} with confidence threshold {confidence_threshold}.")
return
best_mask_index = np.argmax(valid_scores)
best_mask = valid_masks[best_mask_index]
print(f'Using mask {best_mask_index} with score {valid_scores[best_mask_index]}')
masked_image = image.copy()
masked_image[best_mask == 0] = 0
rgba_image = Image.fromarray(masked_image).convert('RGBA')
rgba_np = np.array(rgba_image)
rgba_np[:, :, 3] = (best_mask == 1).astype(np.uint8) * 255
result_image = Image.fromarray(rgba_np)
output_path = f'segmented/images/{name}.png'
os.makedirs(os.path.dirname(output_path), exist_ok=True)
result_image.save(output_path)
counter = 0
processed_coords = []
for image_path, coords in all_image_coords:
print(f"Image Path: {image_path}")
for coord in coords:
overlap_found = any(check_overlap(coord, processed_coord) for processed_coord in processed_coords)
if overlap_found:
print(f"Skipping overlapping fish at {coord['x_center_pixel']}, {coord['y_center_pixel']}")
continue
input_coord = [coord['x_center_pixel'], coord['y_center_pixel']]
save_name = f'image{counter}'
if coord['class_id'] == 6:
print("SHARK")
continue
if check_edge(coord):
print('Label touching edge.')
continue
runRandomlySegment(image_path, input_coord, save_name, confidence_threshold=0.95)
label_path = f'segmented/labels/{save_name}.txt'
os.makedirs(os.path.dirname(label_path), exist_ok=True)
print(f'Saving as {save_name} with label path {label_path}')
with open(label_path, 'w') as file:
file.write(f'{coord["class_id"]} {coord["x_center"]} {coord["y_center"]} {coord["bbox_width"]} {coord["bbox_height"]}')
processed_coords.append(coord)
counter += 1
print('Resetting Overlap Coords')
processed_coords = []