This is the PyTorch implemention of our paper HarmoFL: Harmonizing Local and Global Drifts in Federated Learning on Heterogeneous Medical Images by Meirui Jiang, Zirui Wang and Qi Dou.
Multiple medical institutions collaboratively training a model using federated learning (FL) has become a promising solution for maximizing the potential of data-driven models, yet the non-independent and identically distributed (non-iid) data in medical images is still an outstanding challenge in real-world practice. The feature heterogeneity caused by diverse scanners or sensors introduces a drift in the learning process, in both local (client) and global (server) optimizations, which harms the convergence as well as model performance. Many previous works have attempted to address the non-iid issue by tackling the drift locally or globally, but how to jointly solve the two essentially coupled drifts is still unclear. In this work, we concentrate on handling both local and global drifts and introduce a new harmonizing framework called HarmoFL. First, we propose to mitigate the local update drift by normalizing amplitudes of images transformed into the frequency domain to mimic a unified scanner/sensor, in order to generate a harmonized feature space across local clients. Second, based on harmonized features, we design a client weight perturbation guiding each local model to reach a flat optimum, where a neighborhood area of the local optimal solution has a uniformly low loss. Without any extra communication cost, the perturbation assists the global model to optimize towards a converged optimal solution by aggregating several local flat optima. We have theoretically analyzed the proposed method and empirically conducted extensive experiments on three medical image classification and segmentation tasks, showing that HarmoFL outperforms a set of recent state-of-the-art methods with promising convergence behavior.
Conda
We recommend using conda to setup the environment, See the requirements.yaml
for environment configuration
If there is no conda installed on your PC, please find the installers from https://www.anaconda.com/products/individual
If you have already installed conda, please use the following commands.
conda env create -f environment.yaml
conda activate harmofl
Build cython file
build cython file for amplitude normalization
python utils/setup.py build_ext --inplace
- Please download the histology breast cancer classification datasets here, extract and put folder 'patches' under
data/camelyon17
directory:
- Please download the prostate MRI datasets here, put the folder
data
underdata/prostate
directory.
fed_train.py
is the main file to run the federated experiments
Please using following commands to train a model with federated learning strategy.
bash train.sh
Below please find some useful options:
- --alpha :specify the degree of weight perturbation, default is 0.05.
- --wk_iters :specify the local update epochs, default is 1.
suppose your test model's path is 'model/data/harmofl'
bash test.sh