论文标题
可训练的重量平均:加速训练和改善概括
Trainable Weight Averaging: Accelerating Training and Improving Generalization
论文作者
论文摘要
平均体重是一种广泛使用的技术,用于加速训练并改善深神经网络(DNNS)的概括。尽管现有的方法平均(SWA)等现有方法依赖于预设的加权方案,但处理多种权重时它们可能是次优的。我们引入了可训练的重量平均(TWA),这是一种新颖的优化方法,在候选重量跨越的子空间内运行,并通过优化学习最佳的加权系数。 TWA提供了更大的灵活性,可以应用于不同的培训场景。对于大规模应用程序,我们开发了一个分布式培训框架,该培训框架将并行计算与投影矩阵的低位压缩结合在一起,有效地管理内存和计算需求。可以使用培训数据(TWA-T)或验证数据(TWA-V)实施TWA,后者提供了更有效的平均。广泛的实验展示了TWA的优势:(i)在早期培训期间应用时,它始终优于SWA,(ii)将训练时间降低了40 \%\%在CIFAR数据集上,在ImagEnet上的培训时间降低了30 \%,而在可比较性能的同时,(III)在微调范围内通过可比较的性能,并通过模型逐步逐步逐步逐步逐渐增强。总而言之,我们为平均训练重量提供了一个有效的框架。该代码可在https://github.com/nblt/twa上找到。
Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40\% on CIFAR datasets and 30\% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at https://github.com/nblt/TWA.