• Stars
    star
    162
  • Rank 232,284 (Top 5 %)
  • Language
    Jupyter Notebook
  • License
    BSD 2-Clause "Sim...
  • Created almost 7 years ago
  • Updated over 5 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

🏥 Visualizing Convolutional Networks for MRI-based Diagnosis of Alzheimer’s Disease

Visualizing Convolutional Networks for MRI-based Diagnosis of Alzheimer’s Disease

Johannes Rieke, Fabian Eitel, Martin Weygandt, John-Dylan Haynes and Kerstin Ritter

Our paper was presented on the MLCN workshop at MICCAI 2018 in Granada (Slides).

Preprint: http://arxiv.org/abs/1808.02874

Abstract: Visualizing and interpreting convolutional neural networks (CNNs) is an important task to increase trust in automatic medical decision making systems. In this study, we train a 3D CNN to detect Alzheimer’s disease based on structural MRI scans of the brain. Then, we apply four different gradient-based and occlusion-based visualization methods that explain the network’s classification decisions by highlight- ing relevant areas in the input image. We compare the methods qualita- tively and quantitatively. We find that all four methods focus on brain regions known to be involved in Alzheimer’s disease, such as inferior and middle temporal gyrus. While the occlusion-based methods focus more on specific regions, the gradient-based methods pick up distributed rel- evance patterns. Additionally, we find that the distribution of relevance varies across patients, with some having a stronger focus on the temporal lobe, whereas for others more cortical areas are relevant. In summary, we show that applying different visualization methods is important to understand the decisions of a CNN, a step that is crucial to increase clinical impact and trust in computer-based decision support systems.

Heatmaps

Quickstart

You can use the visualization methods in this repo on your own model (PyTorch; for other frameworks see below) like this:

from interpretation import sensitivity_analysis
from utils import plot_slices

cnn = load_model()
mri_scan = load_scan()

heatmap = sensitivity_analysis(cnn, mri_scan, cuda=True)
plot_slices(mri_scan, overlay=heatmap)

heatmap is a numpy array containing the relevance heatmap. The methods should work for 2D and 3D images alike. Currently, four methods are implemented and tested: sensitivity_analysis, guided_backprop, occlusion, area_occlusion. There is also a rough implementation of grad_cam, which seems to work on 2D photos, but not on brain scans. Please look at interpretation.py for further documentation.

Code Structure

The codebase uses PyTorch and Jupyter notebooks. The main files for the paper are:

  • training.ipynb is the notebook to train the model and perform cross validation.
  • interpretation-mri.ipynb contains the code to create relevance heatmaps with different visualization methods. It also includes the code to reproduce all figures and tables from the paper.
  • All *.py files contain methods that are imported in the notebooks above.

Additionally, there are two other notebooks:

  • interpretation-photos.ipynb uses the same visualization methods as in the paper but applies them to 2D photos. This might be an easier introduction to the topic.
  • small-dataset.ipynb contains some old code to run a similar experiment on a smaller dataset.

Trained Model and Heatmaps

If you don't want to train the model and/or run the computations for the heatmaps yourself, you can just download my results: Here is the final model that I used to produce all heatmaps in the paper (as a pytorch state dict; see paper or code for more details on how the model was trained). And here are the numpy arrays that contain all average relevance heatmaps (as a compressed numpy .npz file). Please have a look at interpretations-mri.ipynb for instructions on how to load and use these files.

Data

The MRI scans used for training are from the Alzheimer Disease Neuroimaging Initiative (ADNI). The data is free but you need to apply for access on http://adni.loni.usc.edu/. Once you have an account, go here and log in.

Tables

We included csv tables with metadata for all images we used in this repo (data/ADNI/ADNI_tables). These tables were made by combining several data tables from ADNI. There is one table for 1.5 Tesla scans and one for 3 Tesla scans. In the paper, we trained only on the 1.5 Tesla images.

Images

To download the corresponding images, log in on the ADNI page, go to "Download" -> "Image Collections" -> "Data Collections". In the box on the left, select "Other shared collections" -> "ADNI" -> "ADNI1:Annual 2 Yr 1.5T" (or the corresponding collection for 3T) and download all images. We preprocessed all images by non-linear registration to a 1 mm isotropic ICBM template via ANTs with default parameters, using the quick registration script from here.

To be consistent with the codebase, put the images into the folders data/ADNI/ADNI_2Yr_15T_quick_preprocessed (for the 1.5 Tesla images) or data/ADNI/ADNI_2Yr_3T_preprocessed (for the 3 Tesla images). Within these folders, each image should have the following path: <PTID>/<Visit (spaces removed)>/<PTID>_<Scan.Date (/ replaced by -)>_<Visit (spaces removed)>_<Image.ID>_<DX>_Warped.nii.gz. If you want to use a different directory structure, you need to change the method get_image_filepath and/or the filenames in datasets.py.

Users from Ritter/Haynes lab

If you're working in the Ritter/Haynes lab at Charité Berlin, you don't need to download any data, but simply uncomment the correct ADNI_DIR variable in datasets.py.

