且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

使用 PySpark 进行多类分类的逻辑回归问题

更新时间:2022-06-22 22:38:14

案例 1:这里没有什么奇怪的,只是(如错误消息所述)LogisticRegression 不支持多类分类,正如 文档.

Case 1: There is nothing strange here, simply (as the error message says) LogisticRegression does not support multi-class classification, as clearly stated in the documentation.

案例 2:在这里,您已从 ML 切换到 MLlib,但它不适用于数据帧,但需要将输入作为 LabeledPoint (文档),因此再次错误消息是预期的.

Case 2: Here you have switched from ML to MLlib, which however does not work with dataframes but needs the input as RDD of LabeledPoint (documentation), hence again the error message is expected.

案例 3:这就是事情变得有趣的地方.首先,您应该从 map 函数中删除括号,即它应该是

Case 3: Here is where things get interesting. First, you should remove the brackets from your map function, i.e. it should be

trainingData = trainingData.map(lambda row: LabeledPoint(row.label, row.features)) # no brackets after "row:"

尽管如此,根据您提供的代码片段猜测,您现在很可能会遇到不同的错误:

Nevertheless, guessing from the code snippets you have provided, most probably you are going to get a different error now:

model = LogisticRegressionWithLBFGS.train(trainingData, numClasses=5)
[...]
: org.apache.spark.SparkException: Input validation failed.

这是发生了什么(我花了一些时间才弄明白),使用了一些虚拟数据(提供一些示例数据总是一个好主意):

Here is what happening (it took me some time to figure it out), using some dummy data (it's always a good idea to provide some sample data with your question):

# 3-class classification
data = sc.parallelize([
     LabeledPoint(3.0, SparseVector(100,[10, 98],[1.0, 1.0])),
     LabeledPoint(1.0, SparseVector(100,[1, 22],[1.0, 1.0])),
     LabeledPoint(2.0, SparseVector(100,[36, 54],[1.0, 1.0]))
])

lrm = LogisticRegressionWithLBFGS.train(data, iterations=10, numClasses=3) # throws exception
[...]
: org.apache.spark.SparkException: Input validation failed.

问题是你的标签必须从 0 开始(这没有任何记录 - 你必须在 Scala 源代码 看看是不是这样!);因此,将上面我的虚拟数据中的标签从 (1.0, 2.0, 3.0) 映射到 (0.0, 1.0, 2.0),我们最终得到:

The problem is that your labels must start from 0 (and this is nowhere documented - you have to dig in the Scala source code to see that this is the case!); so, mapping the labels in my dummy data above from (1.0, 2.0, 3.0) to (0.0, 1.0, 2.0), we finally get:

# 3-class classification
data = sc.parallelize([
     LabeledPoint(2.0, SparseVector(100,[10, 98],[1.0, 1.0])),
     LabeledPoint(0.0, SparseVector(100,[1, 22],[1.0, 1.0])),
     LabeledPoint(1.0, SparseVector(100,[36, 54],[1.0, 1.0]))
])

lrm = LogisticRegressionWithLBFGS.train(data, iterations=10, numClasses=3) # no error now

从您的 numClasses=5 参数以及您打印的记录之一中的 label=5.0 来看,我想您的代码很可能受到同样的问题.将您的标签更改为 [0.0, 4.0],您应该没问题.

Judging from your numClasses=5 argument, as well as from the label=5.0 in one of your printed records, I guess that most probably your code suffers from the same issue. Change your labels to [0.0, 4.0] and you should be fine.

(我建议你删除你打开的另一个相同的问题 此处,用于减少杂乱...)

(I suggest that you delete the other identical question you have opened here, for reducing clutter...)