且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何在Keras中使用“有状态"变量/张量创建自定义图层?

更新时间:2023-12-01 22:06:52

定义自定义图层有时会造成混乱.您重写的某些方法将被调用一次,但给您的印象是,就像许多其他OO库/框架一样,它们将被多次调用.

Defining a custom layer can become confusing some times. Some of the methods that you override are going to be called once but it gives you the impression that just like many other OO libraries/frameworks, they are going to be called many times.

这是我的意思:当您定义一个图层并在模型中使用它时,为覆盖call方法编写的python代码将不会在向前或向后传递中直接调用.相反,当您调用model.compile时,它仅被调用一次.它将python代码编译成一个计算图,而张量将在其中流动的图就是训练和预测期间的计算.

Here is what I mean: When you define a layer and use it in a model the python code that you write for overriding call method is not going to be directly called in forward or backward passes. Instead, it's called only once when you call model.compile. It compiles the python code to a computational graph and that graph in which the tensors will flow is what does the computations during training and prediction.

这就是为什么如果您想通过放置print语句来调试模型的原因;您需要使用tf.print向图形添加打印命令.

That's why if you want to debug your model by putting a print statement it won't work; you need to use tf.print to add a print command to the graph.

与您要拥有的状态变量的情况相同.除了简单地将old + update分配给new之外,您还需要调用Keras函数,以将该操作添加到图形中.

It is the same situation with the state variable you want to have. Instead of simply assigning old + update to new you need to call a Keras function that adds that operation to the graph.

请注意,张量是不可变的,因此您需要在__init__方法中将状态定义为tf.Variable.

And note that tensors are immutable so you need to define the state as tf.Variable in the __init__ method.

所以我相信这段代码更像您要寻找的东西:

So I believe this code is more like what you're looking for:

class CustomLayer(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super(CustomLayer, self).__init__(**kwargs)
    self.state = tf.Variable(tf.zeros((3,3), 'float32'))
    self.constant = tf.constant([[1,1,1],[1,0,-1],[-1,0,1]], 'float32')
    self.extra_constant = tf.constant([[1,1,1],[1,0,-1],[-1,0,1]], 'float32')
    self.trainable = False

  def call(self, X):
    m = self.constant    
    c = self.extra_constant
    outputs = self.state + tf.matmul(X, m) + c
    tf.keras.backend.update(self.state, tf.reduce_sum(outputs, axis=0))

    return outputs