且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

在 Tensorflow 中恢复变量子集

更新时间:2023-12-02 19:11:28

要恢复变量子集,您必须创建一个新的 tf.train.Saver 并在可选的 var_list 参数中向其传递要恢复的特定变量列表.

To restore a subset of variables, you must create a new tf.train.Saver and pass it a specific list of variables to restore in the optional var_list argument.

默认情况下,tf.train.Saver 将创建操作,以便 (i) 在您调用 saver.restore().虽然这适用于大多数常见场景,但您必须提供更多信息才能处理变量的特定子集:

By default, a tf.train.Saver will create ops that (i) save every variable in your graph when you call saver.save() and (ii) lookup (by name) every variable in the given checkpoint when you call saver.restore(). While this works for most common scenarios, you have to provide more information to work with specific subsets of the variables:

  1. 如果你只想恢复变量的一个子集,你可以通过调用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX),假设你把g"网络放在一个共同的with tf.name_scope(G_NETWORK_PREFIX):tf.variable_scope(G_NETWORK_PREFIX): 块.然后,您可以将此列表传递给 tf.train.Saver 构造函数.

如果要恢复变量的子集和/或检查点中的变量具有不同的名称,您可以将字典作为var_list传递争论.默认情况下,检查点中的每个变量都与一个 key 相关联,这是其 tf.Variable.name 属性的值.如果目标图中的名称不同(例如,因为您添加了范围前缀),您可以指定一个字典,将字符串键(在检查点文件中)映射到 tf.Variable 对象(在目标中图).

If you want to restore a subset of the variable and/or they variables in the checkpoint have different names, you can pass a dictionary as the var_list argument. By default, each variable in a checkpoint is associated with a key, which is the value of its tf.Variable.name property. If the name is different in the target graph (e.g. because you added a scope prefix), you can specify a dictionary that maps string keys (in the checkpoint file) to tf.Variable objects (in the target graph).