1.损失函数
对于分类任务来说,rnn使用交叉熵损失。

其中:
- T为序列长度
- C为类别数
- yt,i为真实标签的one-hot表示
- yhat t,i为模型预测的概率分布
预测的输出是C维张量,每个数值代表这个类别的概率。真实标签也是一个C维张量,但是只有一个值为1。
对于回归任务来说,rnn使用均方损失。
2.梯度更新过程

假设我们使用一个时间步为3的rnn,那么损失函数由三部分组成。
L1=loss(y1)
L2=loss(y2)
L3=loss(y3)

根据链式法则:



可以看到对于V来说不存在连乘,而对于W和U就有,所以就存在梯度消失和梯度爆炸的问题。
针对梯度爆炸,有:
梯度裁剪方法。
针对梯度爆炸,有:
LSTM方法。