From 7d28a70226c711f84bde8ff8dbd6551c02c3ec3e Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Mon, 22 Apr 2024 10:08:23 +0000 Subject: [PATCH] add unlutable rc variants --- src/models/__init__.py | 4 +- src/models/rcnet.py | 157 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 1 deletion(-) diff --git a/src/models/__init__.py b/src/models/__init__.py index 2a089d1..900d147 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -17,7 +17,9 @@ AVAILABLE_MODELS = { 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1, 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2, 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered, - 'SDYNetx1': sdynet.SDYNetx1 + 'SDYNetx1': sdynet.SDYNetx1, + 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable, + 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable, } def SaveCheckpoint(model, path): diff --git a/src/models/rcnet.py b/src/models/rcnet.py index 8699b70..2d8af31 100644 --- a/src/models/rcnet.py +++ b/src/models/rcnet.py @@ -369,6 +369,163 @@ class RCNetx2Centered(nn.Module): ) return lut_model + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) + output /= 3*4 + x = output + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) + output /= 3*4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + + +class ReconstructedConvRot90Unlutable(nn.Module): + def __init__(self, hidden_dim, window_size=7): + super(ReconstructedConvRot90Unlutable, self).__init__() + self.window_size = window_size + self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + + def pixel_wise_forward(self, x): + x = (x-127.5)/127.5 + out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) + out = torch.tanh(out) + out = out*127.5 + 127.5 + return out + + def forward(self, x): + original_shape = x.shape + x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') + x = F.unfold(x, self.window_size) + x = self.pixel_wise_forward(x) + x = x.mean(1) + x = x.reshape(*original_shape) + # x = round_func(x) # quality likely suffer from this + return x + + def __repr__(self): + return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" + +class RCBlockRot90Unlutable(nn.Module): + def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): + super(RCBlockRot90Unlutable, self).__init__() + self.window_size = window_size + self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size) + self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + + def forward(self, x): + b,c,hs,ws = x.shape + x = self.rc_conv(x) + x = F.pad(x, pad=[0,1,0,1], mode='replicate') + x = self.dense_conv_block(x) + return x + +class RCNetx2Unlutable(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetx2Unlutable, self).__init__() + self.scale = scale + self.hidden_dim = hidden_dim + self.stage1_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) + self.stage1_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) + self.stage1_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) + self.stage2_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) + self.stage2_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) + self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + raise NotImplementedError + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) + output /= 3*4 + x = output + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) + output /= 3*4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + + + +class ReconstructedConvCenteredUnlutable(nn.Module): + def __init__(self, hidden_dim, window_size=7): + super(ReconstructedConvCenteredUnlutable, self).__init__() + self.window_size = window_size + self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + + def pixel_wise_forward(self, x): + x = (x-127.5)/127.5 + out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) + out = torch.tanh(out) + out = out*127.5 + 127.5 + return out + + def forward(self, x): + original_shape = x.shape + x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') + x = F.unfold(x, self.window_size) + x = self.pixel_wise_forward(x) + x = x.mean(1) + x = x.reshape(*original_shape) + # x = round_func(x) # quality likely suffer from this + return x + + def __repr__(self): + return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" + +class RCBlockCenteredUnlutable(nn.Module): + def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): + super(RCBlockRot90Unlutable, self).__init__() + self.window_size = window_size + self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size) + self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + + def forward(self, x): + b,c,hs,ws = x.shape + x = self.rc_conv(x) + x = F.pad(x, pad=[0,1,0,1], mode='replicate') + x = self.dense_conv_block(x) + return x + +class RCNetx2CenteredUnlutable(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetx2CenteredUnlutable, self).__init__() + self.scale = scale + self.hidden_dim = hidden_dim + self.stage1_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) + self.stage1_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) + self.stage1_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) + self.stage2_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) + self.stage2_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) + self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + raise NotImplementedError + def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w)