且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

PyTorch 获取模型的所有层

更新时间:2023-12-01 23:28:40

您可以使用 modules() 方法.这是一个简单的例子:

>>>模型 = nn.Sequential(nn.Linear(2, 2),nn.ReLU(),nn.Sequential(nn.Linear(2, 1),nn.Sigmoid()))>>>l = [model.modules() 中的模块的模块,如果不是 isinstance(module, nn.Sequential)]>>>升[线性(输入特征=2,输出特征=2,偏差=真),ReLU(),线性(输入特征=2,输出特征=1,偏差=真),Sigmoid()]

What's the easiest way to take a pytorch model and get a list of all the layers without any nn.Sequence groupings? For example, a better way to do this?

import pretrainedmodels

def unwrap_model(model):
    for i in children(model):
        if isinstance(i, nn.Sequential): unwrap_model(i)
        else: l.append(i)

model = pretrainedmodels.__dict__['xception'](num_classes=1000, pretrained='imagenet')
l = []
unwrap_model(model)            
            
print(l)
    

You can iterate over all modules of a model (including those inside each Sequential) with the modules() method. Here's a simple example:

>>> model = nn.Sequential(nn.Linear(2, 2), 
                          nn.ReLU(),
                          nn.Sequential(nn.Linear(2, 1),
                          nn.Sigmoid()))

>>> l = [module for module in model.modules() if not isinstance(module, nn.Sequential)]

>>> l

[Linear(in_features=2, out_features=2, bias=True),
 ReLU(),
 Linear(in_features=2, out_features=1, bias=True),
 Sigmoid()]