且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何在TF2.0中使用自定义渐变创建keras层?

更新时间:2023-11-17 23:48:52

首先,keras 下 API 的统一"(如您所称)并不会阻止您像在 TensorFlow 1 中所做的那样做.X.会话可能会消失,但您仍然可以像任何 python 函数一样定义您的模型,并在没有 keras 的情况下热切地训练它(即通过 tf.GradientTape)

First of all, the "unification" of the APIs (as you call it) under keras doesn't prevent you from doing things like you did in TensorFlow 1.x. Sessions might be gone but you can still define your model like any python function and train it eagerly without keras (i.e. through tf.GradientTape)

现在,如果您想构建一个带有自定义层的 keras 模型,该层执行自定义操作并具有自定义渐变,您应该做到以下几点:

Now, if you want to build a keras model with a custom layer that performs a custom operation and has a custom gradient, you should do the following:

a) 编写一个函数来执行您的自定义操作并定义您的自定义渐变.有关如何执行此操作的更多信息此处.

a) Write a function that performs your custom operation and define your custom gradient. More info on how to do this here.

@tf.custom_gradient
def custom_op(x):
    result = ... # do forward computation
    def custom_grad(dy):
        grad = ... # compute gradient
        return grad
    return result, custom_grad

注意在函数中你应该把xdy当作张量而不是numpy数组(即执行张量操作)

Note that in the function you should treat x and dy as Tensors and not numpy arrays (i.e. perform tensor operations)

b) 创建一个自定义 keras 层来执行您的 custom_op.在这个例子中,我假设你的层没有任何可训练的参数或改变它的输入的形状,但如果有的话也没有太大区别.为此,您可以参考您发布的指南,查看这个.>

b) Create a custom keras layer that performs your custom_op. For this example I'll assume that your layer doesn't have any trainable parameters or change the shape of its input, but it doesn't make much difference if it does. For that you can refer to the guide that you posted check this one.

class CustomLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(CustomLayer, self).__init__()

    def call(self, x):
        return custom_op(x)  # you don't need to explicitly define the custom gradient
                             # as long as you registered it with the previous method

现在你可以在 keras 模型中使用这个层,它会起作用.例如:

Now you can use this layer in a keras model and it will work. For example:

inp = tf.keras.layers.Input(input_shape)
conv = tf.keras.layers.Conv2D(...)(inp)  # add params like the number of filters
cust = CustomLayer()(conv)  # no parameters in custom layer
flat = tf.keras.layers.Flatten()(cust)
fc = tf.keras.layers.Dense(num_classes)(flat)

model = tf.keras.models.Model(inputs=[inp], outputs=[fc])
model.compile(loss=..., optimizer=...)  # add loss function and optimizer
model.fit(...)  # fit the model