Skip to content

VAL_CENTER Indice in camelyon17_dataset.py  #165

@David-Drexlin

Description

@David-Drexlin

Hi everyone,

I hope this is the right place to ask a question about the Camelyon17 dataset. My question is regarding the center-metadata indices for TEST_CENTER and VAL_CENTER, as defined in the camelyon17_dataset.py file. According to that file, the test and validation (OOD) centers are 0-indexed, with TEST_CENTER at index 2 and VAL_CENTER at index 1. My understanding is that this should correspond to the images shown in columns 5 and 4 of the paper (see the first image for reference). Is that correct?

When I naively plot the images according to their center labels per row (see the second image), I would expect the images for indices 2 and 1 to show the test and validation (OOD) slides in row 2 and 1 (zero-index) as well. Instead, it seems like the (validation) center indices are switched, with the test images corresponding to index 2 and validation (OOD) to index 4 instead of 1. Also inspecting the images directly in the data/patches directory showcases this behaviour e.g. patient 96 from center 4 seems to be Val (ODD) and e.g. patient 34 from center 1 seems to be part of train, at least visually to a layman.

Did I misunderstand something in the indexing or do you have any clue what could be wrong? Below are the images for reference and the code I used to generate them:

Wilds slides:
camelyon_dataset

Slides as per my Code: slides_per_domain_class

Thanks in advance for any clarification!

Code:

import os
import pandas as pd
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from collections import defaultdict

# constants
DATA_DIR = '/data/camelyon17_v1.0'
PATCHES_DIR = os.path.join(DATA_DIR, 'patches')
METADATA_CSV = os.path.join(DATA_DIR, 'metadata.csv')
MAX_IMAGES_PER_COMBINATION = 5
NUM_DOMAINS = 5  
NUM_CLASSES = 2  

# Load the metadata
metadata_df = pd.read_csv(
    METADATA_CSV,
    index_col=0,
    dtype={'patient': 'str'}
)

# Get labels
y_array = torch.LongTensor(metadata_df['tumor'].values)

# Get input image paths
input_paths = [
    os.path.join(
        PATCHES_DIR,
        f'patient_{patient}_node_{node}',
        f'patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png'
    )
    for patient, node, x, y in metadata_df[['patient', 'node', 'x_coord', 'y_coord']].values
]

# Get domains (centers)
centers = metadata_df['center'].astype(int).values

# Organize images into a dictionary keyed by (domain, class)
images_dict = defaultdict(list)

for img_path, label, domain in zip(input_paths, y_array, centers):
    key = (domain, label.item())
    if len(images_dict[key]) < MAX_IMAGES_PER_COMBINATION:
        try:
            img = Image.open(img_path).convert('RGB')
            images_dict[key].append(img)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")

# plot 
fig, axes = plt.subplots(nrows=NUM_DOMAINS, ncols=NUM_CLASSES * MAX_IMAGES_PER_COMBINATION, figsize=(24, 12))
plt.subplots_adjust(wspace=0.05, hspace=0.05)

for domain_idx in range(NUM_DOMAINS):
    for class_idx in range(NUM_CLASSES):
        key = (domain_idx, class_idx)
        images = images_dict.get(key, [])
        for img_idx in range(MAX_IMAGES_PER_COMBINATION):
            col_idx = class_idx * MAX_IMAGES_PER_COMBINATION + img_idx
            ax = axes[domain_idx, col_idx]
            if img_idx < len(images):
                ax.imshow(images[img_idx])
            ax.axis('off')

            if domain_idx == 0 and img_idx == 0:
                ax.set_title(f"Class {class_idx}")

        # Add domain labels to the first image in each row
        if class_idx == 0:
            ax = axes[domain_idx, 0]
            ax.text(-30, 32, f"Domain {domain_idx}", rotation=90, va='center', fontsize=12)
            #ax.text(-150, images[0].size[1] // 2, f"Domain {domain_idx}", rotation=90, va='center', fontsize=12)

plt.tight_layout()
plt.savefig("slides_per_domain_class.png")

Or very straightforward and then inspect:

import os
from wilds import get_dataset

def save_images():
    # Create the 'images' directory if it doesn't exist
    if not os.path.exists('images'):
        os.makedirs('images')

    # Load the camelyon17 dataset
    dataset = get_dataset(dataset='camelyon17', download=True)
    
    # Get the validation and test subsets
    val_data = dataset.get_subset('val')

    # Save the first 10 images from the validation set
    for i in range(10):
        x, y, metadata = val_data[i]
        x.save('images/val{}.png'.format(i+1))

if __name__ == '__main__':
    save_images()

Cheers David

Originally posted by @David-Drexlin in #163

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions