import timm
import os
import torch
from torchvision import transforms
from PIL import Image
import openslide
import numpy as np
from skimage.color import rgb2gray
import json
#### definition of the model
model = timm.create_model("vit_large_patch16_224", img_size=224, patch_size=16, init_values=1e-5, num_classes=0, dynamic_img_size=True)
model.load_state_dict(torch.load("PATH.bin", map_location="mps",weights_only=True), strict=True)
transform = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
def is_tissue(tile, extrem_val = 0.1,tissue_threshold=0.02, homogeneity_threshold=0.005):
"""
Extract tiles with tissue from a whole-slide image.
:param tile: tile to analyse ([:,:,1])
:param extrem_val: extrem grey value to exclude (default: 0.1)
:param tissue_threshold: ratio of pixels to consider as tissu in the tile (default: 0.02)
:param homogeneity_threshold : homogeneity in the tile (default: 5x)
"""
return (np.sum((tile < (1-extrem_val)) & (tile > extrem_val))/ tile.size) > tissue_threshold and (np.var(tile)) > homogeneity_threshold
def extract_tiles_with_tissue(slide_path,geojson_path, magnification=20, tile_size=224, low_mag=5, extract_features=False):
"""
Extract tiles with tissue from a whole-slide image.
:param slide_path: Path to the whole-slide image (.mrxs, .svs, .ndpi)
:param geojson_path: Path to the geojson (.geojson)
:param magnification: Desired magnification for tile extraction (e.g., 20x)
:param tile_size: Size of the square tile in pixels (default: 224x224)
:param low_mag: Low magnification level for tissue detection (default: 5x)
:param extract_features: do you want to extract features ? (default: false)
"""
# Open the whole-slide image
slide = openslide.OpenSlide(slide_path)
# Get the best level to use for tissue detection (closest to low_mag)
low_mag_level = slide.get_best_level_for_downsample(float(slide.properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER]) / low_mag)
low_mag_downsample_factor = slide.level_downsamples[low_mag_level]
# Get dimensions of the slide at this low magnification level
low_mag_width, low_mag_height = slide.level_dimensions[low_mag_level]
# Calculate the scaled tile size for tissue detection (in low magnification level)
scaled_tile_size = int(tile_size * low_mag_downsample_factor)
tile_id = 0
features=[]
feat=[]
# Iterate over the slide at low magnification level and detect tissue
for y in range(0, low_mag_height, scaled_tile_size):
for x in range(0, low_mag_width, scaled_tile_size):
# Corrected: Scale back to full-resolution coordinates for reading the tile
region_x = int(x * low_mag_downsample_factor)
region_y = int(y * low_mag_downsample_factor)
region_size = int(scaled_tile_size / low_mag_downsample_factor)
# Extract the tile at low magnification level for tissue detection
low_mag_tile = rgb2gray(np.array(slide.read_region((region_x, region_y), low_mag_level, (scaled_tile_size, scaled_tile_size)))[:,:,0:3])
# Check if the tile contains tissue
if is_tissue(low_mag_tile):
for high_y in range(0, scaled_tile_size*int(low_mag_downsample_factor), tile_size):
for high_x in range(0, scaled_tile_size*int(low_mag_downsample_factor), tile_size):
global_x = int(region_x + high_x)
global_y = int(region_y + high_y)
# Read the corresponding tile at the desired magnification (e.g., 20x)
high_mag_tile = slide.read_region(
(global_x, global_y),
0, # Full-resolution level (20x)
(tile_size, tile_size)
)
if is_tissue(rgb2gray(np.array(high_mag_tile)[:,:,0:3])):
high_mag_tile = high_mag_tile.convert("RGB")
image = transform(high_mag_tile).unsqueeze(dim=0) # Image (torch.Tensor) with shape [1, 3, 224, 224] followingimage resizing and normalization (ImageNet parameters)
if extract_features :
with torch.inference_mode():
feature_emb = model(image) # Extracted features (torch.Tensor) with shape [1,1024]
feat.append(feature_emb)
# Create GeoJSON feature for this tile
feature = {
"type": "Feature",
"geometry": {
"type": "Polygon",
"coordinates": [[
[global_x, global_y], # Top-left corner
[global_x + tile_size, global_y], # Top-right corner
[global_x + tile_size, global_y + tile_size], # Bottom-right corner
[global_x, global_y + tile_size], # Bottom-left corner
[global_x, global_y] # Closing the polygon
]]
},
"properties": {
"tile_id": tile_id,
"tile_path": slide_path,
"width": tile_size,
"height": tile_size
}
}
features.append(feature)
tile_id += 1
# Save GeoJSON
geojson_data = {
"type": "FeatureCollection",
"features": features
}
with open(geojson_path, 'w') as geojson_file:
json.dump(geojson_data, geojson_file, indent=4)
print(f"Tiles extraction complete! Extracted {tile_id} tiles with tissue.")
return feat,features
# Path to your whole-slide image (.mrxs, .svs, or .ndpi)
slide_path = 'PATH.ndpi'
# Output directory where geojson will be saved
geojson_path = 'PATH.geojson'
# Extract tiles with tissue at 20x magnification and save
tt=extract_tiles_with_tissue(slide_path,
geojson_path,
magnification=20,
tile_size=224,
low_mag=10,
extract_features=False)`