Image Segmentation - SAM

transfomers
image
segementation
ml
python
Author

David McGaughey

Published

October 28, 2025

1 Motivation

The work we do involves a lot imaging and much of the quantitative aspects involve hand measurements. I am currently in unpaid surprise sabbatical and I have always wanted to learn how to use machine learning approaches to segment / cluster / specify image features but haven’t managed to scrounge up the spare time - so here we go!

2 Notes

  1. I want to segment images
  2. Ideally much of it (all?)
  3. This is referred to as “semantic segmentation”. Each pixel gets a classification.
  4. To make my life far more difficult, I’m going to try to segment fundus imagery.

3 Retina Fundus

Retina

The fundus is the back of the eye (opposite the lens). That’s where the retina is. The retina is the part of your eye which turns light into signal for your brain to turn into images. Fundus photography is a non-invasive imaging technique which is used to help assess retinal health.

4 Main features

Retina Annotated

The macula / fovea is a the cone-rich segment where humans get their high visual acuity (this structure is “missing” in most other species). The optic disk is where the neurons go to send the visual information to the brain.

5 ML Approach

  1. Transformers are big right now
  2. They are the model architecture behind the LLM / GPT type systems
  3. So, let’s try out (a) transformer based image segmentation system
  4. Wait, isn’t the second L in LLM for language?
  5. Yes
  6. But adaptions are being made to the original transformer setup to handle other kinds of input (e.g. a picture)

6 SAM

“Segment Anything Model.” From Meta. Claims it can zero shot (no training) any object from any image with either a bounding box or a click.

7 R and python

So…I’m coding this in Rstudio in a quarto document. Which can handle both python and R in the same document and share (some) data between.

8 Step 1

I set up a conda environment “image_segmentation” with the required python libraries. Here I force this document to use this python.

9 Did it load the right python?

import sys
print(f"Python Version: {sys.version}")
Python Version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:11:29) [Clang 20.1.8 ]
print(f"Python Executable Path: {sys.executable}")
Python Executable Path: /opt/homebrew/Caskroom/miniconda/base/envs/image_segmentation/bin/python
import pandas as pd
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from transformers import Sam2Processor, Sam2Model, infer_device
from PIL import Image
import requests

10 R part.

As I’m in Rstudio, mine as well use a bit of R to load an arbitrary image and click to identify features to segment out. This is setup as a function

library(png)
library(jpeg)

