Code release for the DIFA paper and baselines for dynamic feature acquisition.
- Training code for methods:
difa,gsmrl,jafa,vaeac,eddi,random - Data loading and preprocessing for tabular and vision benchmarks
- Optional SLURM job generation for large experiment sweeps (
run.py)
src/train.py: main training entrypointsrc/models/: model wrappers, RL/policy modules, and data utilitiessrc/utils/utils.py: CLI args, logging, reproducibility helpersrun.py: SLURM config generation + submission helper (advanced, optional)references/: dataset metadata, metrics mapping, SLURM templatedata/fetch_data.sh: helper script for non-torchvision datasetsdocs/REPRODUCIBILITY.md: paper-style experiment guidancedocs/HPC_SLURM.md: cluster-specific usage notes
- Python 3.9 (recommended for pinned dependency compatibility)
- Linux/macOS shell environment
- CUDA-capable GPU is optional (CPU works, but slower)
Install dependencies:
python -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txtFrom repo root:
cd data
bash fetch_data.sh
cd ..Notes:
mnist,fashionmnist,cifar10,cifar100,svhn,stl,foodare downloaded automatically throughtorchvisionwhen first used.- Other datasets are fetched by
data/fetch_data.sh. - See
data/README.mdfor expected dataset paths and caveats.
src/train.py is the single training entrypoint.
vaeac: trains the imputation modeldifa: feature acquisition model (requires imputation checkpoint)gsmrl: feature acquisition model (requires imputation checkpoint)random: random acquisition baseline (requires imputation checkpoint)eddi: EDDI baseline (requires imputation checkpoint)jafa: feature acquisition baseline (does not use imputation checkpoint)
python src/train.py \
--data mnist \
--problem vaeac \
--name vaeac_mnist \
--saveThis writes a checkpoint to models/vaeac_mnist.pt.
python src/train.py \
--data mnist \
--problem difa \
--n_features 12 \
--imputation_model models/vaeac_mnist.pt \
--iters 10 \
--pretrain_iters 5 \
--batch_size 128 \
--workers 0 \
--debug \
--name difa_mnist \
--saveUse --cuda if a GPU is available.
For problem choices, see src/models/__init__.py.
- CLI defaults live in
src/utils/utils.py. - Paper sweep-style settings are encoded in
run.py(HPC/SLURM-oriented). run.pyis optional for local usage.
Set token and pass --neptune:
export NEPTUNE_API_TOKEN="<your-token>"
python src/train.py --data mnist --problem difa --n_features 12 --neptune --project "<workspace>/<project>"For paper-style settings and sweep behavior, see:
docs/REPRODUCIBILITY.mddocs/HPC_SLURM.md
run.py is intentionally HPC-oriented and generates/submits SLURM jobs; it is not required for local use.
If this code is useful in your work, cite the paper:
- Paper: https://doi.org/10.1609/aaai.v37i6.25934
- Citation file:
CITATION.cff
@inproceedings{ghosh2023difa,
title={DiFA: Differentiable Feature Acquisition},
author={Ghosh, Aritra and Lan, Andrew S.},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={37},
number={6},
pages={7705--7713},
year={2023},
doi={10.1609/aaai.v37i6.25934}
}Released under the MIT License. See LICENSE.
- Some dataset endpoints are external and may change availability over time.
- Reproducibility depends on using the pinned dependencies in
requirements.txt. - The provided SLURM template must be customized before use on a new cluster.