且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

带有Keras的单词级Seq2Seq

更新时间:2023-12-02 20:12:10

最近我也遇到了这个问题.没有其他解决方案,然后在generator中创建小批次,例如batch_size=64,然后代替model.fit执行model.fit_generator.我在下面附加了generate_batch代码:

Recently I was also facing this problem. There is no other solution then creating small batches say batch_size=64 in a generator and then instead of model.fit do model.fit_generator. I have attached my generate_batch code below:

def generate_batch(X, y, batch_size=64):
    ''' Generate a batch of data '''
    while True:
        for j in range(0, len(X), batch_size):
            encoder_input_data = np.zeros((batch_size, max_encoder_seq_length),dtype='float32')
            decoder_input_data = np.zeros((batch_size, max_decoder_seq_length+2),dtype='float32')
            decoder_target_data = np.zeros((batch_size, max_decoder_seq_length+2, num_decoder_tokens),dtype='float32')

            for i, (input_text_seq, target_text_seq) in enumerate(zip(X[j:j+batch_size], y[j:j+batch_size])):
                for t, word_index in enumerate(input_text_seq):
                    encoder_input_data[i, t] = word_index # encoder input seq

                for t, word_index in enumerate(target_text_seq):
                    decoder_input_data[i, t] = word_index
                    if (t>0)&(word_index<=num_decoder_tokens):
                        decoder_target_data[i, t-1, word_index-1] = 1.

            yield([encoder_input_data, decoder_input_data], decoder_target_data)

然后像这样进行训练:

batch_size = 64
epochs = 2

# Run training
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

model.fit_generator(
    generator=generate_batch(X=X_train_sequences, y=y_train_sequences, batch_size=batch_size),
    steps_per_epoch=math.ceil(len(X_train_sequences)/batch_size),
    epochs=epochs,
    verbose=1,
    validation_data=generate_batch(X=X_val_sequences, y=y_val_sequences, batch_size=batch_size),
    validation_steps=math.ceil(len(X_val_sequences)/batch_size),
    workers=1,
    )

X_train_sequences[[23,34,56], [2, 33544, 6, 10]]之类的列表的列表.
其他人也一样.

X_train_sequences is list of lists like [[23,34,56], [2, 33544, 6, 10]].
Similarly others.

还从此博客获得了帮助-

Also took help from this blog - word-level-english-to-marathi-nmt