核心联系
本质是同一个概念:它们都源于信息论中的交叉熵概念,用于衡量两个概率分布之间的差异。在机器学习中,一个分布是模型预测的分布 (Prediction),另一个是真实的分布 (Truth)。目标就是通过梯度下降等优化方法,最小化这个交叉熵,使得预测分布尽可能接近真实分布。
统一的数学思想:无论是二分类还是多分类,它们的损失函数都是同一个形式:Loss = - Σ (真实标签的分布 * log(模型预测的分布))
区别主要在于真实标签的表示形式和模型输出层的激活函数,这导致了公式具体形式的不同。
详细区别
为了更清晰地对比,我们来看下表:
| 方面 | 二分类交叉熵损失 | 多分类交叉熵损失 |
|---|---|---|
| 问题类型 | 只有两个类别(正/负,是/否,猫/狗) | 有两个以上的类别(猫/狗/鸟,0-9手写数字) |
| 模型输出 | 一个神经元(通常) | K个神经元(K为类别数) |
| 激活函数 | Sigmoid | Softmax |
| 输出意义 | 输出一个值,表示样本属于正类的概率 P(y=1)。属于负类的概率即为 1 - P(y=1)。 | 输出一个概率向量,每个值代表样本属于对应类别的概率。所有输出值之和为1。 |
| 真实标签y | 一个数字,通常是 0 或 1。 | 一个 One-hot 编码 的向量。例如,3个类别中第2类表示为 [0, 1, 0]。 |
| 损失函数公式 | L = - [y * log(p) + (1 - y) * log(1 - p)]其中 p 是预测为正类的概率。 | L = - Σ (y_i * log(p_i))其中 i 从1到K,y_i是one-hot向量中第i位的值(0或1),p_i是预测为第i类的概率。 |
| 计算过程 | 对于每个样本: 1. 如果真实标签 y=1,损失为 -log(p)。2. 如果真实标签 y=0,损失为 -log(1-p)。 | 由于真实标签y是one-hot向量,只有真实类别c的位置y_c=1,其他都为0。所以公式简化为: L = - log(p_c)其中 p_c 是模型预测样本属于于真实类别c的概率。 |
公式推导与直观理解
1. 二分类交叉熵
- 模型输出:一个值
z,经过 Sigmoid 函数后得到p = σ(z),表示P(y=1)。 - 真实分布:可以看作一个伯努利分布。如果真实标签是
y=1,则真实分布是[0, 1](即属于正类的概率为1);如果y=0,真实分布是[1, 0]。 - 损失计算:
- 当
y=1时,我们希望p越大越好。损失是-log(p)。如果p=0.9,-log(0.9) ≈ 0.1(损失小);如果p=0.1,-log(0.1) ≈ 2.3(损失大)。 - 当
y=0时,我们希望1-p越大越好。损失是-log(1-p)。
- 当
这个公式巧妙地组合了这两种情况。
2. 多分类交叉熵
- 模型输出:一个向量
[z1, z2, ..., zK],经过 Softmax 函数后,得到一个概率分布[p1, p2, ..., pK],其中p_i = e^{z_i} / Σ(e^{z_j})。 - 真实分布:一个 One-hot 向量,例如
[0, 0, 1, 0](表示属于第3类)。 - 损失计算:
因为真实标签中只有真实类别c的位置是1,其他都是0,所以求和公式中绝大多数项都为0。L = - [0*log(p1) + 0*log(p2) + 1*log(p_c) + ... + 0*log(pK)] = -log(p_c)
直观理解:多分类交叉熵损失只关心模型对真实类别的预测概率。如果模型非常确定(p_c 接近1),-log(p_c) 就很小;如果模型预测正确但不确定(比如 p_c=0.6),损失会大一些 -log(0.6) ≈ 0.51;如果模型预测错了,把正确的类别预测得很低(比如 p_c=0.1),损失就会非常大 -log(0.1) = 2.3。
关键总结
- 联系:它们都是交叉熵损失,核心思想完全一致——最小化预测概率分布与真实概率分布之间的差异。
- 区别1(输出和激活函数):二分类通常用一个Sigmoid神经元,输出一个标量概率;多分类用Softmax层,输出一个概率向量。
- 区别2(标签形式):二分类的标签是标量(0/1);多分类的标签是One-hot向量。
- 区别3(公式形态):二分类公式需要同时考虑正类和负类
- [y*log(p) + (1-y)*log(1-p)];多分类公式因One-hot编码而简化为只关注真实类别-log(p_c)。
一个常见的疑问
问:二分类问题可以用两个神经元的Softmax吗?
答:完全可以! 在这种情况下,二分类就变成了一个类别数K=2的多分类问题。
- 模型输出两个值,经过Softmax得到
[p_class0, p_class1],且p_class0 + p_class1 = 1。 - 标签需要转换为One-hot形式,例如“猫”是
[1, 0],“狗”是[0, 1]。 - 损失函数则使用多分类交叉熵损失
L = - Σ (y_i * log(p_i))。
在这种情况下,二分类和多分类交叉熵损失就完全统一了。在实际应用中,两种方式都是可行的,但传统上,简单的二分类问题使用单个Sigmoid输出更为常见。
比如李沐动手学习深度学习李word2vec章节的数据:
pred = torch.tensor([[1.1,-2.2, 3.3,-4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
如果看成是多分类问题,那么就对两个样本的pred分别做softmax得到每个类别的概率
- 对于第一个样本: 真实类别是0,损失 = -log(P(类别0))
- 对于第二个样本: 真实类别是1,损失 = -log(P(类别1))
如果看成是二分类问题就要看成两个样本的四个二分类问题。
- Sigmoid转换: 对每个类别的logits独立应用Sigmoid,得到4个独立的概率值
- 损失计算: 对4个类别分别计算二分类损失,然后求平均