且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

了解 PyTorch LSTM 的输入形状

更新时间:2023-12-01 23:24:28

您已经解释了输入的结构,但是您还没有在输入维度和 LSTM 的预期输入维度之间建立联系.

You have explained the structure of your input, but you haven't made the connection between your input dimensions and the LSTM's expected input dimensions.

让我们分解您的输入(为维度分配名称):

Let's break down your input (assigning names to the dimensions):

  • batch_size:12
  • seq_len:384
  • input_size/num_features: 768
  • batch_size: 12
  • seq_len: 384
  • input_size / num_features: 768

这意味着 LSTM 的 input_size 需要是 768.

That means the input_size of the LSTM needs to be 768.

hidden_​​size 不依赖于你的输入,而是 LSTM 应该创建多少特征,然后用于隐藏状态和输出,因为这是最后一个隐藏状态.您必须决定要为 LSTM 使用多少个特征.

The hidden_size is not dependent on your input, but rather how many features the LSTM should create, which is then used for the hidden state as well as the output, since that is the last hidden state. You have to decide how many features you want to use for the LSTM.

最后,对于输入形状,设置 batch_first=True 要求输入具有形状 [batch_size, seq_len, input_size],在您的情况下,将 [12, 384, 768].

Finally, for the input shape, setting batch_first=True requires the input to have the shape [batch_size, seq_len, input_size], in your case that would be [12, 384, 768].

import torch
import torch.nn as nn

# Size: [batch_size, seq_len, input_size]
input = torch.randn(12, 384, 768)

lstm = nn.LSTM(input_size=768, hidden_size=512, batch_first=True)

output, _ = lstm(input)
output.size()  # => torch.Size([12, 384, 512])