-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Open
Description
Hi,
I am trying to fine tune SAM on custom images and masks but am struggling and am hoping someone can point me in the right direction to resolving it.
I have been referencing 331_fine_tune_SAM_mito.ipynb
I cannot get the training to work as I get this message at the forward pass step:
'The input_points must be a 3D tensor. Of shape batch_size, nb_boxes, 4.', ' got torch.Size([2, 4]).'
I think the input_boxes is wrong somehow?
The images I am using are colour PNG images rather than the tiff images in the reference code and are showing with 3 channels here....

My SamDataset code is:
class SAMDataset(Dataset):
"""
This class is used to create a dataset that serves input images and masks.
It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
"""
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
image = item["image"]
ground_truth_mask = np.array(item["label"])
# get bounding box prompt
# prompt = get_bounding_box(ground_truth_mask)
prompt = item["bounding_box"]
# prepare image and prompt for the model
inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
# remove batch dimension which the processor adds by default
inputs = {k:v.squeeze(0) for k,v in inputs.items()}
# add ground truth segmentation
inputs["ground_truth_mask"] = ground_truth_mask
return inputsMetadata
Metadata
Assignees
Labels
No labels

