Skip to content

He-JiYe/CentriLearn

Repository files navigation

CentriLearn: Learning to Identify Key Nodes in Complex Networks

Python Version PyTorch PyTorch Geometric License 中文文档

A reinforcement learning framework based on graph neural networks for solving combinatorial optimization problems in complex networks, such as network dismantling.


Features

  • Graph-Oriented RL: Reinforcement learning framework specifically designed for graph data based on PyTorch Geometric
  • Modular Architecture: Clear separation of environments, algorithms, models, and metrics for easy extension to other graph combinatorial optimization problems
  • Registry System: Flexible component registration with dynamic building, easy to modify experimental configurations
  • Rich Algorithms: DQN and PPO reinforcement learning algorithm implementations

Project Structure

centrilearn/
├── algorithms/          # RL algorithms (DQN, PPO)
├── buffer/              # Experience replay buffers
├── environments/        # Task environments
├── metrics/             # Evaluation metrics
├── models/              # GNN backbones & prediction heads
│   ├── backbones/       # GraphSAGE, GAT, GIN
│   └── heads/           # QHead, VHead, PolicyHead, etc.
└── utils/               # Builder, registry, training utilities

Installation

pip install -e .

Requirements:

  • Python >= 3.11
  • PyTorch >= 2.7.0
  • PyTorch Geometric >= 2.6.0
  • torch-scatter >= 2.1.0

Quick Start

Training

# DQN training
python tools/train.py configs/network_dismantling/FINDER.yaml

# PPO training
python tools/train.py configs/network_dismantling/CentriLearn.yaml

# With custom parameters
python tools/train.py configs/network_dismantling/CentriLearn.yaml --num_episodes 500 --batch_size 64

Testing

# Test trained model
python tools/test.py configs/network_dismantling/FINDER.yaml --checkpoint ./checkpoints/model.pth

Configuration

All components can be configured via YAML/JSON files:

algorithm:
  type: DQN
  model_cfg:
    type: Qnet
    backbone_cfg:
      type: GraphSAGE
      in_channels: 1
      hidden_channels: 64
      num_layers: 3
    q_head_cfg:
      type: QHead
      in_channels: 128
  device: cuda

Supported Algorithms

Algorithm Description
DQN Deep Q-Network with experience replay and target network
PPO Proximal Policy Optimization with clipped objective

Citation

If this project helps your research, please cite:

@misc{CentriLearn2026,
  title = {CentriLearn: A Reinforcement Learning Framework for Complex Networks},
  author = {CentriLearn Team},
  year = {2026},
  url = {https://github.com/He-JiYe/CentriLearn}
}

License

MIT License. See LICENSE for details.


If this project helps you, please give us a ⭐️!

About

CentriLearn is a reinforcement learning framework based on graph neural networks for solving combinatorial optimization problems in complex networks, such as network dismantling.

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors