对于多分类损失函数Cross Entropy Loss,就不过多的解释,网上的博客不计其数。在这里,讲讲对于CE Loss的一些真正的理解。
首先大部分博客给出的公式如下:
其中p为真实标签值,q为预测值。
在低维复现此公式,结果如下。在此强调一点,pytorch中CE Loss并不会将输入的target映射为one-hot编码格式,而是直接取下标进行计算。
import torch import torch.nn as nn import math import numpy as np #官方的实现 entroy=nn.CrossEntropyLoss() input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],]) target = torch.tensor([0,1,2]) output = entroy(input, target) print(output) #输出 tensor(1.1142) #自己实现 input=np.array(input) target = np.array(target) def cross_entorpy(input, target): output = 0 length = len(target) for i in range(length): hou = 0 for j in input[i]: hou += np.log(input[i][target[i]]) output += -hou return np.around(output / length, 4) print(cross_entorpy(input, target)) #输出 3.8162
我们按照官方给的CE Loss和根据公式得到的答案并不相同,说明公式是有问题的。
正确公式
实现代码如下
import torch import torch.nn as nn import math import numpy as np entroy=nn.CrossEntropyLoss() input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],]) target = torch.tensor([0,1,2]) output = entroy(input, target) print(output) #输出 tensor(1.1142) #%% input=np.array(input) target = np.array(target) def cross_entorpy(input, target): output = 0 length = len(target) for i in range(length): hou = 0 for j in input[i]: hou += np.exp(j) output += -input[i][target[i]] + np.log(hou) return np.around(output / length, 4) print(cross_entorpy(input, target)) #输出 1.1142
对比自己实现的公式和官方给出的结果,可以验证公式的正确性。
观察公式可以发现其实nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合版本。
nn.logSoftmax(),公式如下
nn.NLLLoss(),公式如下
将nn.logSoftmax()作为变量带入nn.NLLLoss()可得
因为
可看做一个常量,故上式可化简为:
对比nn.Cross Entropy Loss公式,结果显而易见。
验证代码如下。
import torch import torch.nn as nn import math import numpy as np entroy=nn.CrossEntropyLoss() input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],]) target = torch.tensor([0,1,2]) output = entroy(input, target) print(output) # 输出为tensor(1.1142) m = nn.LogSoftmax() loss = nn.NLLLoss() input=m(input) output = loss(input, target) print(output) # 输出为tensor(1.1142)综上,可得两个结论:
1.nn.Cross Entropy Loss的公式。
2.nn.Cross Entropy Loss为nn.logSoftmax()和nn.NLLLoss()的整合版本。