且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

PyTorch LSTM输入尺寸

更新时间:2023-12-01 23:37:58

这是一个古老的问题,但是由于已被查看了80多次而没有任何响应,因此让我对其进行解释.

This is an old question, but since it has been viewed 80+ times with no response, let me take a crack at it.

LSTM网络用于预测序列.在NLP中,这将是一个单词序列;在经济学中,一系列经济指标;等

An LSTM network is used to predict a sequence. In NLP, that would be a sequence of words; in economics, a sequence of economic indicators; etc.

第一个参数是这些序列的长度.如果序列数据是由句子组成的,那么汤姆的猫又黑又丑"是一个长度为7(seq_len)的序列,每个单词一个,或者可能是第8个,表示句子的结尾.

The first parameter is the length of those sequences. If you sequence data is made of sentences, then "Tom has a black and ugly cat" is a sequence of length 7 (seq_len), one for each word, and maybe an 8th to indicate the end of the sentence.

当然,您可能会反对如果我的序列长度不同会怎样?"这是常见的情况.

Of course, you might object "what if my sequences are of varying length?" which is a common situation.

两个最常见的解决方案是:

The two most common solutions are:

  1. 使用空元素填充序列.例如,如果最长的句子有15个单词,则将上面的句子编码为"[Tom] [has] [a] [black] [and] [ugly] [cat] [EOS] [] [] [][] [] [] []",其中EOS代表句子结尾.突然,您所有的序列长度都变为15,这解决了您的问题.一旦找到[EOS]令牌,该模型就会迅速得知,它后面是无限制的空令牌序列[],这种方法几乎不会给您的网络增加负担.

  1. Pad your sequences with empty elements. For instance, if the longest sentence you have has 15 words, then encode the sentence above as "[Tom] [has] [a] [black] [and] [ugly] [cat] [EOS] [] [] [] [] [] [] []", where EOS stands for end of sentence. Suddenly, all your sequences become of length 15, which solves your issue. As soon as the [EOS] token is found, the model will learn quickly that it is followed by an unlimited sequence of empty tokens [], and that approach will barely tax your network.

发送相同长度的迷你批.例如,在所有句子上使用2个单词训练网络,然后使用3个单词,然后使用4个单词.当然,每个小批量的seq_len都会增加,每个小批量的大小将根据长度为N的序列数而变化您的数据中就有.

Send mini-batches of equal lengths. For instance, train the network on all sentences with 2 words, then with 3, then with 4. Of course, seq_len will be increased at each mini batch, and the size of each mini batch will vary based on how many sequences of length N you have in your data.

***的方法是将数据分成大小大致相等的小批量,按近似长度将其分组,并仅添加必要的填充.例如,如果您将长度为6、7和8的句子最小化在一起,那么长度为8的序列将不需要填充,而长度为6的序列将只需要2.,那是***的方法.

A best-of-both-world approach would be to divide your data into mini batches of roughly equal sizes, grouping them by approximate length, and adding only the necessary padding. For instance, if you mini-batch together sentences of length 6, 7 and 8, then sequences of length 8 will require no padding, whereas sequence of length 6 will require only 2. If you have a large dataset with sequences of widely varying length, that's the best approach.

但是,方法1是最简单(也是最懒惰的)方法,并且在小型数据集上效果很好.

Option 1 is the easiest (and laziest) approach, though, and will work great on small datasets.

最后一件事...始终在数据末尾而不是在开头填充数据.

One last thing... Always pad your data at the end, not at the beginning.

我希望能帮上忙.