且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

同时运行多个预训练的 Tensorflow 网络

更新时间:2022-04-11 01:06:59

最简单的解决方案是创建不同的会话,为每个模型使用单独的图形:

The easiest solution is to create different sessions that use separate graphs for each model:

# Build a graph containing `net1`.
with tf.Graph().as_default() as net1_graph:
  net1 = CreateAlexNet()
  saver1 = tf.train.Saver(...)
sess1 = tf.Session(graph=net1_graph)
saver1.restore(sess1, 'epoch_10.ckpt')

# Build a separate graph containing `net2`.
with tf.Graph().as_default() as net2_graph:
  net2 = CreateAlexNet()
  saver2 = tf.train.Saver(...)
sess2 = tf.Session(graph=net1_graph)
saver2.restore(sess2, 'epoch_50.ckpt')

如果由于某种原因这不起作用,并且您必须使用单个 tf.Session(例如,因为您想在另一个 TensorFlow 计算中组合来自两个网络的结果),***解决方法是:


If this doesn't work for some reason, and you have to use a single tf.Session (e.g. because you want to combine results from the two network in another TensorFlow computation), the best solution is to:

  1. 像您已经在做的那样在名称范围内创建不同的网络,并且
  2. 创建单独的tf.train.Saver 两个网络的实例,带有一个额外的参数来重新映射变量名称.

构建储户,您可以将字典作为 var_list 参数传递,将检查点中的变量名称(即没有名称范围前缀)映射到您的 tf.Variable 对象已在每个模型中创建.

When constructing the savers, you can pass a dictionary as the var_list argument, mapping the names of the variables in the checkpoint (i.e. without the name scope prefix) to the tf.Variable objects you've created in each model.

您可以以编程方式构建 var_list,并且您应该能够执行以下操作:

You can build the var_list programmatically, and you should be able to do something like the following:

with tf.name_scope("net1"):
  net1 = CreateAlexNet()
with tf.name_scope("net2"):
  net2 = CreateAlexNet()

# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.
net1_varlist = {v.name.lstrip("net1/"): v
                for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)

# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.
net2_varlist = {v.name.lstrip("net2/"): v
                for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)

# ...
net1_saver.restore(sess, "epoch_10.ckpt")
net2_saver.restore(sess, "epoch_50.ckpt")