This project trains a 5-class acne classifier in pure Rust using Candle.
The goal is not only to train a model, but to make each design choice explicit and teachable.
This repo was written to run on a MacBook in an air-gapped environment. Adapt it as needed.
- Rust-native deep learning (
candle-core,candle-nn) - Online augmentations (flip, rotation, color jitter)
- Custom CNN with progressive downsampling
- Early stopping to prevent overtraining on small data
- Best-checkpoint saving (
.safetensors) based on validation accuracy - Efficient batched tensor upload path
Current AcneModel:
Conv(3->32) + GroupNorm + ReLU + MaxPoolConv(32->64) + GroupNorm + ReLU + MaxPoolConv(64->128) + GroupNorm + ReLU + MaxPoolConv(128->256) + GroupNorm + ReLU + MaxPool- Extra
MaxPool(14x14 -> 7x7) FlattenLinear(256*7*7 -> 256) + ReLULinear(256 -> num_classes)
Why this structure:
- Convolutions learn local texture/shape patterns useful for skin lesions.
- GroupNorm is stable for smaller batch sizes.
- MaxPool progressively reduces spatial size and compute.
- A small dense head keeps training simple and fast.
I explicitly added one more pooling stage after the 4th conv block so the spatial map goes 14x14 -> 7x7.
Why this helps:
- Much fewer dense-layer parameters
- Lower memory + faster training
- Often better regularization on small datasets
- Images are read as HWC (
height, width, channel) bytes. - Candle/CNN ops expect NCHW/CHW layout (
batch, channel, height, width). u8is not suitable for gradient-based optimization;f32is standard for stable math.
- Raw
[0, 255]scales are too large and inconsistent. - First map to
[0, 1], then to[-1, 1]. - Centering around zero helps optimization (more balanced activations/gradients).
- Uploading one image at a time causes repeated device transfer overhead.
- Building one contiguous batch buffer and uploading once is faster.
- It improves throughput and keeps training loops cleaner.
Applied only during training:
- Random horizontal flip
- Random rotation (up to +/- 15 degrees)
- Random brightness/contrast jitter
Why augment:
- Small medical image datasets are easy to overfit.
- Augmentations simulate realistic variation (angle, lighting, framing).
- This improves generalization to unseen photos.
Toggle at runtime:
- Augmentation ON (default):
cargo run --release --features metal - Augmentation OFF:
cargo run --release --features metal -- --no-aug
- This is a multi-class classification task (one class per image).
- Cross-entropy directly compares class probabilities against true class labels.
- It provides strong gradients for separating classes.
- Adam gives fast, robust convergence in early training.
- AdamW supports decoupled weight decay, which is useful if you want stronger regularization than plain Adam.
- Good default optimizer for CNNs on moderate-size datasets.
Current setup uses lr=1e-3 with AdamW and weight_decay=0.0.
- Validation can plateau or degrade while training loss still drops.
- Early stopping halts when no validation improvement is seen for 7 epochs.
- This avoids wasting epochs and reduces overfitting risk.
- It also saves time when training on a CPU or small GPU like my Macbook for demonstration purposes.
- The final epoch is not always the best model.
- Saving whenever validation accuracy improves preserves the strongest checkpoint.
- This project writes that checkpoint to
acne-best.safetensors. - After training,
main.rsreloadsacne-best.safetensorsand prints a best-model summary on both train and validation sets (loss + accuracy + training duration).
src/main.rs: pipeline orchestration (scan, split, train, eval, single-image predict)src/acne_model.rs: CNN model definitionsrc/data.rs: class discovery + path labelingsrc/data_augentation.rs: image transforms + preprocessing helperssrc/batching.rs: batch iterator + tensor assembly/uploadsrc/training.rs: training loop, evaluation loop, checkpointing, inference helper
Expected local layout:
train/<class_name>/*.jpg
The current code expects exactly 5 classes.
Dataset used for this project: Acne Image Dataset (Kaggle)
cargo run --release --features metalAfter training completes, the program prints:
- per-epoch training and validation loss/accuracy
- early stopping notice (if triggered)
- a final "Best Model Summary" block using
acne-best.safetensors:- total training time
- train loss/accuracy (no augmentation)
- validation loss/accuracy (no augmentation)
- a single-image prediction example (
true labelvspredicted label)
- Seed is fixed to
42069for shuffling and stochastic transforms.