且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何在Keras中的LSTM层中解释权重

更新时间:2023-12-02 12:06:52

如果您使用的是 Keras 2.2.0

打印时

print(model.layers[0].trainable_weights)

您应该看到三个张量:lstm_1/kernel, lstm_1/recurrent_kernel, lstm_1/bias:0 每个张量的维度之一应该是

you should see three tensors: lstm_1/kernel, lstm_1/recurrent_kernel, lstm_1/bias:0 One of the dimensions of each tensor should be a product of

4 *单位数量

4 * number_of_units

其中 number_of_units 是您的神经元数量.试试:

where number_of_units is your number of neurons. Try:

units = int(int(model.layers[0].trainable_weights[0].shape[1])/4)
print("No units: ", units)

这是因为每个张量包含四个LSTM单位(按该顺序)的权重:

That is because each tensor contains weights for four LSTM units (in that order):

i(输入),f(忘记),c(单元格状态)和o(输出)

因此,为了提取权重,您可以简单地使用切片运算符:

Therefore in order to extract weights you can simply use slice operator:

W = model.layers[0].get_weights()[0]
U = model.layers[0].get_weights()[1]
b = model.layers[0].get_weights()[2]

W_i = W[:, :units]
W_f = W[:, units: units * 2]
W_c = W[:, units * 2: units * 3]
W_o = W[:, units * 3:]

U_i = U[:, :units]
U_f = U[:, units: units * 2]
U_c = U[:, units * 2: units * 3]
U_o = U[:, units * 3:]

b_i = b[:units]
b_f = b[units: units * 2]
b_c = b[units * 2: units * 3]
b_o = b[units * 3:]

来源: keras代码