import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from common.utils import round_func from common import layers import copy class SRBase(nn.Module): def __init__(self): super(SRBase, self).__init__() def get_loss_fn(self): def loss_fn(pred, target): return F.mse_loss(pred/255, target/255) return loss_fn # def get_loss_fn(self): # ssim_loss = losses.SSIM(data_range=255) # l1_loss = losses.CharbonnierLoss() # def loss_fn(pred, target): # return ssim_loss(pred, target) + l1_loss(pred, target) # return loss_fn