We plan for moving Jax AI Stack examples to NNX docs: - [ ] (Low prio) Part 2: Debug a variational autoencoder (VAE) - [x] Part 3: Train a diffusion model for image generation (@samanklesaria, https://github.com/google/flax/pull/5403) - [x] Visualize JAX model metrics with TensorBoard (@samanklesaria, https://github.com/google/flax/pull/5425) - [x] Introduction to Data Loaders on CPU with JAX / Introduction to Data Loaders on GPU with JAX (@samanklesaria #5454 ) - [x] JAX for PyTorch users / Porting a PyTorch model to JAX (@vfdev-5, https://github.com/google/flax/pull/5408) - [x] Train a miniGPT language model with JAX (@samanklesaria, https://github.com/google/flax/pull/5405) - [ ] Text classification with a transformer language model using JAX - [x] Machine Translation with encoder-decoder transformer model (@samanklesaria, https://github.com/google/flax/pull/5431) - [ ] Image segmentation with UNETR model (@vfdev-5, https://github.com/google/flax/pull/5463) - [ ] ~Image Captioning with Vision Transformer (ViT) model (@vfdev-5, ...)~ - [ ] Object Detection with DETR, inference (@vfdev-5 ) - [ ] training more recent version of DETR, e.g. RF-DETR (but more code to privide) - [x] Train a Vision Transformer (ViT) for image classification with JAX (@vfdev-5, https://github.com/google/flax/pull/5455) - [ ] Time series classification with CNN
We plan for moving Jax AI Stack examples to NNX docs:
Image Captioning with Vision Transformer (ViT) model (@vfdev-5, ...)