且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

在 PyTorch 中保存训练模型的***方法?

更新时间:2022-12-18 08:58:33

我找到了 这个页面在他们的 github repo 上,我会把内容贴在这里.

I've found this page on their github repo, I'll just paste the content here.

序列化和恢复模型有两种主要方法.

There are two main approaches for serializing and restoring a model.

第一个(推荐)只保存和加载模型参数:

The first (recommended) saves and loads only the model parameters:

torch.save(the_model.state_dict(), PATH)

然后:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

The second saves and loads the entire model:

torch.save(the_model, PATH)

然后:

the_model = torch.load(PATH)

但是在这种情况下,序列化的数据绑定到特定的类以及使用的确切目录结构,因此它可以以各种方式中断在其他项目中使用,或者经过一些严重的重构.

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.