且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

TensorFlow 将图形保存到文件中/从文件加载图形

更新时间:2023-12-02 20:47:22

有很多方法可以解决在 TensorFlow 中保存模型的问题,这可能会让人有点困惑.依次回答您的每个子问题:

There are many ways to approach the problem of saving a model in TensorFlow, which can make it a bit confusing. Taking each of your sub-questions in turn:

  1. 检查点文件(例如通过调用saver.save()tf.train.Saver 对象)仅包含权重,以及在同一程序中定义的任何其他变量.要在另一个程序中使用它们,您必须重新创建关联的图结构(例如,通过运行代码再次构建它,或调用 tf.import_graph_def()),它告诉 TensorFlow 如何处理这些权重.请注意,调用 saver.save() 还会生成一个包含 MetaGraphDef,其中包含一个图表以及如何将来自检查点的权重与该图表相关联的详细信息.有关详细信息,请参阅教程.

  1. The checkpoint files (produced e.g. by calling saver.save() on a tf.train.Saver object) contain only the weights, and any other variables defined in the same program. To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or calling tf.import_graph_def()), which tells TensorFlow what to do with those weights. Note that calling saver.save() also produces a file containing a MetaGraphDef, which contains a graph and details of how to associate the weights from a checkpoint with that graph. See the tutorial for more details.

tf.train.write_graph() 只写图结构;不是权重.

tf.train.write_graph() only writes the graph structure; not the weights.

Bazel 与读取或写入 TensorFlow 图无关.(也许我误解了您的问题:请随时在评论中澄清.)

Bazel is unrelated to reading or writing TensorFlow graphs. (Perhaps I misunderstand your question: feel free to clarify it in a comment.)

可以使用 tf.import_graph_def() 加载冻结图.在这种情况下,权重(通常)嵌入在图中,因此您无需加载单独的检查点.

A frozen graph can be loaded using tf.import_graph_def(). In this case, the weights are (typically) embedded in the graph, so you don't need to load a separate checkpoint.

主要的变化是更新输入模型的张量的名称,以及从模型中获取的张量的名称.在 TensorFlow Android 演示中,这将对应于传递给 TensorFlowClassifier.initializeTensorFlow().

The main change would be to update the names of the tensor(s) that are fed into the model, and the names of the tensor(s) that are fetched from the model. In the TensorFlow Android demo, this would correspond to the inputName and outputName strings that are passed to TensorFlowClassifier.initializeTensorFlow().

GraphDef 是程序结构,通常不会在训练过程中改变.检查点是训练过程状态的快照,通常在训练过程的每一步都会发生变化.因此,TensorFlow 对这些类型的数据使用不同的存储格式,低级 API 提供了不同的方式来保存和加载它们.更高级别的库,例如 MetaGraphDef 库,Kerasskflow 以这些机制为基础,提供更便捷的方法来保存和恢复整个模型.>

The GraphDef is the program structure, which typically does not change through the training process. The checkpoint is a snapshot of the state of a training process, which typically changes at every step of the training process. As a result, TensorFlow uses different storage formats for these types of data, and the low-level API provides different ways to save and load them. Higher-level libraries, such as the MetaGraphDef libraries, Keras, and skflow build on these mechanisms to provide more convenient ways to save and restore an entire model.