Skip to content

jgeorg11/Acne-Classification-Rust

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Acne Classification in Rust (Candle)

This project trains a 5-class acne classifier in pure Rust using Candle.
The goal is not only to train a model, but to make each design choice explicit and teachable. This repo was written to run on a MacBook in an air-gapped environment. Adapt it as needed.

What Makes This Project Unique

  • Rust-native deep learning (candle-core, candle-nn)
  • Online augmentations (flip, rotation, color jitter)
  • Custom CNN with progressive downsampling
  • Early stopping to prevent overtraining on small data
  • Best-checkpoint saving (.safetensors) based on validation accuracy
  • Efficient batched tensor upload path

Architecture and Why It Looks This Way

Current AcneModel:

  • Conv(3->32) + GroupNorm + ReLU + MaxPool
  • Conv(32->64) + GroupNorm + ReLU + MaxPool
  • Conv(64->128) + GroupNorm + ReLU + MaxPool
  • Conv(128->256) + GroupNorm + ReLU + MaxPool
  • Extra MaxPool (14x14 -> 7x7)
  • Flatten
  • Linear(256*7*7 -> 256) + ReLU
  • Linear(256 -> num_classes)

Why this structure:

  • Convolutions learn local texture/shape patterns useful for skin lesions.
  • GroupNorm is stable for smaller batch sizes.
  • MaxPool progressively reduces spatial size and compute.
  • A small dense head keeps training simple and fast.

Why the extra pooling to 7x7?

I explicitly added one more pooling stage after the 4th conv block so the spatial map goes 14x14 -> 7x7. Why this helps:

  • Much fewer dense-layer parameters
  • Lower memory + faster training
  • Often better regularization on small datasets

Data Pipeline Decisions

Why convert HWC u8 -> CHW f32?

  • Images are read as HWC (height, width, channel) bytes.
  • Candle/CNN ops expect NCHW/CHW layout (batch, channel, height, width).
  • u8 is not suitable for gradient-based optimization; f32 is standard for stable math.

Why normalize to [-1, 1] with 2x - 1?

  • Raw [0, 255] scales are too large and inconsistent.
  • First map to [0, 1], then to [-1, 1].
  • Centering around zero helps optimization (more balanced activations/gradients).

Why batch and upload as one tensor?

  • Uploading one image at a time causes repeated device transfer overhead.
  • Building one contiguous batch buffer and uploading once is faster.
  • It improves throughput and keeps training loops cleaner.

Augmentation Decisions

Applied only during training:

  • Random horizontal flip
  • Random rotation (up to +/- 15 degrees)
  • Random brightness/contrast jitter

Why augment:

  • Small medical image datasets are easy to overfit.
  • Augmentations simulate realistic variation (angle, lighting, framing).
  • This improves generalization to unseen photos.

Toggle at runtime:

  • Augmentation ON (default): cargo run --release --features metal
  • Augmentation OFF: cargo run --release --features metal -- --no-aug

Training Decisions

Why cross-entropy loss?

  • This is a multi-class classification task (one class per image).
  • Cross-entropy directly compares class probabilities against true class labels.
  • It provides strong gradients for separating classes.

Why AdamW?

  • Adam gives fast, robust convergence in early training.
  • AdamW supports decoupled weight decay, which is useful if you want stronger regularization than plain Adam.
  • Good default optimizer for CNNs on moderate-size datasets.

Current setup uses lr=1e-3 with AdamW and weight_decay=0.0.

Why early stopping?

  • Validation can plateau or degrade while training loss still drops.
  • Early stopping halts when no validation improvement is seen for 7 epochs.
  • This avoids wasting epochs and reduces overfitting risk.
  • It also saves time when training on a CPU or small GPU like my Macbook for demonstration purposes.

Why best-checkpoint saving?

  • The final epoch is not always the best model.
  • Saving whenever validation accuracy improves preserves the strongest checkpoint.
  • This project writes that checkpoint to acne-best.safetensors.
  • After training, main.rs reloads acne-best.safetensors and prints a best-model summary on both train and validation sets (loss + accuracy + training duration).

Project Structure

  • src/main.rs: pipeline orchestration (scan, split, train, eval, single-image predict)
  • src/acne_model.rs: CNN model definition
  • src/data.rs: class discovery + path labeling
  • src/data_augentation.rs: image transforms + preprocessing helpers
  • src/batching.rs: batch iterator + tensor assembly/upload
  • src/training.rs: training loop, evaluation loop, checkpointing, inference helper

Dataset

Expected local layout:

  • train/<class_name>/*.jpg

The current code expects exactly 5 classes.

Dataset used for this project: Acne Image Dataset (Kaggle)

Run

cargo run --release --features metal

Output Summary

After training completes, the program prints:

  • per-epoch training and validation loss/accuracy
  • early stopping notice (if triggered)
  • a final "Best Model Summary" block using acne-best.safetensors:
    • total training time
    • train loss/accuracy (no augmentation)
    • validation loss/accuracy (no augmentation)
  • a single-image prediction example (true label vs predicted label)

Reproducibility

  • Seed is fixed to 42069 for shuffling and stochastic transforms.

About

This repo is a Rust-native deep learning project that trains and evaluates a 5-class acne image classifier using Hugging Face’s Candle framework. It’s designed to be teachable and explicit

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages