• Stars
    star
    122
  • Rank 290,361 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created about 3 years ago
  • Updated 8 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A graph-transformer for whole slide image classification

A graph-transformer for whole slide image classification

This work is published in IEEE Transactions on Medical Imaging (https://doi.org/10.1109/TMI.2022.3176598).

Introduction

This repository contains a PyTorch implementation of a deep learning based graph-transformer for whole slide image (WSI) classification. We propose a Graph-Transformer (GT) network that fuses a graph representation of a WSI and a transformer that can generate WSI-level predictions in a computationally efficient fashion.

To demonstrate the applicability of our approach, we selected 3,024 hematoxylin and eosin WSIs of lung tumors and the oneswith normal histology from the Clinical Proteomic TumorAnalysis Consortium (CPTAC), the National Lung ScreeningTrial (NLST) and The Cancer Genome Atlas (TCGA) and developed a model to distinguish adenocarcinoma (LUAD) and squamous cell carcinoma (LSCC) from those that havenormal histology. To understand how our model processes WSI data and visualize regions that are highly associated with the class label, we proposed a novel class activation mapping technique called GraphCAM on graphs. see below:

Usage

1. Graph Construction

(a) Tiling Patch

python src/tile_WSI.py -s 512 -e 0 -j 32 -B 50 -M 20 -o <full_patch_to_output_folder> "full_path_to_input_slides/*/*.svs"

Mandatory parameters:

  • -s is tile_size: 512 (512x512 pixel tiles)
  • -e is overlap, 0 (no overlap between adjacent tiles). Important: the overlap is defined as "the number of extra pixels to add to each interior edge of a tile". Which means that the final tile size is s + 2.e. So to get a 512px tile with a 50% overlap, you need to set s to 256 and e to 128. Also, tile from the edges of the slide will be smaller (since up to two sides have no "interior" edge)
  • -j is number of threads: 32
  • -B is Max Percentage of Background allowed: 50% (tiles removed if background percentage above this value)
  • -o is the path were the output images must be saved
  • -M set to -1 by default to tile the image at all magnifications. Set it to the value of the desired magnification to tile only at that magnification and save space
  • (b) Training Patch Feature Extractor

    Go to './feature_extractor' and config 'config.yaml' before training. The trained feature extractor based on contrastive learning is saved in folder './feature_extractor/runs'. We train the model with patches cropped in single magnification (20X). Before training, put paths to all pathces in 'all_patches.csv' file.

    python run.py
    

    You could use pretrained feature extractor: feature_extractor/model.pth.

    (c) Constructing Graph

    Go to './feature_extractor' and build graphs from patches:

    python build_graphs.py --weights "path_to_pretrained_feature_extractor" --dataset "path_to_patches" --output "../graphs"
    

    2. Training Graph-Transformer

    Run the following script to train and store the model and logging files under "graph_transformer/saved_models" and "graph_transformer/runs".

    bash scripts/train.sh
    

    To evaluate the model. run bash scripts/test.sh

    Split training, validation, and testing dataset and store them in text files as:

    sample1 \t label1
    sample2 \t label2
    LUAD/C3N-00293-23 \t luad
    ...
    

    3. GraphCAM

    To generate GraphCAM of the model on the WSI:

    1. bash scripts/get_graphcam.sh
    

    To visualize the GraphCAM:

    2. bash scripts/vis_graphcam.sh
    

    Note: Currently we only support generating GraphCAM for one WSI at each time.

    More GraphCAM examples:

    GraphCAMs generated on WSIs across the runs performed via 5-fold cross validation are shown above. The same set of WSI regions are highlighted by our method across the various cross-validation folds, thus indicating consistency of our technique in highlighting salient regions of interest.

    Requirements

  • WSI software: PixelView (deepPath, Inc.)
  • Major dependencies are:
  • python
  • pytorch
  • openslide-python
  • Weights & Biases
  • More Repositories

    1

    brain2020

    Development and validation of an interpretable deep learning framework for Alzheimer's disease classification
    Python
    155
    star
    2

    ncomms2022

    Multimodal deep learning for Alzheimer's disease dementia assessment
    Python
    83
    star
    3

    nmed2024

    AI-based differential diagnosis of dementia etiologies on multimodal data
    Python
    32
    star
    4

    azrt2020

    Enhancing magnetic resonance imaging driven Alzheimerโ€™s disease classification performance using generative adversarial learning
    Python
    19
    star
    5

    medpodgpt

    MedPodGPT: A multilingual audio-augmented large language model for medical research and education
    Python
    16
    star
    6

    peds2019

    Quantifying the nativeness of antibody sequences using long short-term memory networks
    Python
    16
    star
    7

    adrd_tool

    AI-based differential diagnosis of dementia etiologies on multimodal data
    Python
    8
    star
    8

    azrt2021

    Detection of dementia on voice recordings using deep learning: a Framingham Heart Study
    Python
    8
    star
    9

    python_speech_features_cuda

    This is a re-implementation of "Python Speech Features" that offers up to hundreds of times performance boost on CUDA enabled GPUs.
    Python
    6
    star
    10

    tmi2024

    Graph attention-based fusion of pathology images and gene expression for prediction of cancer survival
    Python
    5
    star
    11

    er2020

    Assessment of knee pain from MR imaging using a convolutional Siamese network
    Python
    4
    star
    12

    hbm2024

    Disease-driven domain generalization for neuroimaging-based assessment of Alzheimer's disease
    Python
    4
    star
    13

    iscience2023-old

    Deep learning for risk-based stratification of cognitively impaired individuals
    Python
    3
    star
    14

    ajpa2021

    Deep learning driven quantification of interstitial fibrosis in kidney biopsies
    Python
    3
    star
    15

    unaah

    U-Net-and-a-half
    Python
    2
    star
    16

    access2024

    Adversarial learning for MRI reconstruction and classification of cognitively impaired individuals
    Python
    2
    star
    17

    jad2023

    Fusion of low-level descriptors of digital voice recordings for dementia assessment
    Python
    1
    star
    18

    skra2024

    Survival analysis on subchondral bone length for total knee replacement
    Jupyter Notebook
    1
    star
    19

    ar2021

    Subchondral bone length in knee osteoarthritis: A deep learning driven imaging measure and its association with radiographic and clinical outcomes
    Python
    1
    star
    20

    iscience2023

    Deep learning for risk-based stratification of cognitively impaired individuals
    Python
    1
    star