且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

将元数据添加到 tensorflow 冻结图 pb

更新时间:2023-12-02 20:51:34

首先,是的,您应该使用新的 SavedModel 格式,因为它会得到 TF 团队的支持,并且也可以与 Keras 一起使用.您可以向模型添加一个额外的端点,它会返回一个带有 XML 数据字符串的常量张量(如您所述).

First of all, yes you should use the new SavedModel format, as it is what will be supported by the TF team going forwards, and works with Keras as well. You can add an additional endpoint to the model, that returns a constant tensor (as you mention) with a string of your XML data.

这很好,因为它是密封的——底层的保存模型格式并不重要,因为您的元数据保存在计算图本身中.

This is good because it's hermetic -- the underlying savemodel format does not matter, because your metadata is saved in the computation graph itself.

查看此问题的答案:保存 TF2 keras具有自定义签名 defs 的模型.对于 Keras,该答案并不能 100% 地为您提供帮助,因为它无法与 tf.keras.models.load 函数很好地互操作,因为它们将其包装在 tf.Module 中.幸运的是,如果添加 tf.function 装饰器,使用 tf.keras.Model 在 TF2 中也能正常工作:

See the answer to this question: Saving a TF2 keras model with custom signature defs . That answer doesn't get you 100% of the way there for Keras, because it doesn't interop nicely with the tf.keras.models.load function, as they wrap it inside a tf.Module. Luckily, using tf.keras.Model works as well in TF2, if you add a tf.function decorator:

class MyModel(tf.keras.Model):

  def __init__(self, metadata, **kwargs):
    super(MyModel, self).__init__(**kwargs)
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.metadata = tf.constant(metadata)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

  @tf.function(input_signature=[])
  def get_metadata(self):
    return self.metadata

model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)

然后您可以按如下方式保存和加载您的模型:

Then you can save and load your model as follows:

tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')

最后使用 model_loaded.get_metadata() 来检索您的常量元数据张量.

And finally use model_loaded.get_metadata() to retrieve your constant metadata tensor.