更新时间:2023-12-01 21:37:04
解决方案是为每个签名定义的函数创建一个tf.Module
:
The solution is to create a tf.Module
with functions for each signature definition:
class MyModule(tf.Module):
def __init__(self, model, other_variable):
self.model = model
self._other_variable = other_variable
@tf.function(input_signature=[tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32)])
def score(self, waveform):
result = self.model(waveform)
return { "scores": results }
@tf.function(input_signature=[])
def metadata(self):
return { "other_variable": self._other_variable }
然后保存模块(不是模型):
And then save the module (not the model):
module = MyModule(model, 1234)
tf.saved_model.save(module, export_path, signatures={ "score": module.score, "metadata": module.metadata })
在TF2上使用Keras模型进行了测试.
Tested with Keras model on TF2.