diff --git a/src/models/rcnet.py b/src/models/rcnet.py index 2d8af31..c08bff0 100644 --- a/src/models/rcnet.py +++ b/src/models/rcnet.py @@ -327,7 +327,6 @@ class RCNetx2(nn.Module): output = output.view(b, c, h*self.scale, w*self.scale) return output - class RCNetx2Centered(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(RCNetx2Centered, self).__init__() @@ -390,7 +389,6 @@ class RCNetx2Centered(nn.Module): output = output.view(b, c, h*self.scale, w*self.scale) return output - class ReconstructedConvRot90Unlutable(nn.Module): def __init__(self, hidden_dim, window_size=7): super(ReconstructedConvRot90Unlutable, self).__init__() @@ -545,4 +543,71 @@ class RCNetx2CenteredUnlutable(nn.Module): output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) output /= 3*4 output = output.view(b, c, h*self.scale, w*self.scale) - return output \ No newline at end of file + return output + + + +class ReconstructedConvCenteredv2(nn.Module): + def __init__(self, hidden_dim, window_size=7): + super(ReconstructedConvCenteredv2, self).__init__() + self.window_size = window_size + self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + + def pixel_wise_forward(self, x): + x = (x-127.5)/127.5 + out = torch.einsum('bwk,wh,wh -> bwk', x, self.projection1, self.projection2) + out = torch.tanh(out) + out = out*127.5 + 127.5 + return out + + def forward(self, x): + original_shape = x.shape + x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') + x = F.unfold(x, self.window_size) + x = self.pixel_wise_forward(x) + x = (x.mean(1) - (x.max().values - x.min().values))/2 + x = x.reshape(*original_shape) + x = round_func(x) + return x + + def __repr__(self): + return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" + +class RCBlockCenteredv2(nn.Module): + def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): + super(RCBlockCenteredv2, self).__init__() + self.window_size = window_size + self.rc_conv = ReconstructedConvCenteredv2(hidden_dim=hidden_dim, window_size=window_size) + self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + + def forward(self, x): + b,c,hs,ws = x.shape + x = self.rc_conv(x) + x = F.pad(x, pad=[0,1,0,1], mode='replicate') + x = self.dense_conv_block(x) + return x + + +class RCNetCentered_7x7v2(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetCentered_7x7v2, self).__init__() + self.hidden_dim = hidden_dim + self.layers_count = layers_count + self.scale = scale + window_size = 7 + self.stage = RCBlockCenteredv2(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size) + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + x = self.stage(x) + x = x.view(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + window_size = self.stage.rc_conv.window_size + rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) + dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = rclut.RCLutCentered_7x7.init_from_lut(rc_conv_luts, dense_conv_lut) + return lut_model