更新时间:2022-04-11 01:06:59
最简单的解决方案是创建不同的会话,为每个模型使用单独的图形:
The easiest solution is to create different sessions that use separate graphs for each model:
# Build a graph containing `net1`.
with tf.Graph().as_default() as net1_graph:
net1 = CreateAlexNet()
saver1 = tf.train.Saver(...)
sess1 = tf.Session(graph=net1_graph)
saver1.restore(sess1, 'epoch_10.ckpt')
# Build a separate graph containing `net2`.
with tf.Graph().as_default() as net2_graph:
net2 = CreateAlexNet()
saver2 = tf.train.Saver(...)
sess2 = tf.Session(graph=net1_graph)
saver2.restore(sess2, 'epoch_50.ckpt')
如果由于某种原因这不起作用,并且您必须使用单个 tf.Session
(例如,因为您想在另一个 TensorFlow 计算中组合来自两个网络的结果),***解决方法是:
If this doesn't work for some reason, and you have to use a single tf.Session
(e.g. because you want to combine results from the two network in another TensorFlow computation), the best solution is to:
tf.train.Saver
两个网络的实例,带有一个额外的参数来重新映射变量名称.当构建储户,您可以将字典作为 var_list
参数传递,将检查点中的变量名称(即没有名称范围前缀)映射到您的 tf.Variable
对象已在每个模型中创建.
When constructing the savers, you can pass a dictionary as the var_list
argument, mapping the names of the variables in the checkpoint (i.e. without the name scope prefix) to the tf.Variable
objects you've created in each model.
您可以以编程方式构建 var_list
,并且您应该能够执行以下操作:
You can build the var_list
programmatically, and you should be able to do something like the following:
with tf.name_scope("net1"):
net1 = CreateAlexNet()
with tf.name_scope("net2"):
net2 = CreateAlexNet()
# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.
net1_varlist = {v.name.lstrip("net1/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)
# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.
net2_varlist = {v.name.lstrip("net2/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)
# ...
net1_saver.restore(sess, "epoch_10.ckpt")
net2_saver.restore(sess, "epoch_50.ckpt")