且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

使用自定义图层保存Keras模型

更新时间:2021-10-21 07:18:42

更正数字1是使用Custom_Objects,而loading Saved Model,即替换代码,

Correction number 1 is to use Custom_Objects while loading the Saved Model i.e., replace the code,

new_model = tf.keras.models.load_model('model.h5') 

new_model = tf.keras.models.load_model('model.h5', custom_objects={'CustomLayer': CustomLayer})

由于我们正在使用Custom Layersbuild Model,并且在Saving之前,因此在Loading时应使用Custom Objects.

Since we are using Custom Layers to build the Model and before Saving it, we should use Custom Objects while Loading it.

更正数字2是在自定义层的__init__功能中添加**kwargs,例如

Correction number 2 is to add **kwargs in the __init__ function of the Custom Layer like

def __init__(self, k, name=None, **kwargs):
        super(CustomLayer, self).__init__(name=name)
        self.k = k
        super(CustomLayer, self).__init__(**kwargs)

完整的工作代码如下所示:

Complete working code is shown below:

import tensorflow as tf

class CustomLayer(tf.keras.layers.Layer):
    def __init__(self, k, name=None, **kwargs):
        super(CustomLayer, self).__init__(name=name)
        self.k = k
        super(CustomLayer, self).__init__(**kwargs)


    def get_config(self):
        config = super(CustomLayer, self).get_config()
        config.update({"k": self.k})
        return config

    def call(self, input):
        return tf.multiply(input, 2)

model = tf.keras.models.Sequential([
    tf.keras.Input(name='input_layer', shape=(10,)),
    CustomLayer(10, name='custom_layer'),
    tf.keras.layers.Dense(1, activation='sigmoid', name='output_layer')
])
tf.keras.models.save_model(model, 'model.h5')
new_model = tf.keras.models.load_model('model.h5', custom_objects={'CustomLayer': CustomLayer})

print(new_model.summary())

以上代码的输出如下所示:

Output of the above code is shown below:

WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
custom_layer_1 (CustomLayer) (None, 10)                0         
_________________________________________________________________
output_layer (Dense)         (None, 1)                 11        
=================================================================
Total params: 11
Trainable params: 11
Non-trainable params: 0

希望这会有所帮助.学习愉快!

Hope this helps. Happy Learning!