Deep Learning CNNs for Medical Image Classification

Cyrus Kurd
4 min readNov 6, 2024

--

In healthcare, convolutional neural networks (CNNs) have become indispensable for tasks like diagnosing fractures, identifying tumor cells, and analyzing microscopic tissue structures. With CNNs, medical image analysis is much faster and in many cases more accurate, assisting scientists in making informed judgements and diagnoses. This project focuses on a CNN model tailored for classifying pathology images from the PathMNIST dataset, an open-source dataset that’s part of the MedMNIST collection.

Different Datasets from the MedMNIST Collection (Credit: MedMNIST)

The Dataset: PathMNIST

PathMNIST provides a collection of pathology images for deep learning, ideal for building models to recognize distinct tissue types. Each image represents a tiny section of tissue, which can have vital information for diagnosis.

  • Data Size: PathMNIST includes over 100,000 labeled images.
  • Classes: The dataset contains nine classes, each representing a specific type of pathology tissue or condition:
  • “Adipose Tissue”, “Background”, “Debris”, “Lymphocytes”, “Mucus”, “Normal Colon Mucosa”, “Cancer-Associated Stroma”, “Colorectal Adenocarcinoma Epithelium”, “Smooth Muscle”

The goal is to train our CNN model to learn and differentiate between these classes in a supervised manner, using tagged data from the PathMNIST dataset to identify image features, ultimately building a tool that could be used as a reference guide to assist in medical diagnostics for similar applications.

Building the Model: CNN Architecture

For this project, I used a streamlined CNN architecture. CNNs are the go-to model for 2D images because they excel at detecting spatial patterns — like edges, shapes, and textures — critical for medical imaging. They’re particularly effective at learning features in increasing levels of abstraction as more layers are added: initial layers capture simple features (e.g., edges and lines), while deeper layers identify more complex shapes and patterns, such as specific cell structures or tissue irregularities that may indicate disease. In contrast, simpler neural networks or RNNs lack this structured approach, resulting in poor feature extraction for images, where spatial hierarchies are essential.

The architecture I used starts with three convolutional layers, each extracting progressively complex features from the images. These layers are followed by batch normalization and max-pooling steps to enhance stability and retain only the most critical features. After feature extraction, the model uses fully connected layers to classify the images into specific pathology classes. A dropout layer is also included to prevent overfitting, ensuring that the model generalizes well to unseen data.

Architecture Components

1. Feature Extraction Layers: Each convolutional layer detects unique patterns, refining the model’s understanding of the images.

2. Classification Layers: Fully connected layers at the end take the extracted features and classify the images. Dropout helps mitigate overfitting by reducing dependency on specific neurons.

This architecture balances simplicity and performance, ensuring the model is powerful enough to handle the complexity of pathology images without being overly complex.

Training the Model

For training, I used cross-entropy loss (ideal for multi-class classification tasks) and the Adam optimizer, which adjusts learning rates automatically. With a batch size of 128, the model learns in increments, processing data in mini-batches to improve both stability and speed.

A critical part of this project was to leverage GPU acceleration whenever available. By processing large datasets and model operations on the GPU, I saved considerable time, especially in operations like finding maximum class predictions. GPU-accelerated training is ideal for deep learning projects, providing faster feedback and allowing us to iterate more effectively. Do note that matplotlib operations have to be done on the CPU (see code for details on how to switch back and forth).

Visualizing Predictions

To make the results tangible, I made a visualization function that displays test images alongside their true and predicted labels. By examining these images, we can directly assess where the model succeeds and where it might need further refinement. This type of visualization is invaluable in medical AI, where interpretability is as crucial as accuracy. Indeed, some models may have to go through their own regulatory processes, particularly for clinical use (safety, efficacy, transparency standards) versus use in research, industry, or otherwise. This visual feedback not only helps with model validation but also allows researchers and medical professionals to verify the model’s outputs.

Output comparing the predicted versus actual tissue classification (note: rendering is blurry since the images are stored in 3x28x28 matricies, RGB values of only 28 pixels of height and width — all the more impressive that the model can achieve an AUC score of 0.9370!)

Takeaways

This project demonstrates a practical application of CNNs in computer vision for medical imaging, using accessible datasets and a clear training workflow. The CNN model was able to classify complex pathology images with high accuracy, demonstrating the feasibility of AI-assisted diagnostics.

If you’re interested in more details, feel free to explore the code on Github or reach out!

--

--

Cyrus Kurd
Cyrus Kurd

Written by Cyrus Kurd

M.S. Data Science Student at Columbia University | linkedin.com/in/cykurd/

No responses yet