Official implementation of "Variational Perturbation Personalized Federated Learning via Prior-Posterior Distance".
VPFL is a personalized federated learning algorithm that leverages Prior-Posterior Distance (PPD) to achieve better personalization in non-IID environments.
-
Core Components:
- PPD (Prior-Posterior Distance): Measures the layerwise discrepancy between the global prior model and the local posterior model (
Γ = posterior − prior) - PPD-Guided Constrained Update: After local training, each client applies a constrained subtraction step with coefficient
c = 1 / (λ · min(max|Γ|)) - Distribution-Aware Adaptive Aggregation: The server weights clients by
α_k ∝ n_k^γ · exp(β · Sim_k), combining prior-posterior cosine similarity with data quantity - Variational Perturbation: After aggregation, the server applies layerwise Gaussian noise (low-order
N(0, Var/μ)/ high-orderN(0, Var), auto-selected by PPD magnitude)
- PPD (Prior-Posterior Distance): Measures the layerwise discrepancy between the global prior model and the local posterior model (
-
Optimized Performance:
- Fashion-MNIST Pathological: 86.49%
- CIFAR-10 Dirichlet: 68.08%
- CIFAR-10 Pathological: 58.72% (vs 51.19% baseline)
-
PFLlib-Compatible Architecture: Easy to integrate with existing FL frameworks
# Clone the repository
git clone https://github.com/yourusername/VPFL.git
cd VPFL
# Install dependencies
pip install -r requirements.txt# Generate CIFAR-10 datasets
cd dataset
python generate_Cifar10.py
# Generate Fashion-MNIST datasets
python generate_FashionMNIST.py
# Or generate all at once
bash generate_all_datasets.sh
cd ..# Train on Fashion-MNIST (5 clients, pathological)
python main.py --dataset FashionMNIST_5_pat --global_rounds 100
# Train on CIFAR-10 (5 clients, pathological)
python main.py --dataset Cifar10_5_pat --global_rounds 100
# Train with custom hyperparameters
python main.py \
--dataset FashionMNIST_5_pat \
--lambda_param 10.0 \
--momentum 0.9 \
--global_rounds 100from system.flcore.trainmodel.models import create_model
import torch
# Create model
model = create_model('Cifar10_5_pat', device='cuda')
# Load pre-trained weights
model.load_state_dict(torch.load('results/Cifar10_5_pat_VPFL_global_model.pt'))
# Evaluate
model.eval()| Dataset | # Clients | Partition | Directory |
|---|---|---|---|
| CIFAR-10 | 5 | Pathological | dataset/Cifar10_5_pat/ |
| CIFAR-10 | 10 | Dirichlet(0.1) | dataset/Cifar10_10_dir/ |
| Fashion-MNIST | 5 | Pathological | dataset/FashionMNIST_5_pat/ |
| Fashion-MNIST | 10 | Dirichlet(0.1) | dataset/FashionMNIST_10_dir/ |
VPFL_CONFIG = {
'lambda_param': 10.0, # PPD constraint strength
'mu': 3.0, # Perturbation scale control
'beta': 2.0, # Temperature for adaptive aggregation weighting
'gamma': 0.5, # Data quantity exponent for adaptive weighting
'lr_decay': 1.0, # Per-round learning rate decay (1.0 = no decay)
}
# Optimizer settings (set automatically inside ClientVPFL)
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.005,
momentum=0.9, # CRITICAL for best performance!
weight_decay=1e-4
)- lambda_param (λ): Controls the strength of the PPD constraint
c = 1/(λ · min(max|Γ|)). Higher values allow more personalization. - mu (μ): Controls low-order perturbation variance scaling (
Var/μ). - beta (β): Temperature of the similarity-based aggregation weights. Larger values sharpen the weight distribution.
- gamma (γ): Exponent of the data quantity factor
n_k^γin aggregation weights. - lr_decay: Per-round multiplicative decay applied to client learning rates (
lr_t = lr_0 · lr_decay^t). - momentum: SGD momentum. 0.9 is crucial for best performance (+5.52% improvement). Combined with
weight_decay=1e-4.
VPFL/
├── README.md # This file
├── LICENSE # Apache 2.0 License
├── requirements.txt # Python dependencies
├── main.py # Main entry point
│
├── dataset/ # Dataset generation scripts
│ ├── generate_Cifar10.py
│ ├── generate_FashionMNIST.py
│ └── generate_all_datasets.sh
│
├── system/ # Core algorithm (PFLlib-style)
│ └── flcore/
│ ├── servers/
│ │ ├── serverbase.py # Base server
│ │ └── servervpfl.py # VPFL server
│ ├── clients/
│ │ ├── clientbase.py # Base client
│ │ └── clientvpfl.py # VPFL client
│ └── trainmodel/
│ └── models.py # Neural network models
│
├── utils/ # Utilities
│ ├── data_utils.py # Data loading utilities
│ ├── result_utils.py # Result saving utilities
│ ├── ppd.py # Prior-Posterior Distance computation
│ ├── perturbation.py # Variational perturbation module
│ └── vpfl_core.py # Adaptive aggregation weighting
│
└── results/ # Results, saved models and JSON history
Each evaluation round records the following metrics (also dumped as JSON under
results/<dataset>_VPFL_seed<t>/round_XXXX.json, with the full run saved to
history.json):
- avg_acc / std_acc: Mean and standard deviation of per-client (personalized) test accuracy
- worst5_acc: 5th percentile of per-client accuracy (fairness indicator, higher is better)
- pgap (Personalization Gap): Mean of
local_acc − global_accper client (higher means personalization helps more) - global_avg_acc: Global model accuracy averaged over clients' test sets
bash scripts/run_experiments.sh# Quick test (20 rounds)
python main.py --dataset FashionMNIST_5_pat --global_rounds 20
# Full training (100 rounds)
python main.py --dataset Cifar10_5_pat --global_rounds 100
# With model saving
python main.py --dataset FashionMNIST_5_pat --global_rounds 100| Method | CIFAR-10 Path | Fashion Path | CIFAR-10 Dirichlet |
|---|---|---|---|
| VPFL (Ours) | 59.72% | 86.49% | 67.08% |
| FedAvg | ~50% | ~75% | ~65% |
# Generate datasets first
cd dataset
python generate_Cifar10.py
python generate_FashionMNIST.py# Reduce batch size
python main.py --dataset Cifar10_5_pat --batch_size 5# Use fewer rounds
python main.py --dataset Cifar10_5_pat --global_rounds 50If you use this code in your research, please cite:
@inproceedings{zhou2025variational,
title={Variational Perturbation Personalized Federated Learning via Prior-Posterior Distance},
author={Zhou, Hefeng and Wang, Yuanbin and Wang, Jun and Lou, Jiong and Bao, Wugedele and Wu, Chentao and Li, Jie},
booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={1--5},
year={2025},
organization={IEEE}
}This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
- Architecture inspired by PFLlib
- PyTorch team for the excellent deep learning framework