且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

在Keras中使用的自定义注意层

更新时间:2023-12-01 21:58:34

在您共享的代码中,您似乎想要在代码中实现Bahdanau的关注层.您要关注所有值"(上一层输出-所有隐藏状态),而查询"将是解码器的最后一个隐藏状态.您的代码实际上应该非常简单,并且应该像这样:

From the code you have shared, looks like you want to implement Bahdanau's attention layer in your code. You want to attend to all the 'values' (prev layer output - all its hidden states) and your 'query' would be the last hidden state of the decoder. Your code should actually be very simple and should look like:

        class Bahdanau(tf.keras.layers.Layer):
            def __init__(self, n):
                super(Bahdanau, self).__init__()
                self.w = tf.keras.layers.Dense(n)
                self.u = tf.keras.layers.Dense(n)
                self.v = tf.keras.layers.Dense(1)
        
            def call(self, query, values):
                query = tf.expand_dims(query, 1)
                e = self.v(tf.nn.tanh(self.w(query) + self.u(values)))
                a = tf.nn.softmax(e, axis=1)
                c = a * h
                c = tf.reduce_sum(c, axis=1)
                return a,c
        
        ##Say we want 10 units in the single layer MLP determining w,u
        attentionlayer = Bahdanau(10)
        ##Call with i/p: decoderstate @ t-1 and all encoder hidden states
        a, c = attentionlayer(stminus1, hj)
    

我们没有在代码中的任何地方指定张量形状.此代码将为您返回一个与"stminus1"(即查询")大小相同的上下文张量.它是在使用Bahdanau的注意力机制处理所有值"(解码器的所有输出状态)之后执行此操作的.

We are not specifying the tensor shape anywhere in the code. This code will return you a context tensor of same size as 'stminus1' which is the 'query'. It does this after attending to all the 'values' (all output states of decoder) using Bahdanau's attention mechanism.

因此,假设您的批处理大小为32,时间步长= 100,嵌入尺寸= 2048,则stminus1的形状应为(32,2048),hj的形状应为(32,100,2048).输出上下文的形状将为(32,2048).我们还返回了100个注意权重,以防万一您希望将它们路由到一个漂亮的显示器上.

So assuming your batch size is 32, timesteps=100 and embedding dimension=2048, the shape of stminus1 should be (32,2048) and the shape of the hj should be (32,100,2048). The shape of the output context would be (32,2048). We also returned the 100 attention weights just in case you want to route them to a nice display.

这是注意"的最简单版本.如果您还有其他意图,请告诉我,我将重新格式化我的答案.有关更多详细信息,请参阅 https://towardsdatascience.com/create-your-own-custom-attention-layer-understand-all-flavors-2201b5e8be9e

This is the simplest version of 'Attention'. If you have any other intent, please let me know and I will reformat my answer. For more specific details, please refer https://towardsdatascience.com/create-your-own-custom-attention-layer-understand-all-flavours-2201b5e8be9e