updates
parent
19e0fca745
commit
c3dc32c336
File diff suppressed because one or more lines are too long
@ -0,0 +1,152 @@
|
||||
|
||||
import torch
|
||||
|
||||
class FourierLoss(nn.Module):
|
||||
def __init__(self, weight=None, size_average=True):
|
||||
super(FourierLoss, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def perm_roll(im, axis, amount):
|
||||
permutation = torch.roll(torch.arange(im.shape[axis], device=im.device), amount, dims=0)
|
||||
return torch.index_select(im, axis, permutation)
|
||||
|
||||
@staticmethod
|
||||
def shift_left(im):
|
||||
tt = perm_roll(im, axis=-2, amount=-(im.shape[-2]+1)//2)
|
||||
tt = perm_roll(tt, axis=-1, amount=-(im.shape[-1]+1)//2)
|
||||
return tt
|
||||
|
||||
@staticmethod
|
||||
def shift_right(im):
|
||||
tt = FourierLoss.perm_roll(im, axis=-2, amount=(im.shape[-2]+1)//2)
|
||||
tt = FourierLoss.perm_roll(tt, axis=-1, amount=(im.shape[-1]+1)//2)
|
||||
return tt
|
||||
|
||||
@staticmethod
|
||||
def aperture(h, w, condition=None, low_pass_frequency=False):
|
||||
"""torch.where(condition(circle_dist), torch.ones_like(circle_dist), torch.zeros_like(circle_dist)) """
|
||||
x, y = torch.meshgrid(torch.arange(-h//2, h//2), torch.arange(-w//2, w//2), indexing='ij')
|
||||
circle_dist = torch.sqrt(x**2 + y**2)
|
||||
if low_pass_frequency:
|
||||
circle_aperture = torch.where(condition(circle_dist), torch.zeros_like(circle_dist), torch.ones_like(circle_dist))
|
||||
else:
|
||||
circle_aperture = torch.where(condition(circle_dist), torch.ones_like(circle_dist), torch.zeros_like(circle_dist))
|
||||
return circle_aperture
|
||||
|
||||
@staticmethod
|
||||
def fft(x, aperture):
|
||||
r = FourierLoss.shift_right(torch.fft.fft2(x, norm='ortho'))*aperture
|
||||
return r.real**2 + r.imag**2
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
b,c,h,w = inputs.shape
|
||||
r = min(h, w)//12
|
||||
low_pass = FourierLoss.aperture(h=h, w=w, condition=lambda circle_dist: circle_dist>r, low_pass_frequency=False).to(inputs.device)
|
||||
high_pass = FourierLoss.aperture(h=h, w=w, condition=lambda circle_dist: circle_dist>r, low_pass_frequency=True).to(inputs.device)
|
||||
low_frequency_loss = F.mse_loss(FourierLoss.fft(inputs, low_pass), FourierLoss.fft(targets, low_pass))*(low_pass.sum()/high_pass.sum())
|
||||
high_frequency_loss = F.mse_loss(FourierLoss.fft(inputs, high_pass), FourierLoss.fft(targets, high_pass))
|
||||
return low_frequency_loss + high_frequency_loss
|
||||
|
||||
|
||||
class FocalFrequencyLoss(nn.Module):
|
||||
"""The torch.nn.Module class that implements focal frequency loss - a
|
||||
frequency domain loss function for optimizing generative models.
|
||||
|
||||
Ref:
|
||||
Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021.
|
||||
<https://arxiv.org/pdf/2012.12821.pdf>
|
||||
|
||||
Args:
|
||||
loss_weight (float): weight for focal frequency loss. Default: 1.0
|
||||
alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0
|
||||
patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1
|
||||
ave_spectrum (bool): whether to use minibatch average spectrum. Default: False
|
||||
log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False
|
||||
batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False
|
||||
"""
|
||||
|
||||
def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False):
|
||||
super(FocalFrequencyLoss, self).__init__()
|
||||
self.loss_weight = loss_weight
|
||||
self.alpha = alpha
|
||||
self.patch_factor = patch_factor
|
||||
self.ave_spectrum = ave_spectrum
|
||||
self.log_matrix = log_matrix
|
||||
self.batch_matrix = batch_matrix
|
||||
|
||||
def tensor2freq(self, x):
|
||||
# crop image patches
|
||||
patch_factor = self.patch_factor
|
||||
_, _, h, w = x.shape
|
||||
assert h % patch_factor == 0 and w % patch_factor == 0, (
|
||||
'Patch factor should be divisible by image height and width')
|
||||
patch_list = []
|
||||
patch_h = h // patch_factor
|
||||
patch_w = w // patch_factor
|
||||
for i in range(patch_factor):
|
||||
for j in range(patch_factor):
|
||||
patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
|
||||
|
||||
# stack to patch tensor
|
||||
y = torch.stack(patch_list, 1)
|
||||
|
||||
# perform 2D DFT (real-to-complex, orthonormalization)
|
||||
freq = torch.fft.fft2(y, norm='ortho')
|
||||
freq = torch.stack([freq.real, freq.imag], -1)
|
||||
return freq
|
||||
|
||||
def loss_formulation(self, recon_freq, real_freq, matrix=None):
|
||||
# spectrum weight matrix
|
||||
if matrix is not None:
|
||||
# if the matrix is predefined
|
||||
weight_matrix = matrix.detach()
|
||||
else:
|
||||
# if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
|
||||
matrix_tmp = (recon_freq - real_freq) ** 2
|
||||
matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
|
||||
|
||||
# whether to adjust the spectrum weight matrix by logarithm
|
||||
if self.log_matrix:
|
||||
matrix_tmp = torch.log(matrix_tmp + 1.0)
|
||||
|
||||
# whether to calculate the spectrum weight matrix using batch-based statistics
|
||||
if self.batch_matrix:
|
||||
matrix_tmp = matrix_tmp / matrix_tmp.max()
|
||||
else:
|
||||
matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
|
||||
|
||||
matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
|
||||
matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
|
||||
weight_matrix = matrix_tmp.clone().detach()
|
||||
|
||||
assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
|
||||
'The values of spectrum weight matrix should be in the range [0, 1], '
|
||||
'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
|
||||
|
||||
# frequency distance using (squared) Euclidean distance
|
||||
tmp = (recon_freq - real_freq) ** 2
|
||||
freq_distance = tmp[..., 0] + tmp[..., 1]
|
||||
|
||||
# dynamic spectrum weighting (Hadamard product)
|
||||
loss = weight_matrix * freq_distance
|
||||
return torch.mean(loss)
|
||||
|
||||
def forward(self, pred, target, matrix=None, **kwargs):
|
||||
"""Forward function to calculate focal frequency loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor.
|
||||
target (torch.Tensor): of shape (N, C, H, W). Target tensor.
|
||||
matrix (torch.Tensor, optional): Element-wise spectrum weight matrix.
|
||||
Default: None (If set to None: calculated online, dynamic).
|
||||
"""
|
||||
pred_freq = self.tensor2freq(pred)
|
||||
target_freq = self.tensor2freq(target)
|
||||
|
||||
# whether to use minibatch average spectrum
|
||||
if self.ave_spectrum:
|
||||
pred_freq = torch.mean(pred_freq, 0, keepdim=True)
|
||||
target_freq = torch.mean(target_freq, 0, keepdim=True)
|
||||
|
||||
# calculate focal frequency loss
|
||||
return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight
|
Loading…
Reference in New Issue