我们来详细、清晰地梳理一下 交叉熵损失函数。它是机器学习和深度学习中,尤其是分类任务中,最核心、最重要的损失函数之一。
1. 核心思想:衡量两个概率分布的差异
交叉熵的本质是衡量模型预测的概率分布 与 真实标签的概率分布 之间的“距离”或差异。
- 真实分布:通常是“one-hot”编码。例如,对于手写数字识别,数字“3”的真实分布是
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]。
- 预测分布:模型输出的概率向量。例如,
[0.1, 0.1, 0.0, 0.6, 0.05, ...] 表示模型认为图像是“3”的概率为60%。
我们希望预测分布尽可能地接近真实分布。交叉熵就是给这个“接近程度”打一个分,分越低,说明两者越接近。
2. 公式解析
交叉熵的定义来自于信息论。对于一个单个样本,其交叉熵损失计算公式为:
H(p,q)=-sum_(i=1)^(C)p_(i)log(q_(i))H(p, q) = -\sum_{i=1}^{C} p_i \log(q_i)
其中:
- CC :类别的总数(例如,10个数字就是10类)。
- p_(i)p_i :样本属于第 ii 类的 真实概率。
- q_(i)q_i :模型预测样本属于第 ii 类的 概率。
3. 在分类任务中的具体形式
在分类任务中,真实标签
pp 是
one-hot 向量(只有真实类别为1,其余为0)。假设真实类别是
kk,则公式可以大大简化:
"Loss"=-log(q_(k))\text{Loss} = -\log(q_k)
这里的 q_(k)q_k 就是模型预测为真实类别 kk 的概率。
举个例子:
-
真实类别:第3类(p = [0, 0, 1, 0])。
-
模型预测概率:q = [0.1, 0.2, 0.6, 0.1](预测为第3类的概率是0.6)。
-
交叉熵损失:Loss = -log(0.6) ≈ 0.51
-
如果模型预测得更准:q = [0.05, 0.05, 0.85, 0.05]
-
交叉熵损失:Loss = -log(0.85) ≈ 0.16 (损失变小了)
-
如果模型预测错了:q = [0.8, 0.1, 0.05, 0.05](预测为第3类的概率很低,0.05)
-
交叉熵损失:Loss = -log(0.05) ≈ 3.00 (损失变得非常大!)
4. 直观理解与特性
- 惩罚自信的错误:从上面的例子可以看出,当模型非常确信地做出了错误的预测时(q_(k)q_k 非常小),
-log(q_k) 会变得极其大,给予模型非常严厉的惩罚。这符合我们的直观:错得越离谱,惩罚应该越重。
- 奖励自信的正确:当模型正确且确信时(q_(k)q_k 接近1),
-log(q_k) 接近0,损失很小。
- 非负性:由于 q_(k)q_k 在0到1之间,
log(q_k) 为负,所以损失值始终为非负数。
- 与Softmax激活函数的完美搭配:在神经网络中,交叉熵损失通常与 Softmax 激活函数联用。Softmax将网络最后一层的输出(logits)归一化为一个概率分布(所有概率之和为1,且每个概率>0),正好作为交叉熵的输入 qq。这种组合在数学上求导非常优雅,梯度形式简洁,有利于模型训练。
5. 多样本的扩展:批量损失
对于有
NN 个样本的批次(Batch),损失通常是所有样本交叉熵损失的平均值:
L=-(1)/(N)sum_(n=1)^(N)sum_(i=1)^(C)p_(n,i)log(q_(n,i))L = -\frac{1}{N} \sum_{n=1}^{N} \sum_{i=1}^{C} p_{n,i} \log(q_{n,i})
在代码实现(如PyTorch的 nn.CrossEntropyLoss 或 TensorFlow/Keras的 CategoricalCrossentropy)中,我们通常:
- 直接将模型最后一层的 logits(未经Softmax的原始分数)和 真实的类别索引标签 传给损失函数。
- 损失函数内部会自动进行Softmax和交叉熵计算,并进行数值优化以保证稳定性。
6. 为什么使用交叉熵?(对比均方误差MSE)
对于分类问题,交叉熵比传统的均方误差(MSE)好得多,主要因为:
- 梯度性质更好:交叉熵损失关于模型参数的梯度更干净、更直接。在反向传播时,它能提供更有效、更稳定的梯度信号,尤其是在预测概率接近0或1时。MSE的梯度在饱和区(概率接近0或1)会变得非常小,容易导致训练停滞(梯度消失)。
- 与概率解释一致:它直接衡量概率分布的差异,更符合分类任务的本质。
7. 变体
- 二分类交叉熵 (
Binary Cross-Entropy, BCE):用于二分类任务(C=2)。此时常用Sigmoid函数输出单个概率值 hat(y)\hat{y},损失公式为:L=-[y log( hat(y))+(1-y)log(1- hat(y))]L = -[y \log(\hat{y}) + (1-y) \log(1-\hat{y})]
- 带权重的交叉熵:为不同类别分配不同权重,用于处理类别不平衡的数据集。
- 标签平滑:将真实标签的 one-hot 分布稍微平滑化(如从
[0, 0, 1, 0] 变为 [0.01, 0.01, 0.96, 0.01]),可以防止模型对真实标签过于自信,起到正则化作用,提升泛化能力。
总结
| 特性 |
描述 |
| 目的 |
衡量模型预测概率分布与真实概率分布的差异,用于分类任务。 |
| 核心公式 |
L=-sump_(i)log(q_(i))L = -\sum p_i \log(q_i),对 one-hot 标签简化为 -log(q_("true"))-\log(q_{\text{true}})。 |
| 关键行为 |
严厉惩罚“自信的错误”,鼓励“自信的正确”。 |
| 最佳搭档 |
Softmax(多分类)或 Sigmoid(二分类)激活函数。 |
| 主要优势 |
比MSE等损失函数提供更优的梯度,训练更快、更稳定。 |
| 常见变体 |
二值交叉熵、带权重的交叉熵、结合标签平滑的交叉熵。 |
简单来说,交叉熵损失是驱动分类模型“变得正确且自信”的核心训练信号。理解它,是理解现代深度学习模型如何学习的关键一步。