wound_classification_jax

Wound Classification with JAX / Flax | 使用 JAX / Flax 的创伤图像分类项目


📌 Overview | 项目简介

This project implements multiple deep-learning models for wound image classification, built using JAX and Flax. The system includes:

本项目实现了一个 基于 JAX / Flax 的伤口图像分类任务,提供:


📁 Project Structure | 项目结构

WOUND_CLASSIFICATION_JAX
│  requirements.txt
│  terminal_commands.txt
│
├─data
│   └─dataset                 # Cleaned dataset (after processing)
│
├─nets                        # Model architectures
│   └─ BaselineCNN.py
│      CNN.py
│      Hybrid.py
│      Mamba.py               # The implementations of Vision Mamba and the VisionMamba.py file are different.
│      ResNet.py
│      VisionMamba.py
│
├─references
│   └─ Hatamizadeh_MambaVision_CVPR2025.pdf
│
└─scripts                     # Training / Testing / Data Processing
    └─ dataset.py
       data_clean.py
       download_data.py
       test.py
       train.py

🚀 Features | 功能特点

1. Multiple Model Architectures 多模型支持

2. Data Processing Toolkit 数据处理工具

3. Full Training Pipeline 完整训练流程

4. Inference / Evaluation 推理与评估


📦 Installation | 安装

pip install -r requirements.txt

📂 Dataset Preparation | 数据准备

Download from Kaggle

python scripts/download_data.py

Clean dataset and remove corrupted images

python scripts/data_clean.py

This creates:

data/dataset/
    000001_ClassA.jpg
    000002_ClassB.jpg

Split into train/test

python scripts/data_clean.py --build_split

This generates:

data/dataset_split/train/
data/dataset_split/test/

🧠 Model Training | 模型训练

Example:

python scripts/train.py \
    --model mamba \
    --batch_size 16 \
    --num_epochs 50 \
    --learning_rate 5e-5 \
    --use_augmentation True

Supported models (choose with --model)

cnn
baseline_cnn
resnet18
resnet34
mamba
vision_mamba
hybrid_mamba_cnn
hybrid_mamba_resnet

🧪 Model Testing / Evaluation | 模型测试与评估

Example:

python scripts/test.py \
    --model mamba \
    --ckpt_path ../checkpoints/mamba/best.pkl

Outputs:


🧩 Key Files | 关键文件说明

File Description
scripts/train.py Full training pipeline(训练主脚本)
scripts/test.py Inference and evaluation(推理评估脚本)
scripts/dataset.py Dataset loader + augmentation(数据加载器 + 增强)
scripts/data_clean.py Clean dataset and split(数据清洗与划分)
nets/ All neural network architectures(所有网络结构)

💡 Model Highlights | 模型亮点

Vision Mamba

Implements Mamba state-space blocks for vision tasks, including:

Hybrid Models

Fuse Mamba features with CNN/ResNet outputs:


▶ Example Code | 示例代码

Loading an image manually

from scripts.dataset import data_loader

loader = data_loader(data_path="../data/dataset", use_augmentation=True)

img, label_idx, img_idx = loader[0]

Running a forward pass

from nets.CNN import SimpleCNN
import jax
from scripts.dataset import data_loader

loader = data_loader(data_path="../data/dataset", use_augmentation=False)
model = SimpleCNN(num_classes=loader.num_classes)
params = model.init(jax.random.PRNGKey(0), jax.numpy.zeros((1, 224, 224, 3)))
logits = model.apply(params, jax.numpy..zeros((1, 224, 224, 3)))

📌 Requirements | 依赖

See requirements.txt.


📄 License | 许可证

See LICENSE