且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

将CNN Pytorch中的预训练砝码传递到Tensorflow中的CNN

更新时间:2022-05-04 00:45:21

您可以非常简单地检查所有 keras 层的所有权重的形状:

You can check shapes of all weights of all keras layers quite simply:

for layer in model.layers:
    print([tensor.shape for tensor in layer.get_weights()])

这将为您提供所有权重(包括偏差)的形状,因此您可以相应地准备加载的 numpy 权重.

This would give you shapes of all weights (including biases), so you can prepare loaded numpy weights accordingly.

要设置它们,请执行类似的操作:

To set them, do something similar:

for torch_weight, layer in zip(model.layers, torch_weights):
    layer.set_weights(torch_weight)

其中 torch_weights 应该是包含要加载的 np.array 列表的列表.

where torch_weights should be a list containing lists of np.array which you would have to load.

通常,每个 torch_weights 的元素将包含一个 np.array 用于权重,一个用于偏置.

Usually each element of torch_weights would contain one np.array for weights and one for bias.

记住从打印中收到的形状必须与您在 set_weights 中放入的形状完全相同.

Remember shapes received from print have to be exactly the same as the ones you put in set_weights.

有关更多信息,请参见文档.

See documentation for more info.

顺便说一句.确切的形状取决于图层和模型执行的操作,有时可能需要转置一些数组以适合它们".

BTW. Exact shapes are dependent on layers and operations performed by model, you may have to transpose some arrays sometimes to "fit them in".