1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| class GRUModel(nn.Module):
def __init__(self, input_num, hidden_num, output_num): super(GRUModel, self).__init__() self.hidden_size = hidden_num self.GRU_layer = nn.GRU(input_size=input_num, hidden_size=hidden_num, batch_first=True) self.output_linear = nn.Linear(hidden_num, output_num) self.hidden = None
def forward(self, x): x, self.hidden = self.GRU_layer(x) x = self.output_linear(x) return x, self.hidden
|