# multiple positive/negative clicks AND/OR a bounding box
image_prompter <- function(path){
  # 1. Read the image
  img_path <- file.path(path)
  img <- readJPEG(img_path)
  
  # 2. Get the image dimensions
  img_dims <- dim(img)
  img_height <- img_dims[1]
  img_width <- img_dims[2]
  
  # 3. Set up the plot window
  par(mar = c(0, 0, 0, 0)) 
  plot(c(1, img_width), c(1, img_height), 
       type = "n", ann = FALSE, axes = FALSE, 
       ylim = c(img_height, 1)) # Inverts Y-axis
  
  # 4. Display the image
  rasterImage(img, 1, 1, img_width, img_height)
  
  # 5. Get POSITIVE click coordinates
  cat("Please click POSITIVE points (green *)...\n")
  cat("Right-click or press 'Esc' in the 'Plots' window when done.\n")
  pos_clicks <- locator(type = "p", pch = "*", col = "green", cex = 2) 
  
  # 6. Get NEGATIVE click coordinates
  cat("Please click NEGATIVE points (red *)...\n")
  cat("Right-click or press 'Esc' in the 'Plots' window when done.\n")
  neg_clicks <- locator(type = "p", pch = "*", col = "red", cex = 2)
  
  # 7. Get BOUNDING BOX coordinates (max 2 points)
  cat("Please click two corners for a BOUNDING BOX (blue box)...\n")
  cat("Right-click or press 'Esc' in the 'Plots' window to skip.\n")
  box_clicks <- locator(n = 2, type = "p", pch = "+", col = "blue", cex = 1.5)
  
  # 8. Process and combine coordinates
  all_coords_df <- data.frame(x = integer(), y = integer(), label = integer())
  box_df <- data.frame(xmin = integer(), ymin = integer(), xmax = integer(), ymax = integer())
  
  if (!is.null(pos_clicks) && length(pos_clicks$x) > 0) {
    pos_df <- data.frame(
      x = round(pos_clicks$x), 
      y = round(pos_clicks$y), 
      label = 1
    )
    all_coords_df <- rbind(all_coords_df, pos_df)
  }
  
  if (!is.null(neg_clicks) && length(neg_clicks$x) > 0) {
    neg_df <- data.frame(
      x = round(neg_clicks$x), 
      y = round(neg_clicks$y), 
      label = 0
    )
    all_coords_df <- rbind(all_coords_df, neg_df)
  }
  
  # 9. Process box coordinates
  if (!is.null(box_clicks) && length(box_clicks$x) == 2) {
    xmin <- round(min(box_clicks$x))
    xmax <- round(max(box_clicks$x))
    ymin <- round(min(box_clicks$y))
    ymax <- round(max(box_clicks$y))
    
    # Draw the box on the R plot for feedback
    rect(xmin, ymin, xmax, ymax, border = "blue", lty = "dashed", lwd = 2)
    
    box_df <- data.frame(xmin = xmin, ymin = ymin, xmax = xmax, ymax = ymax)
  }
  
  if (nrow(all_coords_df) == 0 && nrow(box_df) == 0) {
    cat("No prompts received. Returning NULL.\n")
    return(NULL)
  }
  
  # 10. Return a list containing both data frames
  return(list(clicks = all_coords_df, box = box_df))
}

11 Image Path

img_path <- "../data/Fundus_photograph_of_normal_left_eye.jpg"

12 RUN IN CONSOLE

You cannot run the commented out code below in Rstudio as the inline editor cannot handle user clicks. So you have to copy/paste and run this in the Console.

# click_data_df <- image_prompter(img_path)
# saveRDS(click_data_df, file = 'posts/data/sam_segmentation_opticdisk.rds')

# click_data_df <- image_prompter(img_path)
# saveRDS(click_data_df, file = 'posts/data/sam_segmentation_macula.rds')

# click_data_df <- image_prompter(img_path)
# saveRDS(click_data_df, file = 'posts/data/sam_segmentation_vessel.rds')

13 Load SAM2 model and import image

This is now in python

device = infer_device()
model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-large").to(device)

processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-large")


raw_image = Image.open(r.img_path).convert("RGB")

14 Load in hand generated clicks

click_data_df <- readRDS("../data/sam_segmentation_opticdisk.rds")

15

def show_mask_overlay(mask, ax, random_color=False):
    """Helper function to display a mask overlay on an image."""
    if random_color:
        color = np.concatenate([np.random.rand(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6]) # Dodger blue
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=150):
    """Helper function to display points on an image."""
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', 
    s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', 
    s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    """Helper function to display a box on an image."""
    x0, y0, x1, y1 = box
    w, h = x1 - x0, y1 - y0
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', 
    facecolor='none', lw=2))



