RNN梯度更新BPTT

1.损失函数

对于分类任务来说,rnn使用交叉熵损失。

其中:

  • T为序列长度
  • C为类别数
  • yt,i​为真实标签的one-hot表示
  • yhat t,i​为模型预测的概率分布

预测的输出是C维张量,每个数值代表这个类别的概率。真实标签也是一个C维张量,但是只有一个值为1。

对于回归任务来说,rnn使用均方损失。

2.梯度更新过程

假设我们使用一个时间步为3的rnn,那么损失函数由三部分组成。

L1=loss(y1)

L2=loss(y2)

L3=loss(y3)

根据链式法则:

可以看到对于V来说不存在连乘,而对于W和U就有,所以就存在梯度消失和梯度爆炸的问题。

针对梯度爆炸,有:

梯度裁剪方法。

针对梯度爆炸,有:

LSTM方法。

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