This project implements multiple deep-learning models for wound image classification, built using JAX and Flax. The system includes:
本项目实现了一个 基于 JAX / Flax 的伤口图像分类任务,提供:
WOUND_CLASSIFICATION_JAX
│ requirements.txt
│ terminal_commands.txt
│
├─data
│ └─dataset # Cleaned dataset (after processing)
│
├─nets # Model architectures
│ └─ BaselineCNN.py
│ CNN.py
│ Hybrid.py
│ Mamba.py # The implementations of Vision Mamba and the VisionMamba.py file are different.
│ ResNet.py
│ VisionMamba.py
│
├─references
│ └─ Hatamizadeh_MambaVision_CVPR2025.pdf
│
└─scripts # Training / Testing / Data Processing
└─ dataset.py
data_clean.py
download_data.py
test.py
train.py
pip install -r requirements.txt
python scripts/download_data.py
python scripts/data_clean.py
This creates:
data/dataset/
000001_ClassA.jpg
000002_ClassB.jpg
python scripts/data_clean.py --build_split
This generates:
data/dataset_split/train/
data/dataset_split/test/
Example:
python scripts/train.py \
--model mamba \
--batch_size 16 \
--num_epochs 50 \
--learning_rate 5e-5 \
--use_augmentation True
--model)cnn
baseline_cnn
resnet18
resnet34
mamba
vision_mamba
hybrid_mamba_cnn
hybrid_mamba_resnet
Example:
python scripts/test.py \
--model mamba \
--ckpt_path ../checkpoints/mamba/best.pkl
Outputs:
| File | Description |
|---|---|
scripts/train.py |
Full training pipeline(训练主脚本) |
scripts/test.py |
Inference and evaluation(推理评估脚本) |
scripts/dataset.py |
Dataset loader + augmentation(数据加载器 + 增强) |
scripts/data_clean.py |
Clean dataset and split(数据清洗与划分) |
nets/ |
All neural network architectures(所有网络结构) |
Implements Mamba state-space blocks for vision tasks, including:
Fuse Mamba features with CNN/ResNet outputs:
from scripts.dataset import data_loader
loader = data_loader(data_path="../data/dataset", use_augmentation=True)
img, label_idx, img_idx = loader[0]
from nets.CNN import SimpleCNN
import jax
from scripts.dataset import data_loader
loader = data_loader(data_path="../data/dataset", use_augmentation=False)
model = SimpleCNN(num_classes=loader.num_classes)
params = model.init(jax.random.PRNGKey(0), jax.numpy.zeros((1, 224, 224, 3)))
logits = model.apply(params, jax.numpy..zeros((1, 224, 224, 3)))
See requirements.txt.
See LICENSE