Requirements

  • Python 2 (mostly compatible with Python 3 syntax, but not tested)
  • Scientific packages (included with anaconda): numpy, scipy, matplotlib, pandas, jupyter, scikit-learn
  • Other packages: tqdm, tabulate
  • PyTorch: torch, torchvision (tested with 0.3.1, but mostly compatible with 0.4)
  • torchsample: I made a custom fork of torchsample which fixes some bugs. You can download it from https://github.com/jrieke/torchsample or install directly via pip install git+https://github.com/jrieke/torchsample. Please use this fork instead of the original package, otherwise the code will break.

Non-pytorch Models

If your model is not in pytorch, but you still want to use the visualization methods, you can try to transform the model to pytorch (overview of conversion tools).

For keras to pytorch, I can recommend nn-transfer. If you use it, keep in mind that by default, pytorch uses channels-first format and keras channels-last format for images. Even though nn-transfer takes care of this difference for the orientation of the convolution kernels, you may still need to permute your dimensions in the pytorch model between the convolutional and fully-connected stage (for 3D images, I did x = x.permute(0, 2, 3, 4, 1).contiguous()). The safest bet is to switch keras to use channels-first as well, then nn-transfer should handle everything by itself.

Citation

If you use our code, please cite our paper:

@inproceedings{rieke2018,
  title={Visualizing Convolutional Networks for MRI-based Diagnosis of Alzheimer's Disease},
  author={Rieke, Johannes and Eitel, Fabian and Weygandt, Martin and Haynes, John-Dylan and Ritter, Kerstin},
  booktitle={Machine Learning in Clinical Neuroimaging (MLCN)},
  year={2018}
}

More Repositories

1

traingenerator

🧙 A web app to generate template code for machine learning
Python
1,364
star
2

best-of-streamlit

🏆 A ranked gallery of awesome streamlit apps built by the community
1,139
star
3

awesome-machine-learning-startups-berlin

🤖 A curated list of machine learning & artificial intelligence startups in Berlin (Germany)
Python
275
star
4

streamlit-analytics

👀 Track & visualize user interactions with your streamlit app
Python
263
star
5

shape-detection

🟣 Object detection of abstract shapes with neural networks
Jupyter Notebook
218
star
6

year-on-github

🐙 Share your Github stats for 2020 on Twitter
Python
132
star
7

streamlit-image-select

🖼️ An image select component for Streamlit
Python
103
star
8

components-hub

An automated hub of Streamlit components
Python
79
star
9

streamlit-pills

💊 A Streamlit component to show clickable pills/badges
TypeScript
77
star
10

fastapi-csv

🏗️ Create APIs from CSV files within seconds, using fastapi
Python
77
star
11

streamlit-profiler

🏄🏼 Runtime profiler for Streamlit, powered by pyinstrument
Python
39
star
12

streamlit-theme-generator

👩‍🎨️ Generate beautiful color themes for Streamlit, powered by colormind.io
Jupyter Notebook
16
star
13

timeseries-rnn

⏱️ char-rnn for time series data
Python
13
star
14

traintool

🔧 Train off-the-shelf machine learning models in one line of code
Python
12
star
15

readme-template

My template for Github readmes
10
star
16

streamlit-inspector

🕵️ Streamlit component to inspect Python objects during development
Python
8
star
17

reinforcement-maze

Solving a maze with reinforcement learning
Jupyter Notebook
7
star
18

DiffusionLimitedAggregation

Simulation of 3D cluster formation through diffusion using Cython, NumPy and VPython
Python
7
star
19

evolution-learning

🐣 Code for my master thesis "Biologically Plausible Deep Learning through Neuroevolution"
Jupyter Notebook
6
star
20

awstrainer

🛠️ Command line tool for machine learning on AWS
Python
4
star
21

airbnb-sanity

🧳 Chrome extension to hide Airbnb listings you don't like
JavaScript
4
star
22

CrisisInYourCity

Web visualization of housing market data during the financial crisis
JavaScript
3
star
23

lstm-biology

Neural networks with LSTM to classify and predict biological cell movement
Jupyter Notebook
3
star
24

awesome-python-utils

⭐ Python packages that make dev life easier
2
star
25

drosophila-dynamics

Investigation of a Drosophila motoneuron model for my bachelor thesis
Jupyter Notebook
1
star
26

streamlit-cache

Little exploration of what st.cache could look like
Python
1
star
27

jrieke.github.io

My personal website
CSS
1
star
28

lightshapes

Party lightshow where animations are projected on cardboard shapes
Python
1
star
29

machine-intelligence

Code for the class "Machine Intelligence 1" at TU Berlin
Jupyter Notebook
1
star
30

stcode

Python
1
star
31

NeuroSim

Simple neuron simulator using NumPy and Matplotlib
Python
1
star
32

ioiometer

Android app to measure voltages on the IOIO board
Java
1
star
33

webb-compare-streamlit

Reproducing the amazing WebbCompare app in Streamlit
Python
1
star