且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

使用自定义签名defs保存TF2 keras模型

更新时间:2023-12-01 21:37:04

解决方案是为每个签名定义的函数创建一个tf.Module:

The solution is to create a tf.Module with functions for each signature definition:

class MyModule(tf.Module):
  def __init__(self, model, other_variable):
    self.model = model
    self._other_variable = other_variable

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32)])
  def score(self, waveform):
    result = self.model(waveform)
    return { "scores": results }

  @tf.function(input_signature=[])
  def metadata(self):
    return { "other_variable": self._other_variable }

然后保存模块(不是模型):

And then save the module (not the model):

module = MyModule(model, 1234)
tf.saved_model.save(module, export_path, signatures={ "score": module.score, "metadata": module.metadata })

在TF2上使用Keras模型进行了测试.

Tested with Keras model on TF2.