且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

了解Tensorflow LSTM输入形状

更新时间:2023-12-02 08:37:10

tf.nn.dynamic_rnn 状态的文档:

The documentation of tf.nn.dynamic_rnn states:

inputs:RNN输入.如果为time_major == False(默认值),则必须为形状为[batch_size, max_time, ...]的张量,或此类元素的嵌套元组.

inputs: The RNN inputs. If time_major == False (default), this must be a Tensor of shape: [batch_size, max_time, ...], or a nested tuple of such elements.

在您的情况下,这意味着输入的形状应为[batch_size, 10, 2].不必一次训练所有4000个序列,而是在每个训练迭代中仅使用batch_size许多序列.像下面这样的东西应该起作用(为清楚起见添加了重塑):

In your case, this means that the input should have a shape of [batch_size, 10, 2]. Instead of training on all 4000 sequences at once, you'd use only batch_size many of them in each training iteration. Something like the following should work (added reshape for clarity):

batch_size = 32
# batch_size sequences of length 10 with 2 values for each timestep
input = get_batch(X, batch_size).reshape([batch_size, 10, 2])
# Create LSTM cell with state size 256. Could also use GRUCell, ...
# Note: state_is_tuple=False is deprecated;
# the option might be completely removed in the future
cell = tf.nn.rnn_cell.LSTMCell(256, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(cell,
                                   input,
                                   sequence_length=[10]*batch_size,
                                   dtype=tf.float32)

文档的形状将为[batch_size, 10, 256],即每个时间步长为256个输出. state将是元组形状为[batch_size, 256]的a>.您可以据此预测最终值(每个序列一个):

From the documentation, outputs will be of shape [batch_size, 10, 256], i.e. one 256-output for each timestep. state will be a tuple of shapes [batch_size, 256]. You could predict your final value, one for each sequence, from that:

predictions = tf.contrib.layers.fully_connected(state.h,
                                                num_outputs=1,
                                                activation_fn=None)
loss = get_loss(get_batch(Y).reshape([batch_size, 1]), predictions)

outputsstate形状的数字256由cell.output_size确定. cell.state_size.当像上面创建LSTMCell时,它们是相同的.另请参见 LSTMCell文档.

The number 256 in the shapes of outputs and state is determined by cell.output_size resp. cell.state_size. When creating the LSTMCell like above, these are the same. Also see the LSTMCell documentation.