且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何使用 keras 构建注意力模型?

更新时间:2023-12-02 22:48:58

初始化注意力层和传递参数的方式有问题.你应该在这个地方指定attention layer单元的数量并修改传入参数的方式:

There is a problem with the way you initialize attention layer and pass parameters. You should specify the number of attention layer units in this place and modify the way of passing in parameters:

context_vector, attention_weights = Attention(32)(lstm, state_h)

结果:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 200)          0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 200, 128)     32000       input_1[0][0]                    
__________________________________________________________________________________________________
bi_lstm_0 (Bidirectional)       [(None, 200, 256), ( 263168      embedding[0][0]                  
__________________________________________________________________________________________________
bidirectional (Bidirectional)   [(None, 200, 256), ( 394240      bi_lstm_0[0][0]                  
                                                                 bi_lstm_0[0][1]                  
                                                                 bi_lstm_0[0][2]                  
                                                                 bi_lstm_0[0][3]                  
                                                                 bi_lstm_0[0][4]                  
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 256)          0           bidirectional[0][1]              
                                                                 bidirectional[0][3]              
__________________________________________________________________________________________________
attention (Attention)           [(None, 256), (None, 16481       bidirectional[0][0]              
                                                                 concatenate[0][0]                
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 1)            257         attention[0][0]                  
==================================================================================================
Total params: 706,146
Trainable params: 706,146
Non-trainable params: 0
__________________________________________________________________________________________________
None