| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from .cross_entropy import LabelSmoothingCrossEntropy |
| |
|
| |
|
| | class JsdCrossEntropy(nn.Module): |
| | """ Jensen-Shannon Divergence + Cross-Entropy Loss |
| | |
| | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py |
| | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - |
| | https://arxiv.org/abs/1912.02781 |
| | |
| | Hacked together by / Copyright 2020 Ross Wightman |
| | """ |
| | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): |
| | super().__init__() |
| | self.num_splits = num_splits |
| | self.alpha = alpha |
| | if smoothing is not None and smoothing > 0: |
| | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) |
| | else: |
| | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() |
| |
|
| | def __call__(self, output, target): |
| | split_size = output.shape[0] // self.num_splits |
| | assert split_size * self.num_splits == output.shape[0] |
| | logits_split = torch.split(output, split_size) |
| |
|
| | |
| | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) |
| | probs = [F.softmax(logits, dim=1) for logits in logits_split] |
| |
|
| | |
| | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() |
| | loss += self.alpha * sum([F.kl_div( |
| | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) |
| | return loss |
| |
|