且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

sklearn.svm.svc 的函数 predict_proba() 在内部如何工作?

更新时间:2023-12-02 15:35:58

Scikit-learn 在内部使用 LibSVM,而这反过来使用 Platt 缩放,详见 LibSVM 作者的这篇笔记,用于校准 SVM 以产生除类别预测之外的概率.

Platt scaling 需要首先像往常一样训练 SVM,然后优化参数向量 AB 使得

P(y|X) = 1/(1 + exp(A * f(X) + B))

其中 f(X) 是样本与超平面的有符号距离(scikit-learn 的 decision_function 方法).您可能会在此定义中识别出 logistic sigmoid,与逻辑回归和神经网络使用的函数相同将决策函数转化为概率估计.

请注意:B 参数、拦截"或偏差"或您喜欢的任何称呼,可能会导致基于此模型的概率估计的预测与您的预测不一致从 SVM 决策函数 f 得到.例如.假设 f(X) = 10,则 X 的预测为正;但如果 B = -9.9A = 1,则 P(y|X) = .475.我是凭空得出这些数字的,但您已经注意到这可能在实践中发生.

实际上,Platt scaling 在交叉熵损失函数下在 SVM 的输出之上训练概率模型.为防止该模型过度拟合,它使用了内部五重交叉验证,这意味着使用 probability=True 训练 SVM 可能比普通的非概率 SVM 昂贵得多.>

I am using sklearn.svm.svc from scikit-learn to do binary classification. I am using its predict_proba() function to get probability estimates. Can anyone tell me how predict_proba() internally calculates the probability?

Scikit-learn uses LibSVM internally, and this in turn uses Platt scaling, as detailed in this note by the LibSVM authors, to calibrate the SVM to produce probabilities in addition to class predictions.

Platt scaling requires first training the SVM as usual, then optimizing parameter vectors A and B such that

P(y|X) = 1 / (1 + exp(A * f(X) + B))

where f(X) is the signed distance of a sample from the hyperplane (scikit-learn's decision_function method). You may recognize the logistic sigmoid in this definition, the same function that logistic regression and neural nets use for turning decision functions into probability estimates.

Mind you: the B parameter, the "intercept" or "bias" or whatever you like to call it, can cause predictions based on probability estimates from this model to be inconsistent with the ones you get from the SVM decision function f. E.g. suppose that f(X) = 10, then the prediction for X is positive; but if B = -9.9 and A = 1, then P(y|X) = .475. I'm pulling these numbers out of thin air, but you've noticed that this can occur in practice.

Effectively, Platt scaling trains a probability model on top of the SVM's outputs under a cross-entropy loss function. To prevent this model from overfitting, it uses an internal five-fold cross validation, meaning that training SVMs with probability=True can be quite a lot more expensive than a vanilla, non-probabilistic SVM.