论文标题

蒸馏一击联合学习

Distilled One-Shot Federated Learning

论文作者

Zhou, Yanlin, Pu, George, Ma, Xiyao, Li, Xiaolin, Wu, Dapeng

论文摘要

当前的联合学习算法需要在理想情况下和数百个数据分布较差的情况下进行数十回合,以传递笨拙的模型权重。受到有关数据集蒸馏的最新工作的启发,我们提出了蒸馏的一声联合学习(DOSFL),以显着降低沟通成本,同时实现可比的性能。在短短的一轮中,每个客户都会提炼其私有数据集,将综合数据(例如图像或句子)发送到服务器,并共同训练全局模型。蒸馏数据看起来像噪声,仅对特定模型权重有用,即在模型更新后变得无用。通过这种无重量和无梯度的设计,DOSFL的总通信成本比FedAvg少三个数量级,同时保留了集中式同行的93%至99%的性能。之后,客户可以切换到传统方法,例如FedAvg,以捕获最后几个百分比,以使用本地数据集拟合个性化的本地模型。通过全面的实验,我们通过包括CNN,LSTM,Transformer等不同模型在视觉和语言任务上的DOSFL在视觉和语言任务上的准确性和沟通性能。我们证明,窃听的攻击者无法使用泄漏的蒸馏数据来正确训练良好的模型,而不知道初始模型不知道初始模型。 DOSFL是一种廉价的方法,可以快速收敛于传统方法的沟通成本不到0.1%的表现预训练模型。

Current federated learning algorithms take tens of communication rounds transmitting unwieldy model weights under ideal circumstances and hundreds when data is poorly distributed. Inspired by recent work on dataset distillation and distributed one-shot learning, we propose Distilled One-Shot Federated Learning (DOSFL) to significantly reduce the communication cost while achieving comparable performance. In just one round, each client distills their private dataset, sends the synthetic data (e.g. images or sentences) to the server, and collectively trains a global model. The distilled data look like noise and are only useful to the specific model weights, i.e., become useless after the model updates. With this weight-less and gradient-less design, the total communication cost of DOSFL is up to three orders of magnitude less than FedAvg while preserving between 93% to 99% performance of a centralized counterpart. Afterwards, clients could switch to traditional methods such as FedAvg to finetune the last few percent to fit personalized local models with local datasets. Through comprehensive experiments, we show the accuracy and communication performance of DOSFL on both vision and language tasks with different models including CNN, LSTM, Transformer, etc. We demonstrate that an eavesdropping attacker cannot properly train a good model using the leaked distilled data, without knowing the initial model weights. DOSFL serves as an inexpensive method to quickly converge on a performant pre-trained model with less than 0.1% communication cost of traditional methods.

扫码加入交流群

加入微信交流群

微信交流群二维码

扫码加入学术交流群,获取更多资源