import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from pathlib import Path from common.lut import forward_2x2_input_SxS_output from common import layers class SRLut(nn.Module): def __init__( self, quantization_interval, scale ): super(SRLut, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod def init_from_numpy( stage_lut ): scale = int(stage_lut.shape[-1]) quantization_interval = 256//(stage_lut.shape[0]-1) lut_model = SRLut(quantization_interval=quantization_interval, scale=scale) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut) x = x.view(b, c, x.shape[-2], x.shape[-1]) return x def __repr__(self): return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" class SRLutR90(nn.Module): def __init__( self, quantization_interval, scale ): super(SRLutR90, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod def init_from_numpy( stage_lut ): scale = int(stage_lut.shape[-1]) quantization_interval = 256//(stage_lut.shape[0]-1) lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut) unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) output += unrotated_prediction output /= 4 output = output.view(b, c, h*self.scale, w*self.scale) return output def __repr__(self): return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" class SRLutR90Y(nn.Module): def __init__( self, quantization_interval, scale ): super(SRLutR90Y, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.rgb_to_ycbcr = layers.RgbToYcbcr() self.ycbcr_to_rgb = layers.YcbcrToRgb() @staticmethod def init_from_numpy( stage_lut ): scale = int(stage_lut.shape[-1]) quantization_interval = 256//(stage_lut.shape[0]-1) lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model def forward(self, x): b,c,h,w = x.shape x = self.rgb_to_ycbcr(x) y = x[:,0:1,:,:] cbcr = x[:,1:,:,:] cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(y, k=rotations_count, dims=[2, 3]) rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut) unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) output += unrotated_prediction output /= 4 output = torch.cat([output, cbcr_scaled], dim=1) output = self.ycbcr_to_rgb(output).clamp(0, 255) return output def __repr__(self): return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"