且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

恢复keras seq2seq模型

更新时间:2023-12-01 23:20:04

好,我解决了这个问题,并且解码器产生了合理的结果.在上面的代码中,我错过了解码器步骤中的几个细节,特别是call()s LSTM和Dense层以便将它们连接起来.此外,新的解码器输入需要唯一的名称,因此它们不会与input_1和input_2冲突(此细节闻起来像是keras的错误).

Ok, I solved this problem and the decoder is producing reasonable results. In my code above I missed a couple details in the decoder step, specifically that it call()s the LSTM and Dense layers in order to wire them up. In addition, the new decoder inputs need unique names so they don't collide with input_1 and input_2 (this detail smells like a keras bug).

encoder_inputs = model.input[0] #input_1
encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1
encoder_states = [state_h_enc, state_c_enc]
encoder_model = Model(encoder_inputs, encoder_states)

decoder_inputs = model.input[1] #input_2
decoder_state_input_h = Input(shape=(latent_dim,),name='input_3')
decoder_state_input_c = Input(shape=(latent_dim,),name='input_4')
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_lstm = model.layers[3]
decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h_dec, state_c_dec]
decoder_dense = model.layers[4]
decoder_outputs=decoder_dense(decoder_outputs)

decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states)

此代码的一个主要缺点是我们事先了解完整的体系结构.我希望最终能够加载与体系结构无关的解码器.

A big drawback with this code is the fact we know the full architecture in advance. I would like to eventually be able to load an architecture-agnostic decoder.