import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from common.utils import round_func from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output from pathlib import Path class RCLutCentered_3x3(nn.Module): def __init__( self, quantization_interval, scale ): super(RCLutCentered_3x3, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) @staticmethod def init_from_lut( rc_conv_luts, dense_conv_lut ): scale = int(dense_conv_lut.shape[-1]) quantization_interval = 256//(dense_conv_lut.shape[0]-1) lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) return lut_model def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts) x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1] x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) x = x.view(b, c, x.shape[-2], x.shape[-1]) return x def __repr__(self): return "\n".join([ f"{self.__class__.__name__}(", f" rc_conv_luts size: {self.rc_conv_luts.shape}", f" dense_conv_lut size: {self.dense_conv_lut.shape}", ")"]) class RCLutCentered_7x7(nn.Module): def __init__( self, window_size, quantization_interval, scale ): super(RCLutCentered_7x7, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) @staticmethod def init_from_lut( rc_conv_luts, dense_conv_lut ): scale = int(dense_conv_lut.shape[-1]) quantization_interval = 256//(dense_conv_lut.shape[0]-1) lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) return lut_model def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts) x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) # x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4) x = x.view(b, c, x.shape[-2], x.shape[-1]) return x def __repr__(self): return "\n".join([ f"{self.__class__.__name__}(", f" rc_conv_luts size: {self.rc_conv_luts.shape}", f" dense_conv_lut size: {self.dense_conv_lut.shape}", ")"]) class RCLutRot90_3x3(nn.Module): def __init__( self, quantization_interval, scale ): super(RCLutRot90_3x3, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) @staticmethod def init_from_lut( rc_conv_luts, dense_conv_lut ): scale = int(dense_conv_lut.shape[-1]) quantization_interval = 256//(dense_conv_lut.shape[0]-1) lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) return lut_model def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts) rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) output += unrotated_prediction output /= 4 output = output.view(b, c, output.shape[-2], output.shape[-1]) return output def __repr__(self): return "\n".join([ f"{self.__class__.__name__}(", f" rc_conv_luts size: {self.rc_conv_luts.shape}", f" dense_conv_lut size: {self.dense_conv_lut.shape}", ")"]) class RCLutRot90_7x7(nn.Module): def __init__( self, quantization_interval, scale ): super(RCLutRot90_7x7, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) @staticmethod def init_from_lut( rc_conv_luts, dense_conv_lut ): scale = int(dense_conv_lut.shape[-1]) quantization_interval = 256//(dense_conv_lut.shape[0]-1) window_size = rc_conv_luts.shape[0] lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) return lut_model def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts) rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) output += unrotated_prediction output /= 4 output = output.view(b, c, output.shape[-2], output.shape[-1]) return output def __repr__(self): return "\n".join([ f"{self.__class__.__name__}(", f" rc_conv_luts size: {self.rc_conv_luts.shape}", f" dense_conv_lut size: {self.dense_conv_lut.shape}", ")"]) class RCLutx1(nn.Module): def __init__( self, quantization_interval, scale ): super(RCLutx1, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod def init_from_lut( rc_conv_luts_3x3, dense_conv_lut_3x3, rc_conv_luts_5x5, dense_conv_lut_5x5, rc_conv_luts_7x7, dense_conv_lut_7x7 ): scale = int(dense_conv_lut_3x3.shape[-1]) quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1) lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale) lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32)) lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32)) lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32)) lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32)) lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32)) lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32)) return lut_model def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) return x def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7), k=-rotations_count, dims=[2, 3] ) output /= 3*4 output = output.view(b, c, output.shape[-2], output.shape[-1]) return output def __repr__(self): return "\n".join([ f"{self.__class__.__name__}(", f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}", f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}", f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}", f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}", f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}", f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}", ")"]) class RCLutx2(nn.Module): def __init__( self, quantization_interval, scale ): super(RCLutx2, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod def init_from_lut( s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3, s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5, s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7 ): scale = int(s2_dense_conv_lut_3x3.shape[-1]) quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1) lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale) lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32)) lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32)) lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32)) lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32)) lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32)) lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32)) lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32)) lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32)) lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32)) lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32)) lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32)) lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32)) return lut_model def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) return x def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), k=-rotations_count, dims=[2, 3] ) output /= 3*4 x = output output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), k=-rotations_count, dims=[2, 3] ) output /= 3*4 output = output.view(b, c, output.shape[-2], output.shape[-1]) return output def __repr__(self): return "\n".join([ f"{self.__class__.__name__}(", f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}", f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}", f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}", f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}", f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}", f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}", f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}", f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}", f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}", f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}", f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}", f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}", ")"]) class RCLutx2Centered(nn.Module): def __init__( self, quantization_interval, scale ): super(RCLutx2Centered, self).__init__() self.scale = scale self.quantization_interval = quantization_interval self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod def init_from_lut( s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3, s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5, s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7 ): scale = int(s2_dense_conv_lut_3x3.shape[-1]) quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1) lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale) lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32)) lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32)) lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32)) lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32)) lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32)) lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32)) lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32)) lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32)) lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32)) lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32)) lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32)) lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32)) return lut_model def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): x = forward_rc_conv_centered(index=index, lut=rc_conv_lut) x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) return x def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), k=-rotations_count, dims=[2, 3] ) output /= 3*4 x = output output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), k=-rotations_count, dims=[2, 3] ) output += torch.rot90( self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), k=-rotations_count, dims=[2, 3] ) output /= 3*4 output = output.view(b, c, output.shape[-2], output.shape[-1]) return output def __repr__(self): return "\n".join([ f"{self.__class__.__name__}(", f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}", f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}", f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}", f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}", f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}", f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}", f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}", f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}", f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}", f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}", f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}", f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}", ")"])