diff --git a/src/common/layers.py b/src/common/layers.py index 38b1b54..47b8b6e 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -367,8 +367,8 @@ class ChebyKANLayer(nn.Module): self.outdim = out_features self.degree = degree - self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1)) - nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1))) + self.cheby_coeffs = nn.Parameter(torch.empty(in_features, out_features, degree + 1)) + nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (in_features * (degree + 1))) self.register_buffer("arange", torch.arange(0, degree + 1, 1)) def forward(self, x): @@ -390,6 +390,7 @@ class ChebyKANLayer(nn.Module): y = torch.einsum( "btid,iod->bto", x, self.cheby_coeffs ) # shape = (batch_size, hw, outdim) + y = y.view(b, hw, self.outdim) return y @@ -404,10 +405,10 @@ class UpscaleBlockChebyKAN(nn.Module): self.linear_projections = [] for i in range(layers_count): - self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, bias=True)) + self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim)) self.linear_projections = nn.ModuleList(self.linear_projections) - self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) + self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) self.in_bias = self.in_scale = input_max_value/2 self.out_bias = self.out_scale = output_max_value/2 diff --git a/src/common/lut.py b/src/common/lut.py index 3a1ef8c..8c2f8d6 100644 --- a/src/common/lut.py +++ b/src/common/lut.py @@ -113,7 +113,7 @@ def transfer_rc_conv(rc_conv, quantization_interval=1): print() return lut -def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255): +def transfer_4_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255): bucket_count = (max_value+1)//quantization_interval scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1 lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2 diff --git a/src/models/__init__.py b/src/models/__init__.py index 875d7f7..d51bfd2 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -5,48 +5,17 @@ from . import srlut from . import sdynet from . import hdbnet from . import hdblut +from models.base import SRNetBase from common import losses import torch import numpy as np from pathlib import Path +import inspect, sys -AVAILABLE_MODELS = { - 'SRNet': srnet.SRNet, - 'SRLut': srlut.SRLut, - 'SRNetR90': srnet.SRNetR90, - 'SRLutR90': srlut.SRLutR90, - 'SRNetR90Y': srnet.SRNetR90Y, - 'SRLutR90Y': srlut.SRLutR90Y, - 'SDYNetx1': sdynet.SDYNetx1, - 'SDYLutx1': sdylut.SDYLutx1, - 'SDYNetx2': sdynet.SDYNetx2, - 'SDYLutx2': sdylut.SDYLutx2, - 'SDYNetx3': sdynet.SDYNetx3, - 'SDYLutx3': sdylut.SDYLutx3, - 'SDYNetR90x1': sdynet.SDYNetR90x1, - 'SDYLutR90x1': sdylut.SDYLutR90x1, - 'SDYNetR90x2': sdynet.SDYNetR90x2, - 'SDYLutR90x2': sdylut.SDYLutR90x2, - 'SRNetY': srnet.SRNetY, - 'SRLutY': srlut.SRLutY, - 'HDBNet': hdbnet.HDBNet, - 'HDBLut': hdblut.HDBLut, - 'HDBLNet': hdbnet.HDBLNet, - 'HDBHNet': hdbnet.HDBHNet, - 'SRMsbLsbNet': srnet.SRMsbLsbNet, - 'SRMsbLsbShiftNet': srnet.SRMsbLsbShiftNet, - 'SRMsbLsbR90Net': srnet.SRMsbLsbR90Net, - 'SRMsbLsb4R90Net': srnet.SRMsbLsb4R90Net, - # 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, - # 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, - # 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, - # 'RCNetRot90_7x7': rcnet.RCNetRot90_7x7, 'RCLutRot90_7x7': rclut.RCLutRot90_7x7, - # 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1, - # 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2, - # 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered, - # 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable, - # 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable, -} +AVAILABLE_MODELS = {} +for module_name in sys.modules.keys(): + if 'models.' in module_name: + AVAILABLE_MODELS.update({k:v for k,v in inspect.getmembers(sys.modules[module_name], lambda x: inspect.isclass(x) and SRNetBase in x.__bases__)}) def SaveCheckpoint(model, path): model_container = { diff --git a/src/models/base.py b/src/models/base.py new file mode 100644 index 0000000..dd577cb --- /dev/null +++ b/src/models/base.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from common.utils import round_func + +class SRNetBase(nn.Module): + def __init__(self): + super(SRNetBase, self).__init__() + + def forward_stage(self, x, scale, percieve_pattern, stage): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = stage(x) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + raise NotImplementedError + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn \ No newline at end of file diff --git a/src/models/srlut.py b/src/models/srlut.py index e2b0e17..a9524f0 100644 --- a/src/models/srlut.py +++ b/src/models/srlut.py @@ -6,8 +6,9 @@ from pathlib import Path from common.lut import select_index_4dlut_tetrahedral from common import layers from common.utils import round_func +from models.base import SRNetBase -class SRLut(nn.Module): +class SRLut(SRNetBase): def __init__( self, quantization_interval, @@ -54,7 +55,7 @@ class SRLut(nn.Module): return F.mse_loss(pred/255, target/255) return loss_fn -class SRLutY(nn.Module): +class SRLutY(SRNetBase): def __init__( self, quantization_interval, @@ -108,7 +109,7 @@ class SRLutY(nn.Module): return F.mse_loss(pred/255, target/255) return loss_fn -class SRLutR90(nn.Module): +class SRLutR90(SRNetBase): def __init__( self, quantization_interval, @@ -161,7 +162,7 @@ class SRLutR90(nn.Module): return loss_fn -class SRLutR90Y(nn.Module): +class SRLutR90Y(SRNetBase): def __init__( self, quantization_interval, @@ -215,6 +216,60 @@ class SRLutR90Y(nn.Module): def __repr__(self): return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + +class SRLutR90YCbCr(SRNetBase): + def __init__( + self, + quantization_interval, + scale + ): + super(SRLutR90YCbCr, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + + @staticmethod + def init_from_numpy( + stage_lut + ): + scale = int(stage_lut.shape[-1]) + quantization_interval = 256//(stage_lut.shape[0]-1) + lut_model = SRLutR90YCbCr(quantization_interval=quantization_interval, scale=scale) + lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) + return lut_model + + def forward_stage(self, x, scale, percieve_pattern, lut): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = select_index_4dlut_tetrahedral(index=x, lut=lut) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + + def forward(self, x, config=None): + b,c,h,w = x.shape + y = x[:,0:1,:,:] + cbcr = x[:,1:,:,:] + cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') + output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut) + for rotations_count in range(1,4): + rotated = torch.rot90(y, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3]) + output /= 4 + output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255) + return output + + def __repr__(self): + return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" + def get_loss_fn(self): def loss_fn(pred, target): return F.mse_loss(pred/255, target/255) diff --git a/src/models/srnet.py b/src/models/srnet.py index fbe2acc..dc3a03d 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -8,35 +8,37 @@ from pathlib import Path from . import srlut from common import layers from common import losses +from models.base import SRNetBase -class SRNetBase(nn.Module): - def __init__(self): - super(SRNetBase, self).__init__() - - def forward_stage(self, x, scale, percieve_pattern, stage): +class SRNet(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SRNet, self).__init__() + self.scale = scale + self.stage1_S = layers.UpscaleBlock( + hidden_dim=hidden_dim, + layers_count=layers_count, + upscale_factor=self.scale + ) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + + def forward(self, x, config=None): b,c,h,w = x.shape - x = percieve_pattern(x) - x = stage(x) - x = round_func(x) - x = x.reshape(b, c, h, w, scale, scale) - x = x.permute(0,1,2,4,3,5) - x = x.reshape(b, c, h*scale, w*scale) + x = x.reshape(b*c, 1, h, w) + x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) + x = x.reshape(b, c, h*self.scale, w*self.scale) return x def get_lut_model(self, quantization_interval=16, batch_size=2**10): - raise NotImplementedError - - def get_loss_fn(self): - def loss_fn(pred, target): - return F.mse_loss(pred/255, target/255) - return loss_fn + stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = srlut.SRLut.init_from_numpy(stage_lut) + return lut_model -class SRNet(SRNetBase): +class SRNetChebyKan(SRNetBase): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(SRNet, self).__init__() + super(SRNetChebyKan, self).__init__() self.scale = scale - self.stage1_S = layers.UpscaleBlock( + self.stage1_S = layers.UpscaleBlockChebyKAN( hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale @@ -51,7 +53,7 @@ class SRNet(SRNetBase): return x def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) lut_model = srlut.SRLut.init_from_numpy(stage_lut) return lut_model @@ -83,7 +85,7 @@ class SRNetY(SRNetBase): return output def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) lut_model = srlut.SRLutY.init_from_numpy(stage_lut) return lut_model @@ -103,8 +105,7 @@ class SRNetR90(SRNetBase): b,c,h,w = x.shape x = x.reshape(b*c, 1, h, w) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) - for rotations_count in range(1,4): + for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output /= 4 @@ -112,10 +113,36 @@ class SRNetR90(SRNetBase): return output def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) lut_model = srlut.SRLutR90.init_from_numpy(stage_lut) return lut_model +class SRNetChebyKanR90(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SRNetChebyKanR90, self).__init__() + self.scale = scale + self.stage1_S = layers.UpscaleBlockChebyKAN( + hidden_dim=hidden_dim, + layers_count=layers_count, + upscale_factor=self.scale + ) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) + output /= 4 + output = output.reshape(b, c, h*self.scale, w*self.scale) + return output + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = srlut.SRLutR90.init_from_numpy(stage_lut) + return lut_model class SRNetR90Y(SRNetBase): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): @@ -140,8 +167,7 @@ class SRNetR90Y(SRNetBase): x = y.view(b, 1, h, w) output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) - for rotations_count in range(1,4): + for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output /= 4 @@ -149,6 +175,41 @@ class SRNetR90Y(SRNetBase): output = self.ycbcr_to_rgb(output).clamp(0, 255) return output + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = srlut.SRLutR90Yq.init_from_numpy(stage_lut) + return lut_model + + +class SRNetR90Ycbcr(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SRNetR90Ycbcr, self).__init__() + self.scale = scale + s_pattern=[[0,0],[0,1],[1,0],[1,1]] + self.stage1_S = layers.UpscaleBlock( + hidden_dim=hidden_dim, + layers_count=layers_count, + upscale_factor=self.scale + ) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + self.rgb_to_ycbcr = layers.RgbToYcbcr() + self.ycbcr_to_rgb = layers.YcbcrToRgb() + + def forward(self, x, config=None): + b,c,h,w = x.shape + y = x[:,0:1,:,:] + cbcr = x[:,1:,:,:] + cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') + + x = y.view(b, 1, h, w) + output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) + output /= 4 + output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255) + return output + def get_lut_model(self, quantization_interval=16, batch_size=2**10): stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut) diff --git a/src/test.py b/src/test.py index b16b108..541babd 100644 --- a/src/test.py +++ b/src/test.py @@ -26,7 +26,7 @@ class ValOptions(): self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.") self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.") self.parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Names of test datasets.") - self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') + self.parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar')