且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何在 Tensorflow 中使用 CheckpointReader 恢复变量

更新时间:2023-12-02 19:20:28

你可以使用 string.split() 获取张量名称:

You could use string.split() to get the tensor name:

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
    tensor_name = v.name.split(':')[0]
    print tensor_name
    if reader.has_tensor(tensor_name):
        print 'has tensor'
...

接下来,让我用一个例子来说明如何从 .cpkt 文件中恢复每个可能的变量.首先,让我们将 v2v3 保存在 tmp.ckpt 中:

Next, let me use an example to show how I would restore every possible variable from a .cpkt file. First, let's save v2 and v3 in tmp.ckpt:

import tensorflow as tf

v1 = tf.Variable(tf.ones([1]), name='v1')
v2 = tf.Variable(2 * tf.ones([1]), name='v2')
v3 = tf.Variable(3 * tf.ones([1]), name='v3')

saver = tf.train.Saver({'v2': v2, 'v3': v3})

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.save(sess, 'tmp.ckpt')

这就是我如何恢复出现在 tmp.ckpt 中的每个变量(属于一个新图形):

That's how I would restore every variable (belonging to a new graph) showing up in tmp.ckpt:

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros([1]), name='v1')
    v2 = tf.Variable(tf.zeros([1]), name='v2')

    reader = tf.train.NewCheckpointReader('tmp.ckpt')
    restore_dict = dict()
    for v in tf.trainable_variables():
        tensor_name = v.name.split(':')[0]
        if reader.has_tensor(tensor_name):
            print('has tensor ', tensor_name)
            restore_dict[tensor_name] = v

    saver = tf.train.Saver(restore_dict)
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver.restore(sess, 'tmp.ckpt')
        print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)]

此外,您可能希望确保形状和数据类型匹配.

Also, you may want to ensure that shapes and dtypes match.