更新时间:2022-04-29 03:38:35
您可以将您的图层放在 ModuleList
容器:
You can put your layers in a ModuleList
container:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim):
super(Net, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
current_dim = input_dim
self.layers = nn.ModuleList()
for hdim in hidden_dim:
self.layers.append(nn.Linear(current_dim, hdim))
current_dim = hdim
self.layers.append(nn.Linear(current_dim, output_dim))
def forward(self, x):
for layer in self.layers[:-1]:
x = F.relu(layer(x))
out = F.softmax(self.layers[-1](x))
return out
对层使用 pytorch Containers 非常重要,而不是只是一个简单的python 列表.请参阅此答案以了解原因.
It is very important to use pytorch Containers for the layers, and not just a simple python lists. Please see this answer to know why.