From b1f2f6d76bbd94c0a7a631f7e39104f09bd6dcf1 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Thu, 25 Apr 2024 22:35:16 +0000 Subject: [PATCH] added mulut --- src/common/layers.py | 28 ++++++++ src/common/lut.py | 18 ++++- src/models/__init__.py | 3 +- src/models/sdylut.py | 148 ++++++++++++++++++++++++++++++++--------- src/models/sdynet.py | 110 ++++++++++++++++++++++-------- 5 files changed, 247 insertions(+), 60 deletions(-) create mode 100644 src/common/layers.py diff --git a/src/common/layers.py b/src/common/layers.py new file mode 100644 index 0000000..4c2886f --- /dev/null +++ b/src/common/layers.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class PercievePattern(): + def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]): + self.receptive_field_idxes = np.array(receptive_field_idxes) + self.window_size = np.max(self.receptive_field_idxes) + 1 + self.receptive_field_idxes = [ + self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1], + self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1], + self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1], + self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1], + ] + + def __call__(self, x): + b,c,h,w = x.shape + x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') + x = F.unfold(input=x, kernel_size=self.window_size) + x = torch.stack([ + x[:,self.receptive_field_idxes[0],:], + x[:,self.receptive_field_idxes[1],:], + x[:,self.receptive_field_idxes[2],:], + x[:,self.receptive_field_idxes[3],:] + ], 2) + x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2) + return x \ No newline at end of file diff --git a/src/common/lut.py b/src/common/lut.py index 7e2ef13..3c7a4b6 100644 --- a/src/common/lut.py +++ b/src/common/lut.py @@ -93,7 +93,23 @@ def forward_2x2_input_SxS_output(index, lut): ) out = out[:,:,0:-1,0:-1,:,:] # unpad # Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504] - out = out.permute(0,1,2,4,3,5).reshape(b,1,hs*scale,ws*scale) + out = out.permute(0,1,2,4,3,5).reshape(b*c,1,hs*scale,ws*scale) + out = round_func(out) + return out + +def forward_unfolded_2x2_input_SxS_output(index, lut): + b,c,hs,ws = index.shape + scale = lut.shape[-1] + out = select_index_4dlut_tetrahedral( + ixA = index, + ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]), + ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]), + ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]), + lut = lut + ) + out = out[:,:,0:-1,0:-1,:,:] # unpad + # Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504] + out = out.permute(0,1,2,4,3,5).reshape(b*c,1,scale,scale) out = round_func(out) return out diff --git a/src/models/__init__.py b/src/models/__init__.py index 900d147..b1c21bb 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -17,7 +17,8 @@ AVAILABLE_MODELS = { 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1, 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2, 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered, - 'SDYNetx1': sdynet.SDYNetx1, + 'SDYNetx1': sdynet.SDYNetx1, 'SDYLutx1': sdylut.SDYLutx1, + 'SDYNetx2': sdynet.SDYNetx2, 'SDYLutx2': sdylut.SDYLutx2, 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable, 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable, } diff --git a/src/models/sdylut.py b/src/models/sdylut.py index ae0269e..e71fc1d 100644 --- a/src/models/sdylut.py +++ b/src/models/sdylut.py @@ -3,74 +3,162 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np from pathlib import Path -from common.lut import forward_2x2_input_SxS_output +from common.lut import forward_2x2_input_SxS_output, forward_unfolded_2x2_input_SxS_output +from common.layers import PercievePattern -class SRLut2x2(nn.Module): +class SDYLutx1(nn.Module): def __init__( self, quantization_interval, scale ): - super(SRLut2x2, self).__init__() + super(SDYLutx1, self).__init__() self.scale = scale self.quantization_interval = quantization_interval - self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]) + self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]]) + self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]]) + self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod def init_from_lut( - stage_lut + stageS, stageD, stageY ): - scale = int(stage_lut.shape[-1]) - quantization_interval = 256//(stage_lut.shape[0]-1) - lut_model = SRLut2x2(quantization_interval=quantization_interval, scale=scale) - lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) + scale = int(stageS.shape[-1]) + quantization_interval = 256//(stageS.shape[0]-1) + lut_model = SDYLutx1(quantization_interval=quantization_interval, scale=scale) + lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32)) + lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32)) + lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32)) return lut_model def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) - x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut) - x = x.view(b, c, x.shape[-2], x.shape[-1]) - return x - + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) + rb,rc,rh,rw = rotated.shape + + s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS) + s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) + output += s + + d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD) + d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) + output += d + + y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY) + y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) + output += y + + output /= 4*3 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + def __repr__(self): - return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" + return f"{self.__class__.__name__}" + \ + f"\n stageS size: {self.stageS.shape}" + \ + f"\n stageD size: {self.stageD.shape}" + \ + f"\n stageY size: {self.stageY.shape}" -class SRLut3x3(nn.Module): +class SDYLutx2(nn.Module): def __init__( self, quantization_interval, scale ): - super(SRLut3x3, self).__init__() + super(SDYLutx2, self).__init__() self.scale = scale self.quantization_interval = quantization_interval - self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]) + self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]]) + self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]]) + self.stageS_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.stageD_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.stageY_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.stageS_2 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.stageD_2 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.stageY_2 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod def init_from_lut( - stage_lut + stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2 ): - scale = int(stage_lut.shape[-1]) - quantization_interval = 256//(stage_lut.shape[0]-1) - lut_model = SRLut3x3(quantization_interval=quantization_interval, scale=scale) - lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) + scale = int(stageS_2.shape[-1]) + quantization_interval = 256//(stageS_2.shape[0]-1) + lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale) + lut_model.stageS_1 = nn.Parameter(torch.tensor(stageS_1).type(torch.float32)) + lut_model.stageD_1 = nn.Parameter(torch.tensor(stageD_1).type(torch.float32)) + lut_model.stageY_1 = nn.Parameter(torch.tensor(stageY_1).type(torch.float32)) + lut_model.stageS_2 = nn.Parameter(torch.tensor(stageS_2).type(torch.float32)) + lut_model.stageD_2 = nn.Parameter(torch.tensor(stageD_2).type(torch.float32)) + lut_model.stageY_2 = nn.Parameter(torch.tensor(stageY_2).type(torch.float32)) return lut_model def forward(self, x): b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + x = x.view(b*c, 1, h, w).type(torch.float32) + output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut) - unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) - output += unrotated_prediction - output /= 4 + rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) + rb,rc,rh,rw = rotated.shape + + s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_1) + s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) + s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) + output += s + + d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_1) + d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) + d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) + output += d + + y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_1) + y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) + y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) + output += y + + output /= 4*3 + output = output.view(b, c, h, w) + x = output + + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) + rb,rc,rh,rw = rotated.shape + + s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_2) + s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) + output += s + + d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_2) + d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) + output += d + + y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_2) + y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) + output += y + + output /= 4*3 output = output.view(b, c, h*self.scale, w*self.scale) return output def __repr__(self): - return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" \ No newline at end of file + return f"{self.__class__.__name__}" + \ + f"\n stageS_1 size: {self.stageS_1.shape}" + \ + f"\n stageD_1 size: {self.stageD_1.shape}" + \ + f"\n stageY_1 size: {self.stageY_1.shape}" + \ + f"\n stageS_2 size: {self.stageS_2.shape}" + \ + f"\n stageD_2 size: {self.stageD_2.shape}" + \ + f"\n stageY_2 size: {self.stageY_2.shape}" \ No newline at end of file diff --git a/src/models/sdynet.py b/src/models/sdynet.py index 9127d99..be37422 100644 --- a/src/models/sdynet.py +++ b/src/models/sdynet.py @@ -4,6 +4,7 @@ import torch.nn.functional as F import numpy as np from common.utils import round_func from common import lut +from common.layers import PercievePattern from pathlib import Path from . import sdylut @@ -37,32 +38,7 @@ class DenseConvUpscaleBlock(nn.Module): x = self.shuffle(self.project_channels(x)) x = torch.tanh(x) x = round_func(x*127.5 + 127.5) - return x - -class PercievePattern(): - def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]): - self.receptive_field_idxes = np.array(receptive_field_idxes) - self.window_size = np.max(self.receptive_field_idxes) + 1 - self.receptive_field_idxes = [ - self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1], - self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1], - self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1], - self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1], - ] - - def __call__(self, x): - b,c,h,w = x.shape - x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') - x = F.unfold(input=x, kernel_size=self.window_size) - x = torch.stack([ - x[:,self.receptive_field_idxes[0],:], - x[:,self.receptive_field_idxes[1],:], - x[:,self.receptive_field_idxes[2],:], - x[:,self.receptive_field_idxes[3],:] - ], 2) - x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2) - return x - + return x class SDYNetx1(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): @@ -103,6 +79,84 @@ class SDYNetx1(nn.Module): return output def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = sdylut.SDYLutx1.init_from_lut(stage_lut) + stageS = lut.transfer_2x2_input_SxS_output(self.stageS, quantization_interval=quantization_interval, batch_size=batch_size) + stageD = lut.transfer_2x2_input_SxS_output(self.stageD, quantization_interval=quantization_interval, batch_size=batch_size) + stageY = lut.transfer_2x2_input_SxS_output(self.stageY, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = sdylut.SDYLutx1.init_from_lut(stageS, stageD, stageY) + return lut_model + + +class SDYNetx2(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYNetx2, self).__init__() + self.scale = scale + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]) + self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]]) + self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]]) + self.stageS_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stageD_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stageY_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stageS_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stageD_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stageY_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) + rb,rc,rh,rw = rotated.shape + + s = self.stageS_1(self._extract_pattern_S(rotated)) + s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) + s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) + output_1 += s + + d = self.stageD_1(self._extract_pattern_D(rotated)) + d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) + d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) + output_1 += d + + y = self.stageY_1(self._extract_pattern_Y(rotated)) + y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) + y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) + output_1 += y + + output_1 /= 4*3 + output_1 = output_1.view(b, c, h, w) + x = output_1 + + output_2 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) + rb,rc,rh,rw = rotated.shape + + s = self.stageS_2(self._extract_pattern_S(rotated)) + s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) + output_2 += s + + d = self.stageD_2(self._extract_pattern_D(rotated)) + d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) + output_2 += d + + y = self.stageY_2(self._extract_pattern_Y(rotated)) + y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) + output_2 += y + + output_2 /= 4*3 + output_2 = output_2.view(b, c, h*self.scale, w*self.scale) + return output_2 + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + stageS_1 = lut.transfer_2x2_input_SxS_output(self.stageS_1, quantization_interval=quantization_interval, batch_size=batch_size) + stageD_1 = lut.transfer_2x2_input_SxS_output(self.stageD_1, quantization_interval=quantization_interval, batch_size=batch_size) + stageY_1 = lut.transfer_2x2_input_SxS_output(self.stageY_1, quantization_interval=quantization_interval, batch_size=batch_size) + stageS_2 = lut.transfer_2x2_input_SxS_output(self.stageS_2, quantization_interval=quantization_interval, batch_size=batch_size) + stageD_2 = lut.transfer_2x2_input_SxS_output(self.stageD_2, quantization_interval=quantization_interval, batch_size=batch_size) + stageY_2 = lut.transfer_2x2_input_SxS_output(self.stageY_2, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = sdylut.SDYLutx2.init_from_lut(stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2) return lut_model \ No newline at end of file