Federated Learning in PyTorch
Implementations of various Federated Learning (FL) algorithms in PyTorch, especially for research purposes.
Implementation Details
Datasets
- Supports all image classification datasets in
torchvision.datasets
. - Supports all text classification datasets in
torchtext.datasets
. - Supports all datasets in LEAF benchmark (NO need to prepare raw data manually)
- Supports additional image classification datasets (
TinyImageNet
,CINIC10
). - Supports additional text classification datasets (
BeerReviews
). - Supports tabular datasets (
Heart
,Adult
,Cover
). - Supports temporal dataset (
GLEAM
) - NOTE: don't bother to search raw files of datasets; the dataset can automatically be downloaded to the designated path by just passing its name!
Statistical Heterogeneity Simulations
IID
(i.e., statistical homogeneity)Unbalanced
(i.e., sample counts heterogeneity)Pathological Non-IID
(McMahan et al., 2016)Dirichlet distribution-based Non-IID
(Hsu et al., 2019)Pre-defined
(for datasets having natural semantic separation, includingLEAF
benchmark (Caldas et al., 2018))
Models
LogReg
(logistic regression),StackedTransformer
(TransformerEncoder-based classifier)TwoNN
,TwoCNN
,SimpleCNN
(McMahan et al., 2016)FEMNISTCNN
,Sent140LSTM
(Caldas et al., 2018))LeNet
(LeCun et al., 1998),MobileNet
(Howard et al., 2019),SqueezeNet
(Iandola et al., 2016),VGG
(Simonyan et al., 2014),ResNet
(He et al., 2015)MobileNeXt
(Daquan et al., 2020),SqueezeNeXt
(Gholami et al., 2016),MobileViT
(Mehta et al., 2021)DistilBERT
(Sanh et al., 2019),SqueezeBERT
(Iandola et al., 2020),MobileBERT
(Sun et al., 2020)M5
(Dai et al., 2016)
Algorithms
FedAvg
andFedSGD
(McMahan et al., 2016) Communication-Efficient Learning of Deep Networks from Decentralized DataFedAvgM
(Hsu et al., 2019) Measuring the Effects of Non-Identical Data Distribution for Federated Visual ClassificationFedProx
(Li et al., 2018) Federated Optimization in Heterogeneous NetworksFedOpt
(FedAdam
,FedYogi
,FedAdaGrad
) (Reddi et al., 2020) Adaptive Federated Optimization
Evaluation schemes
local
: evaluate FL algorithm using holdout sets of (some/all) clients NOT participating in the current round. (i.e., evaluation of personalized federated learning setting)global
: evaluate FL algorithm using global holdout set located at the server. (ONLY available if the raw dataset supports pre-defined validation/test set).both
: evaluate FL algorithm using bothlocal
andglobal
schemes.
Metrics
- Top-1 Accuracy, Top-5 Accuracy, Precision, Recall, F1
- Area under ROC, Area under PRC, Youden's J
- Seq2Seq Accuracy
- MSE, RMSE, MAE, MAPE
-
$R^2$ ,$D^2$
Requirements
- See
requirements.txt
. (I recommend building an independent environment for this project, using e.g.,Docker
orconda
) - When you install
torchtext
, please check the version compatibility withtorch
. (See official repository) - Plus, please install
torch
-related packages using one command provided by the official guide (See official installation guide); e.g.,conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 torchtext==0.13.0 cudatoolkit=11.6 -c pytorch -c conda-forge
Configurations
- See
python3 main.py -h
.
Example Commands
- See shell files prepared in
commands
directory.
TODO
- Support another model, especially lightweight ones for cross-device FL setting. (e.g.,
EdgeNeXt
) - Support another structured dataset including temporal and tabular data, along with datasets suitable for cross-silo FL setting. (e.g.,
MedMNIST
) - Add other popular FL algorithms including personalized FL algorithms (e.g.,
SuPerFed
). - Attach benchmark results of sample commands.
Contact
Should you have any feedback, please create a thread in issue tab. Thank you :)