且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

在自定义回调中访问验证数据

更新时间:2023-11-14 16:11:28

您可以直接在self.validation_data上进行迭代,以在每个时期结束时汇总所有验证数据.如果要计算精度,请在整个验证数据集中调用和F1:

You can iterate directly over self.validation_data to aggregate all the validation data at the end of each epoch. If you want to calculate precision, recall and F1 across the complete validation dataset:

# Validation metrics callback: validation precision, recall and F1
# Some of the code was adapted from https://medium.com/@thongonary/how-to-compute-f1-score-for-each-epoch-in-keras-a1acd17715a2
class Metrics(callbacks.Callback):

    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []

    def on_epoch_end(self, epoch, logs):
        # 5.4.1 For each validation batch
        for batch_index in range(0, len(self.validation_data)):
            # 5.4.1.1 Get the batch target values
            temp_targ = self.validation_data[batch_index][1]
            # 5.4.1.2 Get the batch prediction values
            temp_predict = (np.asarray(self.model.predict(
                                self.validation_data[batch_index][0]))).round()
            # 5.4.1.3 Append them to the corresponding output objects
            if(batch_index == 0):
                val_targ = temp_targ
                val_predict = temp_predict
            else:
                val_targ = np.vstack((val_targ, temp_targ))
                val_predict = np.vstack((val_predict, temp_predict))

        val_f1 = round(f1_score(val_targ, val_predict), 4)
        val_recall = round(recall_score(val_targ, val_predict), 4)
        val_precis = round(precision_score(val_targ, val_predict), 4)

        self.val_f1s.append(val_f1)
        self.val_recalls.append(val_recall)
        self.val_precisions.append(val_precis)

        # Add custom metrics to the logs, so that we can use them with
        # EarlyStop and csvLogger callbacks
        logs["val_f1"] = val_f1
        logs["val_recall"] = val_recall
        logs["val_precis"] = val_precis

        print("— val_f1: {} — val_precis: {} — val_recall {}".format(
                 val_f1, val_precis, val_recall))
        return

valid_metrics = Metrics()

然后,您可以将有效参数添加到回调参数:

Then you can add valid_metrics to the callback argument:

your_model.fit_generator(..., callbacks = [valid_metrics])

如果您希望其他回调使用这些措施,请确保将其放在回调的开头.

Be sure to put it at the beginning of the callbacks in case you want other callbacks to use these measures.