pytorch版Loss实现苏剑林版【通过互信息思想来缓解类别不平衡问题】
from torch import nn
import numpy as np
import torch
class PriorMultiLabelSoftMarginLoss(nn.Module):
def __init__(self, prior=None, num_labels=None, reduction="mean", eps=1e-9, tau=1.0):
"""PriorCrossEntropy
categorical-crossentropy-with-prior
urls: [通过互信息思想来缓解类别不平衡问题](https://spaces.ac.cn/archives/7615)
args:
prior: List<float>, prior of label, 先验知识. eg. [0.6, 0.2, 0.1, 0.1]
num_labels: int, num of labels, 类别数. eg. 10
reduction: str, Specifies the reduction to apply to the output, 输出形式.
eg.``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``
eps: float, Minimum of maths, 极小值. eg. 1e-9
tau: float, weight of prior in loss, 先验知识的权重, eg. ``1.0``
returns:
Tensor of loss.
examples:
>>> loss = PriorCrossEntropy(prior)(logits, label)
"""
super(PriorMultiLabelSoftMarginLoss, self).__init__()
self.loss_mlsm = torch.nn.MultiLabelSoftMarginLoss(reduction=reduction)
if not prior: prior = np.array([1/num_labels for _ in range(num_labels)]) # 如果不存在就设置为num
if type(prior) ==list: prior = np.array(prior)
self.log_prior = torch.tensor(np.log(prior + eps)).unsqueeze(0)
self.eps = eps
self.tau = tau
def forward(self, logits, labels):
# 使用与输入label相同的device
logits = logits + self.tau * self.log_prior.to(labels.device)
loss = self.loss_mlsm(logits, labels)
return loss