论文标题
超越Lipschitz:全批GD的急剧概括和过多的风险范围
Beyond Lipschitz: Sharp Generalization and Excess Risk Bounds for Full-Batch GD
论文作者
论文摘要
我们为平滑损耗(可能是非lipschitz,可能是非convex)提供了急剧依赖路径依赖的概括和多余的风险保证。我们分析的核心是对概括误差的上限,这意味着终止时的平均输出稳定性和有限的预期优化误差会导致概括。该结果表明,沿优化路径发生了一个小的概括误差,并使我们可以绕过Lipschitz或以前工作中普遍存在的损失的假设。对于非convex,凸面和强烈凸损失,我们在累积的路径依赖性优化误差,终端优化误差,样本数量和迭代次数方面显示了概括误差的明确依赖性。对于非convex平滑损耗,我们证明,全批GD有效地将接近任何终止固定点的概括,并恢复具有更少假设的随机算法保证的概括误差。对于平滑的凸损失,我们表明概括误差比SGD的现有界限更紧(最多一个误差幅度)。因此,多余的风险与SGD相匹配,而SGD的迭代次数更少。最后,对于强烈凸出的平滑损失,我们表明,与SGD上的最新情况相比,全批GD基本上达到的超额风险率基本相同,但迭代次数较小(数据集中的对数)。
We provide sharp path-dependent generalization and excess risk guarantees for the full-batch Gradient Descent (GD) algorithm on smooth losses (possibly non-Lipschitz, possibly nonconvex). At the heart of our analysis is an upper bound on the generalization error, which implies that average output stability and a bounded expected optimization error at termination lead to generalization. This result shows that a small generalization error occurs along the optimization path, and allows us to bypass Lipschitz or sub-Gaussian assumptions on the loss prevalent in previous works. For nonconvex, convex, and strongly convex losses, we show the explicit dependence of the generalization error in terms of the accumulated path-dependent optimization error, terminal optimization error, number of samples, and number of iterations. For nonconvex smooth losses, we prove that full-batch GD efficiently generalizes close to any stationary point at termination, and recovers the generalization error guarantees of stochastic algorithms with fewer assumptions. For smooth convex losses, we show that the generalization error is tighter than existing bounds for SGD (up to one order of error magnitude). Consequently the excess risk matches that of SGD for quadratically less iterations. Lastly, for strongly convex smooth losses, we show that full-batch GD achieves essentially the same excess risk rate as compared with the state of the art on SGD, but with an exponentially smaller number of iterations (logarithmic in the dataset size).