且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何在Keras中进行逐点分类交叉熵损失?

更新时间:2022-11-29 16:50:42

找到了此问题确认我的直觉.

简而言之:softmax将采用2D或3D输入.如果它们是3D角膜,则将采用这样的形状(样本,时维,数字类),并将softmax应用于最后一个.出于某些奇怪的原因,它没有对4D张量执行此操作.

解决方案:将输出调整为一系列像素

reshaped_output = Reshape((height*width, num_classes))(output_tensor)

然后应用您的softmax

new_output = Activation('softmax')(reshaped_output) 

然后将目标张量重塑为2D或将最后一层重塑为(宽度,高度和num_classes).

否则,如果我现在不在手机上,我会尝试使用timedistributed(Activation('softmax')).但是不知道这样是否可行...稍后再试

我希望这会有所帮助:-)

I have a network that produces a 4D output tensor where the value at each position in spatial dimensions (~pixel) is to be interpreted as the class probabilities for that position. In other words, the output is (num_batches, height, width, num_classes). I have labels of the same size where the real class is coded as one-hot. I would like to calculate the categorical-crossentropy loss using this.

Problem #1: The K.softmax function expects a 2D tensor (num_batches, num_classes)

Problem #2: I'm not sure how the losses from each position should be combined. Is it correct to reshape the tensor to (num_batches * height * width, num_classes) and then calling K.categorical_crossentropy on that? Or rather, call K.categorical_crossentropy(num_batches, num_classes) height*width times and average the results?

Found this issue to confirm my intuition.

In short : the softmax will take 2D or 3D inputs. If they are 3D keras will assume a shape like this (samples, timedimension, numclasses) and apply the softmax on the last one. For some weird reasons, it doesnt do that for 4D tensors.

Solution : reshape your output to a sequence of pixels

reshaped_output = Reshape((height*width, num_classes))(output_tensor)

Then apply your softmax

new_output = Activation('softmax')(reshaped_output) 

And then either you reshape your target tensors to 2D or you just reshape that last layer into (width, height, num_classes).

Otherwise, something I would try if i wasn't on my phone right now is to use a timedistributed(Activation('softmax')). But no idea if that would work... will try later

I hope this helps :-)