且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

带有预训练卷积基的 keras 模型中损失函数的奇怪行为

更新时间:2023-12-02 23:53:22

看起来我找到了解决方案.正如我所建议的,问题出在 BatchNormalization 层上.他们制作树的东西

  1. 减去均值并按标准进行归一化
  2. 使用运行平均值收集均值和标准差的统计数据
  3. 训练两个额外的参数(每个节点两个).

当设置 trainableFalse 时,这两个参数 freeze 和 layer 也停止收集对均值和标准差的统计.但看起来该层在训练期间仍然使用训练批次执行标准化.很可能是 keras 中的错误,或者他们出于某种原因故意这样做.因此,与预测时间相比,训练期间前向传播的计算是不同的即使可训练属性设置为 False.

我能想到两种可能的解决方案:

  1. 将所有 BatchNormalization 层设置为可训练.在这种情况下,这些层将从您的数据集中收集统计信息,而不是使用预训练的(可能会有很大不同!).在这种情况下,您将在训练期间将所有 BatchNorm 层调整为您的自定义数据集.
  2. 将模型分成两部分model=model_base+model_top.之后,使用model_base通过model_base.predict()提取特征,然后将这些特征输入model_top,只训练model_top.

我刚刚尝试了第一个解决方案,看起来很有效:

model.fit(x=dat[0],y=dat[1],batch_size=32)时代 1/132/32 [==============================] - 1 秒 28 毫秒/步 - 损失:**3.1053**模型.评估(x=dat[0],y=dat[1])32/32 [==============================] - 0 秒 10 毫秒/步**2.487905502319336**

这是在一些训练之后 - 需要等到收集到足够的均值和标准差统计数据.

第二个解决方案我还没有尝试过,但我很确定它会起作用,因为训练和预测期间的前向传播将是相同的.

更新.我找到了一篇很棒的博客文章,其中详细讨论了这个问题.查看这里

I'm trying to create a model in Keras to make numerical predictions from the pictures. My model has densenet121 convolutional base, with couple of additional layers on top. All layers except for the two last ones are set to layer.trainable = False. My loss is mean squared error, since it's a regression task. During training I get loss: ~3, while evaluation on the very same batch of the data gives loss: ~30:

model.fit(x=dat[0],y=dat[1],batch_size=32)

Epoch 1/1 32/32 [==============================] - 0s 11ms/step - loss: 2.5571

model.evaluate(x=dat[0],y=dat[1])

32/32 [==============================] - 2s 59ms/step 29.276123046875

I feed exactly the same 32 pictures during training and evaluation. And I also calculated loss using predicted values from y_pred=model.predict(dat[0]) and then constructed mean squared error using numpy. The result was the same as what I've got from evaluation (i.e. 29.276123...).

There was suggestion that this behavior might be due to BatchNormalization layers in convolutional base (discussion on github). Of course, all BatchNormalization layers in my model have been set to layer.trainable=False as well. Maybe somebody has encountered this problem and figured out the solution?

Looks like I found the solution. As I have suggested the problem is with BatchNormalization layers. They make tree things

  1. subtract mean and normalize by std
  2. collect statistics on mean and std using running average
  3. train two additional parameters (two per node).

When one sets trainable to False, these two parameters freeze and layer also stops collecting statistic on mean and std. But it looks like the layer still performs normalization during training time using the training batch. Most likely it's a bug in keras or maybe they did it on purpose for some reason. As a result the calculations on forward propagation during training time are different as compared with prediction time even though the trainable atribute is set to False.

There are two possible solutions i can think of:

  1. To set all BatchNormalization layers to trainable. In this case these layers will collect statistics from your dataset instead of using pretrained one (which can be significantly different!). In this case you will adjust all the BatchNorm layers to your custom dataset during the training.
  2. Split the model in two parts model=model_base+model_top. After that, use model_base to extract features by model_base.predict() and then feed these features into model_top and train only the model_top.

I've just tried the first solution and it looks like it's working:

model.fit(x=dat[0],y=dat[1],batch_size=32)

Epoch 1/1
32/32 [==============================] - 1s 28ms/step - loss: **3.1053**

model.evaluate(x=dat[0],y=dat[1])

32/32 [==============================] - 0s 10ms/step
**2.487905502319336**

This was after some training - one need to wait till enough statistics on mean and std are collected.

Second solution i haven't tried yet, but i'm pretty sure it's gonna work since forward propagation during training and prediction will be the same.

Update. I found a great blog post where this issue has been discussed in all the details. Check it out here