且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

在 Tensorflow 中可视化注意力激活

更新时间:2023-12-02 23:23:10

I also want to visualize the attention weights of Tensorflow seq2seq ops for my text summarization task. And I think the temporary solution is to use session.run() to evaluate the attention mask tensor as mentioned above. Interestingly, the original seq2seq.py ops is considered legacy version and can’t be found in github easily so I just used the seq2seq.py file in the 0.12.0 wheel distribution and modified it. To draw the heatmap, I used the 'Matplotlib' package, which is very convenient.

The final output of attention visualization for news headline textsum looks like this:

I modified the code as below: https://github.com/rockingdingo/deepnlp/tree/master/deepnlp/textsum#attention-visualization

seq2seq_attn.py

# Find the attention mask tensor in function attention_decoder()-> attention()
# Add the attention mask tensor to ‘return’ statement of all the function that calls the attention_decoder(), 
# all the way up to model_with_buckets() function, which is the final function I use for bucket training.

def attention(query):
  """Put attention masks on hidden using hidden_features and query."""
  ds = []  # Results of attention reads will be stored here.

  # some code

  for a in xrange(num_heads):
    with variable_scope.variable_scope("Attention_%d" % a):
      # some code

      s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
                              [2, 3])
      # This is the attention mask tensor we want to extract
      a = nn_ops.softmax(s)

      # some code

  # add 'a' to return function
  return ds, a

seq2seq_model_attn.py

# modified model.step() function and return masks tensor
self.outputs, self.losses, self.attn_masks = seq2seq_attn.model_with_buckets(…)

# use session.run() to evaluate attn masks
attn_out = session.run(self.attn_masks[bucket_id], input_feed)
attn_matrix = ...

predict_attn.py and eval.py

# Use the plot_attention function in eval.py to visual the 2D ndarray during prediction.

eval.plot_attention(attn_matrix[0:ty_cut, 0:tx_cut], X_label = X_label, Y_label = Y_label)

And probably in the future tensorflow will have better way to extract and visualize the attention weight map. Any thoughts?