import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from common.utils import round_func class SRNetBase(nn.Module): def __init__(self): super(SRNetBase, self).__init__() def forward_stage(self, x, percieve_pattern, stage): b,c,h,w = x.shape scale = stage.upscale_factor x = percieve_pattern(x) x = stage(x) x = round_func(x) x = x.reshape(b, 1, h, w, scale, scale) x = x.permute(0, 1, 2, 4, 3, 5) x = x.reshape(b, 1, h*scale, w*scale) return x def get_lut_model(self, quantization_interval=16, batch_size=2**10): raise NotImplementedError def get_loss_fn(self): def loss_fn(pred, target): return F.mse_loss(pred/255, target/255) return loss_fn