Skip to content

iacopo97/functa_final

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

71 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Functa

This repository contains code for the ICML 2022 paper "From data to functa: Your data point is a function and you can treat it like one" by Emilien Dupont*, Hyunjik Kim*, Ali Eslami, Danilo Rezende and Dan Rosenbaum. *Denotes joint first authorship.

The codebase contains the meta-learning experiment for CelebA-HQ-64 and SRN CARS, along with a colab that creates a modulation dataset for CelebA-HQ-64.

Setup

To set up a Python virtual environment with the required dependencies, run:

# create virtual environment
conda create -n env python=3.8
conda activate env
# update pip, setuptools and wheel
pip3 install --upgrade pip setuptools wheel
# install all required packages
pip3 install -r requirements.txt
pip3 install --upgrade jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip3 install git+https://github.com/deepmind/jaxline

To activate the right environment, once everything is installed:

conda activate env

Note that the directory containing this repository must be included in the PYTHONPATH environment variable. This can be done by e.g.,

export PYTHONPATH='/media/data2/icurti/projects/functa'
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/
## change the default settings of jax
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
export XLA_PYTHON_CLIENT_MEM_FRACTION=.90

To test the correct installation of the deep learning library use the followind command:

#############################################################

python3 -c "import tensorflow as tf;print(tf.config.list_physical_devices('GPU'))"
#python3 -c "import tensorflow as tf; tf.config.experimental.set_visible_devices([], 'GPU');"

python3 -c "import  jax;  print(jax.devices())"

python3 -c "import torch;print(torch.cuda.is_available())"
###########################################################

Once done with virtual environment, deactivate with command:

conda deactivate
rm -r /env

Setup celeb_a_hq_custom dataset as Tensorflow dataset (TFDS)

The publicly available celeb_a_hq dataset with TFDS at https://www.tensorflow.org/datasets/catalog/celeb_a_hq requires manual preparation, for which there are some known issues: tensorflow/datasets#1496. Alternatively, there exist zip files that are publicly available for download. We convert the 128x128 resolution version into a tensorflow dataset (TFDS) so that we can readily load the data into our jax/haiku models with various data processing options that come with tfds. Note that the resulting dataset has a different ordering to the tfds version, hence any train/test split further down the line may be different to the one used in our paper, and the downsampling algorithm used may be different. We use tf.image.resize to resize to 64x64 resolution with the default biliear interpolation here.

To set up the tfds, run:

cd celeb_a_hq_custom
tfds build --register_checksums --data_dir='/media/data2/icurti/tensorflow_datasets/'

This should be quick to run (few seconds).

Setup srn_cars dataset as Tensorflow dataset (TFDS) (Optional)

The publicly available srn_cars dataset exists as a zip file in the official PixelNeRF codebase. We convert this into a tensorflow dataset (tfds) so that we can readily load the data into our jax/haiku models with various data processing options that come with tfds.

To set up the tfds, run:

cd modelnet40_2c
tfds build --data_dir='/media/data2/icurti/tensorflow_datasets/'
cd modelnet40_1c
tfds build --data_dir='/media/data2/icurti/tensorflow_datasets/'
cd modelnet40
tfds build --data_dir='/media/data2/icurti/tensorflow_datasets/'
cd shapenet
tfds build --data_dir='/media/data2/icurti/tensorflow_datasets/'
cd scannet
tfds build --data_dir='/media/data2/icurti/tensorflow_datasets/'
cd srn_cars
tfds build --register_checksums --data_dir='/media/data2/icurti/tensorflow_datasets/'
cd manifold40
tfds build --data_dir='/media/data2/icurti/tensorflow_datasets/'
cd shapenet_occ
tfds build --data_dir='/media/data2/icurti/tensorflow_datasets/'

This can take a while to run (~ 1hr) as we convert views of each scene into an array with shape (num_views, H, W, C), so set it running and enjoy some ☕

Run tests (Optional)

After setting up either dataset, check that you can successfully run a single step of the experiment by running the test for celeb_a_hq or modelnet:

cd $functadir
python3 -m test_celeb_a_hq

python3 -m test_modelnet --loss=MSE

or for srn_cars:

python3 -m test_srn_cars

on local machine to avoid crashing

export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
export XLA_PYTHON_CLIENT_MEM_FRACTION=.90

Run meta-learning experiment

Set the hyperparameters in experiment_meta_learning.py as desired by modifying the config values. Then inside the virtual environement, run the JAXline experiment via command:

cd $functadir
python3 -m experiment_meta_learning --config=experiment_meta_learning.py --loss=MSE
python3 -m experiment_meta_learning --config=experiment_meta_learning.py --loss=MSE --jaxline_mode='eval'

