0%

pytorch 实现 GRU

使用 pytorch 实现 GRU 网络。


相关博文



参考资料



参数



pytorch 相关实现


直接调用集成网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class GRUModel(nn.Module):

def __init__(self, input_num, hidden_num, output_num):
super(GRUModel, self).__init__()
self.hidden_size = hidden_num
# 这里设置了 batch_first=True, 所以应该 inputs = inputs.view(inputs.shape[0], -1, inputs.shape[1])
# 针对时间序列预测问题,相当于将时间步(seq_len)设置为 1。
self.GRU_layer = nn.GRU(input_size=input_num, hidden_size=hidden_num, batch_first=True)
self.output_linear = nn.Linear(hidden_num, output_num)
self.hidden = None

def forward(self, x):
# h_n of shape (num_layers * num_directions, batch, hidden_size)
# 这里不用显式地传入隐层状态 self.hidden
x, self.hidden = self.GRU_layer(x)
x = self.output_linear(x)
return x, self.hidden
请我喝杯咖啡吧~