更新时间:2023-12-02 11:44:40
tf.train.Saver
是一个用于编写检查点的被动"实用程序,它仅在其他代码调用其 .save()
方法.因此,写入检查点的速度取决于您用于训练模型的框架:
The tf.train.Saver
is a "passive" utility for writing checkpoints, and it only writes a checkpoint when some other code calls its .save()
method. Therefore, the rate at which checkpoints are written depends on what framework you are using to train your model:
如果您使用低级 TensorFlow API (tf.Session
) 并编写自己的训练循环,则只需插入对 Saver.save() 在您自己的代码中.一种常见的方法是根据迭代次数来执行此操作:
If you are using the low-level TensorFlow API (tf.Session
) and writing your own training loop, you can simply insert calls to Saver.save()
in your own code. A common approach is to do this based on the iteration count:
for i in range(NUM_ITERATIONS):
sess.run(train_op)
# ...
if i % 1000 == 0:
saver.save(sess, ...) # Write a checkpoint every 1000 steps.
如果您使用 tf.train.MonitoredTrainingSession
,它为你写了检查点,你可以在构造函数中指定一个检查点间隔(以秒为单位).默认情况下,它每 10 分钟保存一个检查点.要将其更改为每分钟,您可以执行以下操作:
If you are using tf.train.MonitoredTrainingSession
, which writes checkpoints for you, you can specify a checkpoint interval (in seconds) in the constructor. By default it saves a checkpoint every 10 minutes. To change this to every minute, you would do:
with tf.train.MonitoredTrainingSession(..., save_checkpoint_secs=60):
# ...