Download pretrained weights

Download pretrained weights for the CelebA-HQ-64 meta-learning experiments for mod_dim=64, 128, 256, 512, 1024 and srn_cars with mod_dim=128 here:

Dataset Modulation Dimension Link
CelebA-HQ-64 64 .npz
CelebA-HQ-64 128 .npz
CelebA-HQ-64 256 .npz
CelebA-HQ-64 512 .npz
CelebA-HQ-64 1024 .npz
SRN CARS 128 .npz

Note that the weights for CelebA-HQ-64 were obtained using the original tfds dataset, so they can be slightly different to the ones resulting from running the above meta-learning experiment with the custom celeb_a_hq dataset.

How to load these weights into the model is shown in the demo Colab below.

Create or Download modulations for CelebA-HQ-64

modulation_dataset_writer.py creates the modulations on celeba as npz. Before running, make sure the pretrained weights for the correct modulation dim have been downloaded. Then use mod_dim and pretrained_weights_dir as input args to the python script. Optionally also specify save_to_dir to store the created modulations as npz in a different directory than the directory you are running from. Run via command:

cd $functadir
python3 -m modulation_dataset_writer \
  --mod_dim=512 \
  --pretrained_weights_dir="training/training_scan" \
  --save_to_dir=modulation\
  --data='scannet'\
  --loss='MSE'

  python3 -m modulation_dataset_writer \
  --mod_dim=512 \
  --pretrained_weights_dir="training/training_man40_pcd" \
  --save_to_dir=modulation\
  --data='manifold40_pcd'\
  --loss='MSE'

To run both the pointcloud/mesh generation and the modulations creation:

python3 -m pcd_prediction \
  --mod_dim=512 \
  --pretrained_weights_dir="training/training_shape" \
  --data_pcd="shapenet" \
  --data_split="both" \
  --save_to_dir= modulation \
  --n_out=2048 \
  --bs=1 \
  --loss='MSE'

python3 -m pcd_prediction \
  --mod_dim=512 \
  --pretrained_weights_dir="training/training_man40_pcd" \
  --data_pcd="manifold40_pcd" \
  --data_split="test" \
  --save_to_dir= modulation \
  --n_out=16384 \
  --bs=1 \
  --loss='MSE'

python3 -m mesh_prediction \
  --mod_dim=512 \
  --pretrained_weights_dir="training/training_man40" \
  --data_mesh="manifold40" \
  --data_split="test" \
  --save_to_dir= modulation_mesh \
  --resolution=128 \
  --bs=1 \
  --loss='MSE'\
  --number_mesh_save=1

Alternatively, download the modulations here:

Modulation Dimension Link
64 .npz
128 .npz
256 .npz
512 .npz
1024 .npz

Again note that these modulations were obtained using the original tfds dataset, so they can be slightly different to the ones resulting from running the above script that uses the custom celeb_a_hq dataset.

In order to use tensorboard plugin

tensorboard --logdir training/training_scan/512/train --load_fast=false


In order to run the classifier on the modulations:

cd $functadir
python3 -m inr_emb_cls \
  --mod_dim = 512 \
  --num_class = 10 \
  --data="scannet"\

Demo Colab Open In Colab

We also include a colab that shows how to visualize modulation reconstructions for CelebA-HQ-64.

Paper Figures

Figure 4

Meta-learned initialization + 4 gradient steps and target for test scene.

Figure 7

Course of optimization for imputation of voxel from partial observation.

From back

Partial observation Imputation

From front

Partial observation Imputation

From left

Partial observation Imputation

From lidar scan

Partial observation Imputation

Figure 9

Uncurated samples from DDPM (diffusion) trained on 64-dim modulations of SRN-cars.

Figure 10

Latent interpolation between two car scenes with moving pose.

Figure 11

Novel view synthesis from occluded view.

Occluded view Ground truth Inferred No prior

Figure 12

Uncurated samples from flow trained on 256-dim modulations on ERA-5 temperature data.

Figure 26

Additional voxel imputation results.

Partial observation Imputation

Figure 28

Additional novel view synthesis results.

Occluded view Ground truth Inferred No prior

Giving Credit

If you use this code in your work, we ask you to please cite our work:

@InProceedings{functa22,
  title = {From data to functa: Your data point is a function and you can treat it like one},
  author = {Dupont, Emilien and Kim, Hyunjik and Eslami, S. M. Ali and Rezende, Danilo Jimenez and Rosenbaum, Dan},
  booktitle = {39th International Conference on Machine Learning (ICML)},
  year = {2022},
}

Raising Issues

Please feel free to raise a GitHub issue.

License and disclaimer

Copyright 2022 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.

This is not an official Google product.

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors