更新时间:2023-12-02 12:01:58
我认为您可以使用列表跟踪检查点.
您需要在训练中添加代码,因此需要对训练循环进行编码,并将变量保存在每个时期的末尾.
you need to add code in the training so you need to code your training loop and save the variable at the end of each epoch.
def fit_and_save_log(self, train_X, val_X, nb_epoch=50, batch_size=100, contractive=None):
import tensorflow as tf
optimizer = Adam(lr=0.0005)
self.autoencoder.compile(optimizer=optimizer, loss='binary_crossentropy') # kld, binary_crossentropy, mse
save = tf.train.Checkpoint()
save.listed = []
# Prepare dataset
X, y = train_X
train_ds = tf.data.Dataset.from_tensor_slices((x, y))
train_ds = train_ds.shuffle(10000)
train_ds = train_ds.batch(batch_size)
iterator = train_ds.make_initializable_iterator()
next_batch = iterator.get_next()
for epoch in range(nb_epoch):
sess.run(iterator.initializer)
while True:
try:
self.autoencoder.train_on_batch(next_batch[0], next_batch[1])
except tf.errors.OutOfRangeError:
break
save.listed.append(self.encoded_instant.topk_mat)
# you can compute validation results here
save_path = save.save('./topk_mat_log', session=tf.keras.backend.get_session())
return self
或者,如果愿意,可以使用model.fit
功能.用这种方法可以更容易,因为我们不需要关心创建批处理.但是,重复调用model.fit
可能会导致内存泄漏.您可以尝试一下并检查其行为. [1]
Or you can use the model.fit
function if you prefer it. Doing it this way can be easier, as we do not need to care about creating the batches. However, repeatedly calling model.fit
may result in memory leak. You can give it a try and check how it behaves. [1]
def fit_and_save_log(self, train_X, val_X, nb_epoch=50, batch_size=100, contractive=None):
import tensorflow as tf
optimizer = Adam(lr=0.0005)
self.autoencoder.compile(optimizer=optimizer, loss='binary_crossentropy') # kld, binary_crossentropy, mse
save = tf.train.Checkpoint()
save.listed = []
for epoch in range(nb_epoch):
self.autoencoder.fit(train_X[0], train_X[1],
epochs=1,
batch_size=batch_size,
shuffle=True,
validation_data=(val_X[0], val_X[1]))
save.listed.append(self.encoded_instant.topk_mat)
# you can compute validation results here
save_path = save.save('./topk_mat_log', session=tf.keras.backend.get_session())
return self
然后您可以像这样恢复保存的变量
Then you can restore the saved variable like this
restore = tf.train.Checkpoint()
restore.restore(save_path)
restore.listed = []
v1 = tf.Variable(0.)
restore.listed.append(v1) # Now v1 corresponds with topk_mat in the first epoch