Skip to content

arghosh/difa

DIFA

Code release for the DIFA paper and baselines for dynamic feature acquisition.

What this repo contains

  • Training code for methods: difa, gsmrl, jafa, vaeac, eddi, random
  • Data loading and preprocessing for tabular and vision benchmarks
  • Optional SLURM job generation for large experiment sweeps (run.py)

Repository layout

  • src/train.py: main training entrypoint
  • src/models/: model wrappers, RL/policy modules, and data utilities
  • src/utils/utils.py: CLI args, logging, reproducibility helpers
  • run.py: SLURM config generation + submission helper (advanced, optional)
  • references/: dataset metadata, metrics mapping, SLURM template
  • data/fetch_data.sh: helper script for non-torchvision datasets
  • docs/REPRODUCIBILITY.md: paper-style experiment guidance
  • docs/HPC_SLURM.md: cluster-specific usage notes

Requirements

  • Python 3.9 (recommended for pinned dependency compatibility)
  • Linux/macOS shell environment
  • CUDA-capable GPU is optional (CPU works, but slower)

Install dependencies:

python -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt

Data setup

From repo root:

cd data
bash fetch_data.sh
cd ..

Notes:

  • mnist, fashionmnist, cifar10, cifar100, svhn, stl, food are downloaded automatically through torchvision when first used.
  • Other datasets are fetched by data/fetch_data.sh.
  • See data/README.md for expected dataset paths and caveats.

Quickstart (local)

src/train.py is the single training entrypoint.

Problem mapping

  • vaeac: trains the imputation model
  • difa: feature acquisition model (requires imputation checkpoint)
  • gsmrl: feature acquisition model (requires imputation checkpoint)
  • random: random acquisition baseline (requires imputation checkpoint)
  • eddi: EDDI baseline (requires imputation checkpoint)
  • jafa: feature acquisition baseline (does not use imputation checkpoint)

Stage 1: Train imputation model (vaeac)

python src/train.py \
  --data mnist \
  --problem vaeac \
  --name vaeac_mnist \
  --save

This writes a checkpoint to models/vaeac_mnist.pt.

Stage 2: Train acquisition model with imputation checkpoint

python src/train.py \
  --data mnist \
  --problem difa \
  --n_features 12 \
  --imputation_model models/vaeac_mnist.pt \
  --iters 10 \
  --pretrain_iters 5 \
  --batch_size 128 \
  --workers 0 \
  --debug \
  --name difa_mnist \
  --save

Use --cuda if a GPU is available. For problem choices, see src/models/__init__.py.

Parameter defaults and sweep settings

  • CLI defaults live in src/utils/utils.py.
  • Paper sweep-style settings are encoded in run.py (HPC/SLURM-oriented).
  • run.py is optional for local usage.

Neptune logging (optional)

Set token and pass --neptune:

export NEPTUNE_API_TOKEN="<your-token>"
python src/train.py --data mnist --problem difa --n_features 12 --neptune --project "<workspace>/<project>"

Reproducibility

For paper-style settings and sweep behavior, see:

  • docs/REPRODUCIBILITY.md
  • docs/HPC_SLURM.md

run.py is intentionally HPC-oriented and generates/submits SLURM jobs; it is not required for local use.

Citation

If this code is useful in your work, cite the paper:

@inproceedings{ghosh2023difa,
  title={DiFA: Differentiable Feature Acquisition},
  author={Ghosh, Aritra and Lan, Andrew S.},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={37},
  number={6},
  pages={7705--7713},
  year={2023},
  doi={10.1609/aaai.v37i6.25934}
}

License

Released under the MIT License. See LICENSE.

Known limitations

  • Some dataset endpoints are external and may change availability over time.
  • Reproducibility depends on using the pinned dependencies in requirements.txt.
  • The provided SLURM template must be customized before use on a new cluster.

About

Official implementation of DiFA: Differentiable Feature Acquisition (AAAI 2023).

Topics

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors