且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何在keras中每个时期存储操作结果(如TOPK)

更新时间:2023-12-02 12:01:58

我认为您可以使用列表跟踪检查点.

您需要在训练中添加代码,因此需要对训练循环进行编码,并将变量保存在每个时期的末尾.

you need to add code in the training so you need to code your training loop and save the variable at the end of each epoch.

def fit_and_save_log(self, train_X, val_X, nb_epoch=50, batch_size=100, contractive=None):
    import tensorflow as tf
    optimizer = Adam(lr=0.0005)

    self.autoencoder.compile(optimizer=optimizer, loss='binary_crossentropy') # kld, binary_crossentropy, mse   
    
    save = tf.train.Checkpoint()
    save.listed = []
    
    # Prepare dataset
    X, y = train_X
    train_ds = tf.data.Dataset.from_tensor_slices((x, y))
    train_ds = train_ds.shuffle(10000)
    train_ds = train_ds.batch(batch_size)
    iterator = train_ds.make_initializable_iterator()
    next_batch = iterator.get_next()

    for epoch in range(nb_epoch):
        sess.run(iterator.initializer)           
        
        while True:
            try:
                self.autoencoder.train_on_batch(next_batch[0], next_batch[1])
            except tf.errors.OutOfRangeError:
                break
        
        save.listed.append(self.encoded_instant.topk_mat)

        # you can compute validation results here 

    save_path = save.save('./topk_mat_log', session=tf.keras.backend.get_session())
    return self

或者,如果愿意,可以使用model.fit功能.用这种方法可以更容易,因为我们不需要关心创建批处理.但是,重复调用model.fit可能会导致内存泄漏.您可以尝试一下并检查其行为. [1]

Or you can use the model.fit function if you prefer it. Doing it this way can be easier, as we do not need to care about creating the batches. However, repeatedly calling model.fit may result in memory leak. You can give it a try and check how it behaves. [1]

def fit_and_save_log(self, train_X, val_X, nb_epoch=50, batch_size=100, contractive=None):
    import tensorflow as tf
    optimizer = Adam(lr=0.0005)

    self.autoencoder.compile(optimizer=optimizer, loss='binary_crossentropy') # kld, binary_crossentropy, mse   
    
    save = tf.train.Checkpoint()
    save.listed = []
    
    for epoch in range(nb_epoch):
        self.autoencoder.fit(train_X[0], train_X[1],
                epochs=1,
                batch_size=batch_size,
                shuffle=True,
                validation_data=(val_X[0], val_X[1]))
        
        save.listed.append(self.encoded_instant.topk_mat)

        # you can compute validation results here 

    save_path = save.save('./topk_mat_log', session=tf.keras.backend.get_session())
    return self

然后您可以像这样恢复保存的变量

Then you can restore the saved variable like this

restore = tf.train.Checkpoint()
restore.restore(save_path)
restore.listed = []
v1 = tf.Variable(0.)
restore.listed.append(v1) # Now v1 corresponds with topk_mat in the first epoch