Skip to content

Fine tuning SAM, 'The input_points must be a 3D tensor. Of shape batch_size, nb_boxes, 4.', ' got torch.Size([2, 4]).' #93

@leemorton

Description

@leemorton

Hi,

I am trying to fine tune SAM on custom images and masks but am struggling and am hoping someone can point me in the right direction to resolving it.

I have been referencing 331_fine_tune_SAM_mito.ipynb

I cannot get the training to work as I get this message at the forward pass step:
'The input_points must be a 3D tensor. Of shape batch_size, nb_boxes, 4.', ' got torch.Size([2, 4]).'

I think the input_boxes is wrong somehow?

image

The images I am using are colour PNG images rather than the tiff images in the reference code and are showing with 3 channels here....
image

My SamDataset code is:

class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["label"])
    
    # get bounding box prompt
    # prompt = get_bounding_box(ground_truth_mask)
    prompt = item["bounding_box"]

    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

and this is where I run into trouble...
image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions