且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

使用seq2seq API(1.1版及更高版本)的Tensorflow Sequence到序列模型

更新时间:2023-12-02 20:08:10

解码层:

解码由的两个部分培训推断期间的差异:

The decoding consists of two parts because of their differences during training and inference:


在特定时间步长的解码器输入始终来自前一个时间步长的输出
。但是在训练过程中,输出被 固定
到了实际目标(实际目标作为输入被反馈),这表明可以提高性能。 p>

The decoder input at a particular time-step always comes from the output of the previous time-step. But during training, the output is fixed to the actual target (the actual target is fed back as input) and this has shown to improve performance.

这两个都是使用 tf.contrib.seq2seq 中的方法处理的。

Both these are handled using methods from tf.contrib.seq2seq.


  1. 解码器的主要功能是: seq2seq .dynamic解码器()执行动态解码:

  1. The main function for the decoder is: seq2seq.dynamic decoder() which performs dynamic decoding:

tf.contrib.seq2seq.dynamic_decode(decoder,maximum_iterations)

这需要一个 Decoder 实例和 maximum_iterations = maximum seq length 作为输入。

This takes a Decoder instance and maximum_iterations=maximum seq length as inputs.

1.1 Decoder 实例来自:

seq2seq.BasicDecoder(单元格,助手,initial_state,output_layer)

输入为: cell (一个RNNCell实例), helper (帮助程序实例), initial_state (解码器的初始状态 output_layer (可选的密集层作为进行预测的输出)应该是编码器的输出状态

The inputs are: cell (an RNNCell instance), helper (helper instance), initial_state (initial state of the decoder which should be the output state of the encoder) and output_layer (an optional dense layer as outputs to makes predictions)

1.2 RNNCell实例可以是 rnn.MultiRNNCell()

1.2 An RNNCell instance can be a rnn.MultiRNNCell().

1.3 helper 实例与培训推断不同。在培训期间,我们希望输入被馈送到解码器,而在推理期间,我们希望输出解码器在时间步长(t)中作为输入传递给解码器,在时间步长(t + 1)中>。

1.3 The helper instance is the one that differs in training and inference. During training, we want the inputs to be fed to the decoder, while during inference, we want the output of the decoder in time-step (t) to be passed as the input to the decoder in time step (t+1).

培训:我们使用辅助函数:
seq2seq.TrainingHelper(inputs,sequence_length),它只读取输入。

For training: we use the helper function: seq2seq.TrainingHelper(inputs, sequence_length), which just read inputs.

推断:我们调用辅助函数:
seq2seq.GreedyEmbeddingHelper()或seqseq.SampleEmbeddingHelper(),这是使用 argmax()还是抽样(来自分布)输出,并将结果通过嵌入层传递以获得下一个输入。

For inference: we call the helper function: seq2seq.GreedyEmbeddingHelper() or seqseq.SampleEmbeddingHelper(), which differs whether it to use argmax() or sampling(from a distribution) of the outputs and passes the result through an embedding layer to get the next input.

放在一起:Seq2Seq模型


  1. 编码器层获取编码器状态,并将其作为 initial_state $ c $传递c>到解码器。

  2. 获取解码器序列解码器推断$ c $的输出c>使用 seq2seq.dynamic_decoder()。当调用这两种方法时,请确保权重是共享的。 (使用 variable_scope 重用权重)

  3. 然后使用损失函数 seq2seq.sequence_loss 。

  1. Get the encoder state from the encoder layer and passed it as a initial_state to the decoder.
  2. Get the outputs of decoder train and decoder inference using seq2seq.dynamic_decoder(). When your calling both the methods make sure the weights are shared. (Use variable_scope to reuse the weights)
  3. Then train the network using the loss function seq2seq.sequence_loss.

给出了示例代码此处此处