最新下载
热门教程
- 1
 - 2
 - 3
 - 4
 - 5
 - 6
 - 7
 - 8
 - 9
 - 10
 
pytorch lstm gru rnn得到每个state输出操作代码
时间:2022-06-25 01:58:26 编辑:袖梨 来源:一聚教程网
本篇文章小编给大家分享一下pytorch lstm gru rnn得到每个state输出操作代码,文章代码介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。
默认只返回最后一个state,所以一次输入一个step的input
# coding=UTF-8
import torch
import torch.autograd as autograd  # torch中自动计算梯度模块
import torch.nn as nn  # 神经网络模块
torch.manual_seed(1)
# lstm单元输入和输出维度都是3
lstm = nn.LSTM(input_size=3, hidden_size=3)
# 生成一个长度为5,每一个元素为1*3的序列作为输入,这里的数字3对应于上句中第一个3
inputs = [autograd.Variable(torch.randn((1, 3)))
          for _ in range(5)]
# 设置隐藏层维度,初始化隐藏层的数据
hidden = (autograd.Variable(torch.randn(1, 1, 3)),
          autograd.Variable(torch.randn((1, 1, 3))))
for i in inputs:
  out, hidden = lstm(i.view(1, 1, -1), hidden)
  print(out.size())
  print(hidden[0].size())
  print("--------")
print("-----------------------------------------------")
# 下面是一次输入多个step的样子
inputs_stack = torch.stack(inputs)
out,hidden = lstm(inputs_stack,hidden)
print(out.size())
print(hidden[0].size())
print结果:
(1L, 1L, 3L)
(1L, 1L, 3L)
--------
(1L, 1L, 3L)
(1L, 1L, 3L)
--------
(1L, 1L, 3L)
(1L, 1L, 3L)
--------
(1L, 1L, 3L)
(1L, 1L, 3L)
--------
(1L, 1L, 3L)
(1L, 1L, 3L)
--------
----------------------------------------------
(5L, 1L, 3L)
(1L, 1L, 3L)
可见LSTM的定义都是不用变的,根据input的step数目,一次输入多少step,就一次输出多少output,但只输出最后一个state
相关文章
- 原神杜林圣遗物选择推荐 11-04
 - 百度网盘SVIP激活码能用的有哪些 百度网盘vip免费领取 11-04
 - 打个螺丝兑换码能用的有哪些 2025最新有效兑换码汇总 11-04
 - 抓大鹅有效兑换码有哪些 2025最新可用兑换码大全 11-04
 - 密室出逃兑换码能用的有哪些 2025最新有效兑换码大全 11-04
 - 猪了个猪兑换码最新可用 2025最新有效兑换码汇总 11-04