Code for Disentangling Representations through Multitask Learning (ICLR 2025).
This repository contains code for training and analyzing neural networks that learn disentangled representations through multitask learning. The project focuses on how neural networks can generalize out-of-distribution and across different tasks by learning to separate task-relevant dimensions in their internal representations.
- Python 3.8 or higher
- Git
-
Clone this repository:
git clone https://github.com/panvaf/DisentangleRes.git cd DisentangleRes -
Create and activate a virtual environment:
python -m venv disentangle # On Windows disentangle\Scripts\activate # On macOS/Linux source disentangle/bin/activate
-
Install the required packages:
pip install -r requirements.txt
-
Install neurogym separately (required for task environments):
git clone https://github.com/gyyang/neurogym.git cd neurogym pip install -e . cd ..
-
train.py: Trains autoregressive neural networks on specified tasks. Configurable parameters include network architecture, activation functions, noise levels, and training settings.
-
generalize.py: Evaluates zero-shot, out-of-distribution generalization of trained networks.
-
analyze.py: Analyzes the representations learned by trained RNNs through dimensionality reduction, fixed point analysis, and other techniques.
-
analyze_transformer.py: Specific analysis for transformer architectures.
-
sparsity.py: Analyzes and evaluates sparsity in network representations.
-
tasks.py: Contains classes for various cognitive tasks used to train the networks.
-
util.py: Utility functions for analysis, visualization, and data processing.
-
RNN.py: Implementation of the recurrent neural network architecture.
-
transformer.py: Implementation of transformer-based architectures for the tasks.
-
Train networks: Use
train.pyto train neural networks on various numbers of tasks, for different hyperparameter choices.python train.py
-
Evaluate generalization: Use
generalize.pyto test how well trained networks generalize to new tasks. Performs "sweeps" across networks with same hyperparameter choices, and different number of trained tasks.python generalize.py
-
Analyze representations: Use
analyze.pyto visualize the representations that the networks have learned.python analyze.py
- Figures/: Contains code to reproduce all main figures. To do that, please download data from trained networks provided here, or train and evaluate the networks yourselves by following the instructions in the same link.
If you use this code in your research, please cite our paper:
@inproceedings{
vafidis2025disentangling,
title={Disentangling Representations through Multi-task Learning},
author={Pantelis Vafidis and Aman Bhargava and Antonio Rangel},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=yVGGtsOgc7}
}