diff --git a/src/models/__init__.py b/src/models/__init__.py index 204a452..397c9c2 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -29,7 +29,7 @@ AVAILABLE_MODELS = { 'SRNetY': srnet.SRNetY, 'SRLutY': srlut.SRLutY, 'HDBNet': hdbnet.HDBNet, - 'HDBLNet': hdbnet.HDBLNet, + 'HDBLut': hdblut.HDBLut, # 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, # 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, # 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, diff --git a/src/models/hdblut.py b/src/models/hdblut.py index d1fa943..17db7ba 100644 --- a/src/models/hdblut.py +++ b/src/models/hdblut.py @@ -14,19 +14,47 @@ class HDBLut(nn.Module): scale ): super(HDBLut, self).__init__() + assert scale == 4 self.scale = scale self.quantization_interval = quantization_interval - self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + + self.stage1_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + self.stage1_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + self.stage1_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + self.stage1_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + self.stage1_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + + self.stage2_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + self.stage2_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + self.stage2_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + self.stage2_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + self.stage2_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32)) + + self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3) + self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3) + self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2) + self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2) @staticmethod def init_from_numpy( - stage_lut + stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D, + stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D ): scale = int(stage_lut.shape[-1]) quantization_interval = 256//(stage_lut.shape[0]-1) lut_model = HDBLut(quantization_interval=quantization_interval, scale=scale) - lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) + lut_model.stage1_3H = nn.Parameter(torch.tensor(stage1_3H).type(torch.float32)) + lut_model.stage1_3D = nn.Parameter(torch.tensor(stage1_3D).type(torch.float32)) + lut_model.stage1_3B = nn.Parameter(torch.tensor(stage1_3B).type(torch.float32)) + lut_model.stage1_2H = nn.Parameter(torch.tensor(stage1_2H).type(torch.float32)) + lut_model.stage1_2D = nn.Parameter(torch.tensor(stage1_2D).type(torch.float32)) + + lut_model.stage2_3H = nn.Parameter(torch.tensor(stage2_3H).type(torch.float32)) + lut_model.stage2_3D = nn.Parameter(torch.tensor(stage2_3D).type(torch.float32)) + lut_model.stage2_3B = nn.Parameter(torch.tensor(stage2_3B).type(torch.float32)) + lut_model.stage2_2H = nn.Parameter(torch.tensor(stage2_2H).type(torch.float32)) + lut_model.stage2_2D = nn.Parameter(torch.tensor(stage2_2D).type(torch.float32)) return lut_model def forward_stage(self, x, scale, percieve_pattern, lut): @@ -41,161 +69,41 @@ class HDBLut(nn.Module): def forward(self, x): b,c,h,w = x.shape - x = x.reshape(b*c, 1, h, w).type(torch.float32) - x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut) + x = x.reshape(b*c, 1, h, w) + lsb = x % 16 + msb = x - lsb + output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device) + output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3]) + rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3]) + output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3]) + output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3]) + output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3]) + output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H), k=-rotations_count, dims=[2, 3]) + output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3]) + output_msb /= 4*3 + output_lsb /= 4*2 + output_msb = output_msb + output_lsb + x = output_msb + lsb = x % 16 + msb = x - lsb + output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device) + output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3]) + rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3]) + output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3]) + output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3]) + output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3]) + output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3]) + output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3]) + output_msb /= 4*3 + output_lsb /= 4*2 + output_msb = output_msb + output_lsb + x = output_msb x = x.reshape(b, c, h*self.scale, w*self.scale) return x def __repr__(self): - return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" - -# class SRLutY(nn.Module): -# def __init__( -# self, -# quantization_interval, -# scale -# ): -# super(SRLutY, self).__init__() -# self.scale = scale -# self.quantization_interval = quantization_interval -# self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) -# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) -# self.rgb_to_ycbcr = layers.RgbToYcbcr() -# self.ycbcr_to_rgb = layers.YcbcrToRgb() - -# @staticmethod -# def init_from_numpy( -# stage_lut -# ): -# scale = int(stage_lut.shape[-1]) -# quantization_interval = 256//(stage_lut.shape[0]-1) -# lut_model = SRLutY(quantization_interval=quantization_interval, scale=scale) -# lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) -# return lut_model - -# def forward_stage(self, x, scale, percieve_pattern, lut): -# b,c,h,w = x.shape -# x = percieve_pattern(x) -# x = select_index_4dlut_tetrahedral(index=x, lut=lut) -# x = round_func(x) -# x = x.reshape(b, c, h, w, scale, scale) -# x = x.permute(0,1,2,4,3,5) -# x = x.reshape(b, c, h*scale, w*scale) -# return x - -# def forward(self, x): -# b,c,h,w = x.shape -# x = self.rgb_to_ycbcr(x) -# y = x[:,0:1,:,:] -# cbcr = x[:,1:,:,:] -# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') - -# output = self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut) -# output = torch.cat([output, cbcr_scaled], dim=1) -# output = self.ycbcr_to_rgb(output).clamp(0, 255) -# return output - -# def __repr__(self): -# return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" - -# class SRLutR90(nn.Module): -# def __init__( -# self, -# quantization_interval, -# scale -# ): -# super(SRLutR90, self).__init__() -# self.scale = scale -# self.quantization_interval = quantization_interval -# self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) -# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) - -# @staticmethod -# def init_from_numpy( -# stage_lut -# ): -# scale = int(stage_lut.shape[-1]) -# quantization_interval = 256//(stage_lut.shape[0]-1) -# lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale) -# lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) -# return lut_model - -# def forward_stage(self, x, scale, percieve_pattern, lut): -# b,c,h,w = x.shape -# x = percieve_pattern(x) -# x = select_index_4dlut_tetrahedral(index=x, lut=lut) -# x = round_func(x) -# x = x.reshape(b, c, h, w, scale, scale) -# x = x.permute(0,1,2,4,3,5) -# x = x.reshape(b, c, h*scale, w*scale) -# return x - -# def forward(self, x): -# b,c,h,w = x.shape -# x = x.reshape(b*c, 1, h, w) -# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) -# output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut) -# for rotations_count in range(1, 4): -# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) -# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3]) -# output /= 4 -# output = output.reshape(b, c, h*self.scale, w*self.scale) -# return output - -# def __repr__(self): -# return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" - - -# class SRLutR90Y(nn.Module): -# def __init__( -# self, -# quantization_interval, -# scale -# ): -# super(SRLutR90Y, self).__init__() -# self.scale = scale -# self.quantization_interval = quantization_interval -# self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) -# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) -# self.rgb_to_ycbcr = layers.RgbToYcbcr() -# self.ycbcr_to_rgb = layers.YcbcrToRgb() - -# @staticmethod -# def init_from_numpy( -# stage_lut -# ): -# scale = int(stage_lut.shape[-1]) -# quantization_interval = 256//(stage_lut.shape[0]-1) -# lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale) -# lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) -# return lut_model - -# def forward_stage(self, x, scale, percieve_pattern, lut): -# b,c,h,w = x.shape -# x = percieve_pattern(x) -# x = select_index_4dlut_tetrahedral(index=x, lut=lut) -# x = round_func(x) -# x = x.reshape(b, c, h, w, scale, scale) -# x = x.permute(0,1,2,4,3,5) -# x = x.reshape(b, c, h*scale, w*scale) -# return x - -# def forward(self, x): -# b,c,h,w = x.shape -# x = self.rgb_to_ycbcr(x) -# y = x[:,0:1,:,:] -# cbcr = x[:,1:,:,:] -# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') - -# output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) -# output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut) -# for rotations_count in range(1,4): -# rotated = torch.rot90(y, k=rotations_count, dims=[2, 3]) -# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3]) -# output /= 4 -# output = torch.cat([output, cbcr_scaled], dim=1) -# output = self.ycbcr_to_rgb(output).clamp(0, 255) -# return output - -# def __repr__(self): -# return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" \ No newline at end of file + return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" \ No newline at end of file diff --git a/src/models/hdbnet.py b/src/models/hdbnet.py index e0ad96f..01c45e1 100644 --- a/src/models/hdbnet.py +++ b/src/models/hdbnet.py @@ -11,6 +11,7 @@ from common import layers class HDBNet(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(HDBNet, self).__init__() + assert scale == 4 self.scale = scale self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) @@ -79,203 +80,20 @@ class HDBNet(nn.Module): return x def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = srlut.SRLut.init_from_numpy(stage_lut) - return lut_model - -class HDBLNet(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(HDBLNet, self).__init__() - self.scale = scale - self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) - self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) - self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) - self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) - - self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) - self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) - self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) - self.stage2_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2) - - self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3) - self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3) - self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3) - self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=3) - - def forward_stage(self, x, scale, percieve_pattern, stage): - b,c,h,w = x.shape - x = percieve_pattern(x) - x = stage(x) - x = round_func(x) - x = x.reshape(b, c, h, w, scale, scale) - x = x.permute(0,1,2,4,3,5) - x = x.reshape(b, c, h*scale, w*scale) - return x - - def forward(self, x): - b,c,h,w = x.shape - x = x.reshape(b*c, 1, h, w) - lsb = x % 16 - msb = x - lsb - output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device) - output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3]) - rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3]) - output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3]) - output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3]) - output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3]) - output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_3L, self.stage1_3L), k=-rotations_count, dims=[2, 3]) - output_msb /= 4*3 - output_lsb /= 4 - output_msb = output_msb + output_lsb - x = output_msb - lsb = x % 16 - msb = x - lsb - output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device) - output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3]) - rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3]) - output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3]) - output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3]) - output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3]) - output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_3L, self.stage2_3L), k=-rotations_count, dims=[2, 3]) - output_msb /= 4*3 - output_lsb /= 4 - output_msb = output_msb + output_lsb - x = output_msb - x = x.reshape(b, c, h*self.scale, w*self.scale) - return x - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = srlut.SRLut.init_from_numpy(stage_lut) - return lut_model - - -# class SRNetY(nn.Module): -# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): -# super(SRNetY, self).__init__() -# self.scale = scale -# self.stage1_S = layers.UpscaleBlock( -# hidden_dim=hidden_dim, -# layers_count=layers_count, -# upscale_factor=self.scale -# ) -# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) -# self.rgb_to_ycbcr = layers.RgbToYcbcr() -# self.ycbcr_to_rgb = layers.YcbcrToRgb() - -# def forward_stage(self, x, scale, percieve_pattern, stage): -# b,c,h,w = x.shape -# x = percieve_pattern(x) -# x = stage(x) -# x = round_func(x) -# x = x.reshape(b, c, h, w, scale, scale) -# x = x.permute(0,1,2,4,3,5) -# x = x.reshape(b, c, h*scale, w*scale) -# return x - -# def forward(self, x): -# b,c,h,w = x.shape -# x = self.rgb_to_ycbcr(x) -# y = x[:,0:1,:,:] -# cbcr = x[:,1:,:,:] -# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') - -# x = y.view(b, 1, h, w) -# output = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) -# output = torch.cat([output, cbcr_scaled], dim=1) -# output = self.ycbcr_to_rgb(output).clamp(0, 255) -# return output - -# def get_lut_model(self, quantization_interval=16, batch_size=2**10): -# stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) -# lut_model = srlut.SRLutY.init_from_numpy(stage_lut) -# return lut_model - -# class SRNetR90(nn.Module): -# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): -# super(SRNetR90, self).__init__() -# self.scale = scale -# self.stage1_S = layers.UpscaleBlock( -# hidden_dim=hidden_dim, -# layers_count=layers_count, -# upscale_factor=self.scale -# ) -# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) - -# def forward_stage(self, x, scale, percieve_pattern, stage): -# b,c,h,w = x.shape -# x = percieve_pattern(x) -# x = stage(x) -# x = round_func(x) -# x = x.reshape(b, c, h, w, scale, scale) -# x = x.permute(0,1,2,4,3,5) -# x = x.reshape(b, c, h*scale, w*scale) -# return x - -# def forward(self, x): -# b,c,h,w = x.shape -# x = x.reshape(b*c, 1, h, w) -# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) -# output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) -# for rotations_count in range(1,4): -# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) -# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) -# output /= 4 -# output = output.reshape(b, c, h*self.scale, w*self.scale) -# return output - -# def get_lut_model(self, quantization_interval=16, batch_size=2**10): -# stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) -# lut_model = srlut.SRLutR90.init_from_numpy(stage_lut) -# return lut_model - -# class SRNetR90Y(nn.Module): -# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): -# super(SRNetR90Y, self).__init__() -# self.scale = scale -# s_pattern=[[0,0],[0,1],[1,0],[1,1]] -# self.stage1_S = layers.UpscaleBlock( -# hidden_dim=hidden_dim, -# layers_count=layers_count, -# upscale_factor=self.scale -# ) -# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) -# self.rgb_to_ycbcr = layers.RgbToYcbcr() -# self.ycbcr_to_rgb = layers.YcbcrToRgb() - -# def forward_stage(self, x, scale, percieve_pattern, stage): -# b,c,h,w = x.shape -# x = percieve_pattern(x) -# x = stage(x) -# x = round_func(x) -# x = x.reshape(b, c, h, w, scale, scale) -# x = x.permute(0,1,2,4,3,5) -# x = x.reshape(b, c, h*scale, w*scale) -# return x - -# def forward(self, x): -# b,c,h,w = x.shape -# x = self.rgb_to_ycbcr(x) -# y = x[:,0:1,:,:] -# cbcr = x[:,1:,:,:] -# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') - -# x = y.view(b, 1, h, w) -# output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) -# output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) -# for rotations_count in range(1,4): -# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) -# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) -# output /= 4 -# output = torch.cat([output, cbcr_scaled], dim=1) -# output = self.ycbcr_to_rgb(output).clamp(0, 255) -# return output - -# def get_lut_model(self, quantization_interval=16, batch_size=2**10): -# stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) -# lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut) -# return lut_model \ No newline at end of file + stage1_3H = lut.transfer_2x2_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size) + stage1_3D = lut.transfer_2x2_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size) + stage1_3B = lut.transfer_2x2_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size) + stage1_2H = lut.transfer_2x2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size) + stage1_2D = lut.transfer_2x2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size) + + stage2_3H = lut.transfer_2x2_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_3D = lut.transfer_2x2_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_3B = lut.transfer_2x2_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_2H = lut.transfer_2x2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_2D = lut.transfer_2x2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size) + + lut_model = hdblut.HDBLut.init_from_numpy( + stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D, + stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D + ) + return lut_model \ No newline at end of file