def segmenter(prompt_data, image):
    clicks_df = prompt_data['clicks']
    box_df = prompt_data['box']

    if clicks_df.empty and box_df.empty:
        print("No prompts (clicks or box) were provided. Stopping execution.")

    else:
        # 2. Dynamically build processor arguments
        processor_args = {"images": image}

        input_points = None
        input_labels = None
        input_boxes = None

        if not clicks_df.empty:
            # (batch_size=1, num_prompts=1, num_points=N, 2)
            input_points = [[clicks_df[['x', 'y']].values.tolist()]]
            # (batch_size=1, num_prompts=1, num_points=N)
            input_labels = [[clicks_df['label'].values.tolist()]]

            processor_args["input_points"] = input_points
            processor_args["input_labels"] = input_labels

            if not box_df.empty:
                # (batch_size=1, num_prompts=1, 4) -> [[[xmin, ymin, xmax, ymax]]]
                input_boxes = [[box_df.iloc[0].values.tolist()]]
                processor_args["input_boxes"] = input_boxes

            # 3. Process inputs
            inputs = processor(**processor_args, return_tensors="pt").to(model.device)

            with torch.no_grad():
                outputs = model(**inputs)

                masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
                print(f"Generated {masks.shape[1]} masks with shape {masks.shape}")

                # -------------------------------------------------
                # Matplotlib visualization
                # -------------------------------------------------

                # Get the number of masks and IoU scores
                num_masks = masks.shape[1]
                iou_scores = outputs.iou_scores[0, 0].cpu().numpy()

                # Create a figure to display the results
                fig, axes = plt.subplots(1, num_masks + 1, figsize=(16, 4))

                # Plot 1: Original image + input prompts (points AND box)
                axes[0].imshow(image)

                # Show points if they exist
                if input_points:
                    point_coords_np = np.array(input_points[0][0])
                    point_labels_np = np.array(input_labels[0][0])
                    show_points(point_coords_np, point_labels_np, axes[0])

                # Show box if it exists
                if input_boxes:
                    box_np = np.array(input_boxes[0][0])
                    show_box(box_np, axes[0])

                axes[0].set_title("Original Image + Prompt(s)")
                axes[0].axis('off')

                # Plot 2, 3, 4: Image + each of the 3 masks
                for i in range(num_masks):
                    mask_np = masks[0, i].numpy()
                    iou = iou_scores[i]

                    ax = axes[i + 1]
                    ax.imshow(image)
                    show_mask_overlay(mask_np, ax)
                    ax.set_title(f"Mask {i + 1} (IoU: {iou:.3f})")
                    ax.axis('off')

                plt.tight_layout()
            return(plt)
          

16 Fairly succesful at identifying the optic disk

Green is positive (as in this is the feature), red is negative (this is NOT the feature), and the box surrounds the feature

click_data_df <- readRDS("../data/sam_segmentation_opticdisk.rds")
segmented_plt = segmenter(r.click_data_df, raw_image)
Generated 3 masks with shape torch.Size([1, 3, 1411, 1411])
plt.show()

17 Fairly succesful at identifying the macula

Green is positive (as in this is the feature), red is negative (this is NOT the feature), and the box surrounds the feature

click_data_df <- readRDS("../data/sam_segmentation_macula.rds")
segmented_plt = segmenter(r.click_data_df, raw_image)
Generated 3 masks with shape torch.Size([1, 3, 1411, 1411])
plt.show()

18 Terrible at the vessels

Green is positive (as in this is the feature), red is negative (this is NOT the feature), and the box surrounds the feature

click_data_df <- readRDS("../data/sam_segmentation_vessel.rds")
segmented_plt = segmenter(r.click_data_df, raw_image)
Generated 3 masks with shape torch.Size([1, 3, 1411, 1411])
plt.show()

18.1 Trying to up the contrast a bit

Nope, doesn’t help. A bit disappointing though not surprising as I doubt there is much vessel segmentation on fundus imaging in the training data for this model. Perhaps next I will see if I can use a LoRA approach to tune the SAM2 model….after doing a little bit of Googling with “SAM2 LoRA” I know see this paper.

click_data_df <- readRDS("../data/sam_segmentation_vessel.rds")
def change_contrast(img, level):
    factor = (259 * (level + 255)) / (255 * (259 - level))
    def contrast(c):
        return 128 + factor * (c - 128)
    return img.point(contrast)

raw_image_hic = change_contrast(Image.open(r.img_path), 50)

segmented_plt = segmenter(r.click_data_df, raw_image_hic)
Generated 3 masks with shape torch.Size([1, 3, 1411, 1411])
plt.show()

19 Conclusion

I’d like to wrap this post so it doesn’t take me 2-75 weeks to finish. I’ve sort of got something that kind of does something. For my next try, I’ll take a gander at whether I can add a LoRA model into this workflow. Which will require hand labelling my own images next so I can build a train/test dataset.