论文标题
利用未标记的数据来预测分发性能
Leveraging Unlabeled Data to Predict Out-of-Distribution Performance
论文作者
论文摘要
现实世界的机器学习部署的特征是源(训练)和目标分布之间的不匹配,可能导致性能下降。在这项工作中,我们研究了仅使用标记的源数据和未标记的目标数据来预测目标域精度的方法。我们提出了平均阈值置信度(ATC),这是一种实用的方法,可以了解模型的置信度,以预测准确性,因为未标记的示例的分数超过了该阈值。 ATC胜过几种模型架构,分布转移类型(例如,由于合成损坏,数据集复制或新型亚群)以及数据集(Wild,ImageNet,Imagenet,Breeds,Cifar和Mnist)的先前方法。在我们的实验中,ATC估计目标性能$ 2 $ - $ 4 \ times $ $准确地比以前的方法更准确。我们还探索了问题的理论基础,证明通常,确定准确性与确定最佳预测指标一样困难,因此,任何方法的疗效都基于(也许是未说的)对移位性质的假设。最后,分析我们的一些玩具分布的方法,我们提供了有关何时工作的见解。代码可从https://github.com/saurabhgarg1996/atc_code/获得。
Real-world machine learning deployments are characterized by mismatches between the source (training) and target (test) distributions that may cause performance drops. In this work, we investigate methods for predicting the target domain accuracy using only labeled source data and unlabeled target data. We propose Average Thresholded Confidence (ATC), a practical method that learns a threshold on the model's confidence, predicting accuracy as the fraction of unlabeled examples for which model confidence exceeds that threshold. ATC outperforms previous methods across several model architectures, types of distribution shifts (e.g., due to synthetic corruptions, dataset reproduction, or novel subpopulations), and datasets (Wilds, ImageNet, Breeds, CIFAR, and MNIST). In our experiments, ATC estimates target performance $2$-$4\times$ more accurately than prior methods. We also explore the theoretical foundations of the problem, proving that, in general, identifying the accuracy is just as hard as identifying the optimal predictor and thus, the efficacy of any method rests upon (perhaps unstated) assumptions on the nature of the shift. Finally, analyzing our method on some toy distributions, we provide insights concerning when it works. Code is available at https://github.com/saurabhgarg1996/ATC_code/.