且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

Tensorflow:加载预训练的 ResNet 模型时出错

更新时间:2022-06-26 01:22:15

也许你可以使用来自 tf.keras.applications?

Maybe you could use ResNet50 from tf.keras.applications?

根据错误,如果您没有以任何方式更改图形,而且这是您的整个源代码,那么可能真的很难调试.

According to the error, if you haven't altered the graph in any way, and this is your whole source code, it might be really, really hard to debug.

如果您选择合理的tf.keras.applications.resnet50 方式,您可以简单地做到像这样:

If you choose the sane tf.keras.applications.resnet50 way you could do it simply like this:

import tensorflow

in_width, in_height, in_channels = 224, 224, 3

pretrained_resnet = tensorflow.keras.applications.ResNet50(
    weights="imagenet",
    include_top=False,
    input_shape=(in_width, in_height, in_channels),
)

# You can freeze some layers if you want, depends on your task
# Make "top" (last 3 layers below) whatever fits your task as well

model = tensorflow.keras.models.Sequential(
    [
        pretrained_resnet,
        tensorflow.keras.layers.Flatten(),
        tensorflow.keras.layers.Dense(1024, activation="relu"),
        tensorflow.keras.layers.Dense(10, activation="softmax"),
    ]
)

print(model.summary())

现在推荐这种方法,特别是考虑到即将推出的 Tensorflow 2.0、健全性和可读性.顺便提一句.这个模型和Tensorflow提供的一样,都是从IIRC转过来的.

This approach would be the recommended now, especially in the light of upcoming Tensorflow 2.0, sanity and readability. BTW. This model is the same as the one provided by Tensorflow, it's transferred from it IIRC.

您可以在链接的文档和各种博客文章(如 tf.keras.applications 的更多信息-tuning-using-pre-trained-models/" rel="nofollow noreferrer">这个 或其他网络资源.

You can read more about tf.keras.applications in the linked documentation and in various blog posts like this one or other web resources.

回答评论中的问题

  • 如何将图像传递到网络?:如果要进行预测,请使用model.predict(image),其中图像是np.array.就这么简单.

  • How do I pass images to the network?: use model.predict(image) if you want to make a prediction, where image is np.array. Simple as that.

我如何访问权重?:嗯,这个更复杂……开个玩笑,每一层都有 .get_weights() 方法返回它是权重和偏差,您可以使用 for layer in model.layers() 迭代层.您也可以使用 model.get_weights() 一次性获取所有权重.

How do I access weights?: well, this one is more complicated... Just kidding, each layer has .get_weights() method which returns it's weights and biases, you can iterate over layers with for layer in model.layers(). You can get all weights at once using model.get_weights() as well.

总而言之,您将学习 Keras,并且在比 Tensorflow 更高效的时间内调试此问题.他们有30 秒指南是有原因的.

All in all, you will learn Keras and be more productive in it than in Tensorflow in a shorter time than you can debug this issue. They have 30 seconds guide for a reason.

顺便说一句. Tensorflow 默认提供 Keras,因此,Tensorflow 的 Keras 风格是 Tensorflow 的一部分(无论这听起来多么令人困惑).这就是我在示例中使用 tensorflow 的原因.

BTW. Tensorflow has Keras shipped by default and as such, Tensorflow's Keras flavor is part of Tensorflow (no matter how confusing this sounds). That's why I have used tensorflow in my example.

似乎您可以使用 Hub 加载和微调 Resn​​et50,如此链接.

Is seems you could load and fine-tune Resnet50 using Hub as described in this link.