From 5c22148bed602d88087930f2a14f70c3e2d40f4d Mon Sep 17 00:00:00 2001 From: protsenkovi Date: Wed, 3 Jul 2024 22:59:42 +0400 Subject: [PATCH] updates --- {models => experiments}/.gitignore | 0 src/common/layers.py | 89 +++++++- src/common/test.py | 2 +- src/models/base.py | 11 +- src/models/hdbnet.py | 154 +++++++++++--- src/models/sdynet.py | 316 +++++++++++++++++++++++++++-- src/models/srnet.py | 118 ++++++++++- src/test.py | 10 +- src/train.py | 58 +++--- 9 files changed, 656 insertions(+), 102 deletions(-) rename {models => experiments}/.gitignore (100%) diff --git a/models/.gitignore b/experiments/.gitignore similarity index 100% rename from models/.gitignore rename to experiments/.gitignore diff --git a/src/common/layers.py b/src/common/layers.py index 7c75f65..b32c516 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -5,19 +5,35 @@ import numpy as np from .utils import round_func class PercievePattern(): - def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2): - assert window_size >= (np.max(receptive_field_idxes)+1) - receptive_field_idxes = np.array(receptive_field_idxes) + """ + Coordinates scheme: [channel, height, width]. Channel can be ommited for all channels. + Examples: + 1. receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]] + 2. receptive_field_idxes=[[0,0,0],[0,1,0],[1,1,0],[1,1]] + """ + def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, channels=1): + assert window_size >= (np.max([x for y in receptive_field_idxes for x in y])+1) + tmp = [] + for coords in receptive_field_idxes: + if len(coords) < 3: + for i in range(channels): + tmp.append([i,] + coords) + if len(coords) == 3: + tmp.append(coords) + receptive_field_idxes = np.array(sorted(tmp)) self.window_size = window_size self.center = center - self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size + receptive_field_idxes[i,1] for i in range(len(receptive_field_idxes))] + self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size*self.window_size + receptive_field_idxes[i,1]*self.window_size + receptive_field_idxes[i,2] for i in range(len(receptive_field_idxes))] + assert len(np.unique(self.receptive_field_idxes) == len(self.receptive_field_idxes)), "Duplicated coordinates found. Coordinates scheme: [channel, height, width]." def __call__(self, x): b,c,h,w = x.shape x = F.pad( x, - pad=[self.center[0], self.window_size-self.center[0]-1, - self.center[1], self.window_size-self.center[1]-1], + pad=[ + self.center[0], self.window_size-self.center[0]-1, + self.center[1], self.window_size-self.center[1]-1 + ], mode='replicate' ) x = F.unfold(input=x, kernel_size=self.window_size) @@ -451,4 +467,63 @@ class UpscaleBlockEffKAN(nn.Module): x = self.project_channels(x) x = torch.tanh(x) x = x*self.out_scale + self.out_bias - return x \ No newline at end of file + return x + + +class ComplexGaborLayer2D(nn.Module): + ''' + Implicit representation with complex Gabor nonlinearity with 2D activation function + https://github.com/vishwa91/wire/blob/main/modules/wire2d.py + Inputs; + in_features: Input features + out_features; Output features + bias: if True, enable bias for the linear operation + is_first: Legacy SIREN parameter + omega_0: Legacy SIREN parameter + omega0: Frequency of Gabor sinusoid term + sigma0: Scaling of Gabor Gaussian term + trainable: If True, omega and sigma are trainable parameters + ''' + + def __init__(self, in_features, out_features, bias=True, + is_first=False, omega0=10.0, sigma0=10.0, + trainable=False): + super().__init__() + self.omega_0 = omega0 + self.scale_0 = sigma0 + self.is_first = is_first + + self.in_features = in_features + + if self.is_first: + dtype = torch.float + else: + dtype = torch.cfloat + + # Set trainable parameters if they are to be simultaneously optimized + self.omega_0 = nn.Parameter(self.omega_0*torch.ones(1), trainable) + self.scale_0 = nn.Parameter(self.scale_0*torch.ones(1), trainable) + + self.linear = nn.Linear(in_features, + out_features, + bias=bias, + dtype=dtype) + + # Second Gaussian window + self.scale_orth = nn.Linear(in_features, + out_features, + bias=bias, + dtype=dtype) + + def forward(self, input): + lin = self.linear(input) + + scale_x = lin + scale_y = self.scale_orth(input) + + freq_term = torch.exp(1j*self.omega_0*lin) + + arg = scale_x.abs().square() + scale_y.abs().square() + gauss_term = torch.exp(-self.scale_0*self.scale_0*arg) + + return freq_term*gauss_term \ No newline at end of file diff --git a/src/common/test.py b/src/common/test.py index 569ca07..30352b0 100644 --- a/src/common/test.py +++ b/src/common/test.py @@ -92,7 +92,7 @@ def test_steps(model, datasets, config, log_prefix="", print_progress = False): total_area += lr_area row = [ - dataset_name, + f"{dataset_name} {config.color_model}", np.mean(psnrs), np.mean(ssims), np.mean(run_times_ns)*1e-9, diff --git a/src/models/base.py b/src/models/base.py index dd577cb..0f485fb 100644 --- a/src/models/base.py +++ b/src/models/base.py @@ -8,14 +8,15 @@ class SRNetBase(nn.Module): def __init__(self): super(SRNetBase, self).__init__() - def forward_stage(self, x, scale, percieve_pattern, stage): + def forward_stage(self, x, percieve_pattern, stage): b,c,h,w = x.shape - x = percieve_pattern(x) + scale = stage.upscale_factor + x = percieve_pattern(x) x = stage(x) x = round_func(x) - x = x.reshape(b, c, h, w, scale, scale) - x = x.permute(0,1,2,4,3,5) - x = x.reshape(b, c, h*scale, w*scale) + x = x.reshape(b, 1, h, w, scale, scale) + x = x.permute(0, 1, 2, 4, 3, 5) + x = x.reshape(b, 1, h*scale, w*scale) return x def get_lut_model(self, quantization_interval=16, batch_size=2**10): diff --git a/src/models/hdbnet.py b/src/models/hdbnet.py index 15fd3a7..f894222 100644 --- a/src/models/hdbnet.py +++ b/src/models/hdbnet.py @@ -113,11 +113,12 @@ class HDBNet(SRNetBase): return loss_fn class HDBNetv2(SRNetBase): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4, rotations = 4): + def __init__(self, hidden_dim = 64, layers_count = 2, scale = 4, rotations = 4): super(HDBNetv2, self).__init__() assert scale == 4 self.scale = scale self.rotations = rotations + self.layers_count = layers_count self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255) self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255) self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255) @@ -208,8 +209,55 @@ class HDBNetv2(SRNetBase): return F.mse_loss(pred/255, target/255) return loss_fn +class HDBLNet(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(HDBLNet, self).__init__() + self.scale = scale + self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale) + self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale) + self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale) + self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale) + + self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3) + self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3) + self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2) -class HDBLNet(nn.Module): + def forward_stage(self, x, scale, percieve_pattern, stage): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = stage(x) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + lsb = x % 16 + msb = x - lsb + output_msb = self.forward_stage(msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \ + self.forward_stage(msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \ + self.forward_stage(msb, self.scale, self._extract_pattern_3B, self.stage1_3B) + output_lsb = self.forward_stage(lsb, self.scale, self._extract_pattern_3L, self.stage1_3L) + output_msb /= 3 + if not config is None and config.current_iter % config.display_step == 0: + config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter) + config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter) + output = output_msb + output_lsb + x = output.clamp(0, 255) + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + + +class HDBLNetR90(SRNetBase): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(HDBLNet, self).__init__() self.scale = scale @@ -243,17 +291,19 @@ class HDBLNet(nn.Module): for rotations_count in range(4): rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3]) rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3]) - output_msb += torch.rot90(self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3]) - output_msb += torch.rot90(self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3]) - output_msb += torch.rot90(self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3]) - output_lsb += torch.rot90(self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L), k=-rotations_count, dims=[2, 3]) + output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \ + self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \ + self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B) + output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L) + if not config is None and config.current_iter % config.display_step == 0: + config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter) + config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter) + output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3]) + output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3]) output_msb /= 4*3 output_lsb /= 4 - output_msb = round_func((output_msb / 255) * 16) * 15 - output_lsb = (output_lsb / 255) * 15 - - x = output_msb + output_lsb - + output = output_msb + output_lsb + x = output.clamp(0, 255) x = x.reshape(b, c, h*self.scale, w*self.scale) return x @@ -262,39 +312,77 @@ class HDBLNet(nn.Module): return F.mse_loss(pred/255, target/255) return loss_fn - # def get_lut_model(self, quantization_interval=16, batch_size=2**10): - # stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size) - # stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size) - # stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size) - # stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size) - # stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size) - - # stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size) - # stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size) - # stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size) - # stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size) - # stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size) - - # lut_model = hdblut.HDBLut.init_from_numpy( - # stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D, - # stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D - # ) - # return lut_model - -class HDBHNet(nn.Module): + +class HDBLNetR90KAN(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(HDBLNetR90KAN, self).__init__() + self.scale = scale + self.stage1_3H = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale) + self.stage1_3D = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale) + self.stage1_3B = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale) + self.stage1_3L = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale) + + self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3) + self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3) + self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2) + + def forward_stage(self, x, scale, percieve_pattern, stage): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = stage(x) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + lsb = x % 16 + msb = x - lsb + output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3]) + rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3]) + output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \ + self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \ + self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B) + output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L) + if not config is None and config.current_iter % config.display_step == 0: + config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter) + config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter) + output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3]) + output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3]) + output_msb /= 4*3 + output_lsb /= 4 + output = output_msb + output_lsb + x = output.clamp(0, 255) + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + + +class HDBHNet(SRNetBase): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(HDBHNet, self).__init__() self.scale = scale self.hidden_dim = hidden_dim self.layers_count = layers_count - self.msb_fns = nn.ModuleList([layers.UpscaleBlock( + self.msb_fns = SRNetBaseList([layers.UpscaleBlock( in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale ) for x in range(1)]) - self.lsb_fns = nn.ModuleList([layers.UpscaleBlock( + self.lsb_fns = SRNetBaseList([layers.UpscaleBlock( in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, diff --git a/src/models/sdynet.py b/src/models/sdynet.py index 745ce07..cbac125 100644 --- a/src/models/sdynet.py +++ b/src/models/sdynet.py @@ -92,6 +92,54 @@ class SDYNetx2(SRNetBase): return F.mse_loss(pred/255, target/255) return loss_fn + +class SDYNetx2Inv(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYNetx2Inv, self).__init__() + self.scale = scale + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) + self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage2_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage2_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = 0.0 + output += self.forward_stage(x, self._extract_pattern_S, self.stage1_S) + output += self.forward_stage(x, self._extract_pattern_D, self.stage1_D) + output += self.forward_stage(x, self._extract_pattern_Y, self.stage1_Y) + output /= 3 + x = output + output = 0.0 + output += self.forward_stage(x, self._extract_pattern_S, self.stage2_S) + output += self.forward_stage(x, self._extract_pattern_D, self.stage2_D) + output += self.forward_stage(x, self._extract_pattern_Y, self.stage2_Y) + output /= 3 + x = (output + x)/2 + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + stage1_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + stage1_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) + stage1_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) + return lut_model + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + class SDYNetx3(SRNetBase): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(SDYNetx3, self).__init__() @@ -151,7 +199,7 @@ class SDYNetx3(SRNetBase): return F.mse_loss(pred/255, target/255) return loss_fn -class SDYNetR90x1(nn.Module): +class SDYNetR90x1(SRNetBase): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(SDYNetR90x1, self).__init__() self.scale = scale @@ -162,16 +210,6 @@ class SDYNetR90x1(nn.Module): self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) - def forward_stage(self, x, scale, percieve_pattern, stage): - b,c,h,w = x.shape - x = percieve_pattern(x) - x = stage(x) - x = round_func(x) - x = x.reshape(b, c, h, w, scale, scale) - x = x.permute(0,1,2,4,3,5) - x = x.reshape(b, c, h*scale, w*scale) - return x - def forward(self, x, config=None): b,c,h,w = x.shape x = x.reshape(b*c, 1, h, w) @@ -219,10 +257,7 @@ class SDYNetR90x2(SRNetBase): b,c,h,w = x.shape x = x.view(b*c, 1, h, w) output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) - output_1 += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S) - output_1 += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D) - output_1 += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y) - for rotations_count in range(1,4): + for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1]) output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1]) @@ -230,10 +265,7 @@ class SDYNetR90x2(SRNetBase): output_1 /= 4*3 x = output_1 output_2 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - output_2 += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S) - output_2 += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D) - output_2 += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y) - for rotations_count in range(1,4): + for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1]) output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1]) @@ -253,6 +285,252 @@ class SDYNetR90x2(SRNetBase): lut_model = sdylut.SDYLutR90x2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) return lut_model + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + +class SDYEHONetR90x1(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYEHONetR90x1, self).__init__() + self.scale = scale + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) + self._extract_pattern_E = layers.PercievePattern(receptive_field_idxes=[[0,0],[3,0],[0,3],[3,3]], center=[0,0], window_size=4) + self._extract_pattern_H = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,2],[2,3],[3,2]], center=[0,0], window_size=4) + self._extract_pattern_O = layers.PercievePattern(receptive_field_idxes=[[0,0],[3,1],[2,2],[1,3]], center=[0,0], window_size=4) + + self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_E = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_H = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_O = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output_1 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_E, self.stage1_E), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_H, self.stage1_H), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_O, self.stage1_O), k=-rotations_count, dims=[-2, -1]) + output_1 /= 4*6 + x = output_1 + x = x.view(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + raise NotImplementedError + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + + +class SDYEHONetR90x2(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYEHONetR90x2, self).__init__() + self.scale = scale + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) + self._extract_pattern_E = layers.PercievePattern(receptive_field_idxes=[[0,0],[3,0],[0,3],[3,3]], center=[0,0], window_size=4) + self._extract_pattern_H = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,2],[2,3],[3,2]], center=[0,0], window_size=4) + self._extract_pattern_O = layers.PercievePattern(receptive_field_idxes=[[0,0],[3,1],[2,2],[1,3]], center=[0,0], window_size=4) + + self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_E = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_H = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_O = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + + self.stage2_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_E = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_H = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_O = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_E, self.stage1_E), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_H, self.stage1_H), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_O, self.stage1_O), k=-rotations_count, dims=[-2, -1]) + output_1 /= 4*6 + x = output_1 + output_2 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) + output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1]) + output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1]) + output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1]) + output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_E, self.stage2_E), k=-rotations_count, dims=[-2, -1]) + output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_H, self.stage2_H), k=-rotations_count, dims=[-2, -1]) + output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_O, self.stage2_O), k=-rotations_count, dims=[-2, -1]) + output_2 /= 4*6 + x = output_2 + x = x.view(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + raise NotImplementedError + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + + + +class SDYMixNetx1(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYMixNetx1, self).__init__() + self.scale = scale + self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes=[[1,1],[1,2],[2,1],[2,2]], center=[1,1], window_size=4) + self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes=[[1,1],[1,3],[3,1],[3,3]], center=[1,1], window_size=4) + self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes=[[1,1],[2,2],[3,2],[2,3]], center=[1,1], window_size=4) + self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes=[[1,1],[0,0],[0,1],[1,0]], center=[1,1], window_size=4) + self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes=[[1,1],[0,2],[0,3]], center=[1,1], window_size=4) + self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes=[[1,1],[2,0],[3,0]], center=[1,1], window_size=4) + + self.stage1_1 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_5 = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_6 = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + + self.stage1_Mix = layers.UpscaleBlockChebyKAN(in_features=6, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1) + output = output.permute(0, 2, 3, 1).view(b, h*w, 6) + output = self.stage1_Mix(output) + output = output.view(b, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5) + x = output + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + +class SDYMixNetx1v2(SRNetBase): + """ + 22 + 12 23 32 21 + 11 13 33 31 + 10 14 34 30 + 01 03 43 41 + 00 04 44 40 + """ + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYMixNetx1v2, self).__init__() + self.scale = scale + self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes=[[2,2]], center=[2,2], window_size=5) + self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes=[[1,2],[2,3],[3,2],[2,1]], center=[2,2], window_size=5) + self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes=[[1,1],[1,3],[3,3],[3,1]], center=[2,2], window_size=5) + self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes=[[1,0],[1,4],[3,4],[3,0]], center=[2,2], window_size=5) + self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes=[[0,1],[0,3],[4,3],[4,1]], center=[2,2], window_size=5) + self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,4],[4,4],[4,0]], center=[2,2], window_size=5) + + self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + + self.stage1_Mix = layers.UpscaleBlockChebyKAN(in_features=6, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1) + output = output.permute(0, 2, 3, 1).view(b, h*w, 6) + output = self.stage1_Mix(output) + output = output.view(b, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5) + x = output + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + +class SDYMixNetx1v3(SRNetBase): + """ + 22 + 12 23 32 21 + 11 13 33 31 + 10 14 34 30 + 01 03 43 41 + 00 04 44 40 + """ + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYMixNetx1v3, self).__init__() + self.scale = scale + self._extract_pattern_0 = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=6) + + self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes=[[2,2]], center=[2,2], window_size=5) + self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes=[[1,2],[2,3],[3,2],[2,1]], center=[2,2], window_size=5) + self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes=[[1,1],[1,3],[3,3],[3,1]], center=[2,2], window_size=5) + self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes=[[1,0],[1,4],[3,4],[3,0]], center=[2,2], window_size=5) + self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes=[[0,1],[0,3],[4,3],[4,1]], center=[2,2], window_size=5) + self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,4],[4,4],[4,0]], center=[2,2], window_size=5) + + self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + self.stage1_Mix = layers.UpscaleBlockChebyKAN(in_features=6, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1) + output = self.forward_stage(output, self._extract_pattern_0, self.stage1_Mix) + x = output + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + def get_loss_fn(self): def loss_fn(pred, target): return F.mse_loss(pred/255, target/255) diff --git a/src/models/srnet.py b/src/models/srnet.py index 7c4ecd9..41579fd 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -24,7 +24,7 @@ class SRNet(SRNetBase): def forward(self, x, config=None): b,c,h,w = x.shape x = x.reshape(b*c, 1, h, w) - x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) + x = self.forward_stage(x, self._extract_pattern_S, self.stage1_S) x = x.reshape(b, c, h*self.scale, w*self.scale) return x @@ -48,7 +48,7 @@ class SRNetChebyKan(SRNetBase): def forward(self, x, config=None): b,c,h,w = x.shape x = x.reshape(b*c, 1, h, w) - x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) + x = self.forward_stage(x, self._extract_pattern_S, self.stage1_S) x = x.reshape(b, c, h*self.scale, w*self.scale) return x @@ -79,7 +79,7 @@ class SRNetY(SRNetBase): cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') x = y.view(b, 1, h, w) - output = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) + output = self.forward_stage(x, self._extract_pattern_S, self.stage1_S) output = torch.cat([output, cbcr_scaled], dim=1) output = self.ycbcr_to_rgb(output).clamp(0, 255) return output @@ -107,7 +107,7 @@ class SRNetR90(SRNetBase): output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(rotated, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output /= 4 output = output.reshape(b, c, h*self.scale, w*self.scale) return output @@ -134,7 +134,7 @@ class SRNetChebyKanR90(SRNetBase): output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(rotated, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output /= 4 output = output.reshape(b, c, h*self.scale, w*self.scale) return output @@ -169,7 +169,7 @@ class SRNetR90Y(SRNetBase): output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(rotated, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output /= 4 output = torch.cat([output, cbcr_scaled], dim=1) output = self.ycbcr_to_rgb(output).clamp(0, 255) @@ -205,7 +205,7 @@ class SRNetR90Ycbcr(SRNetBase): output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) for rotations_count in range(4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(rotated, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output /= 4 output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255) return output @@ -248,8 +248,8 @@ class SRMsbLsbNet(SRNetBase): lsb = x % 16 msb = x - lsb - output_msb = self.forward_stage(msb, self.scale, self._extract_pattern_S, self.msb_fn) - output_lsb = self.forward_stage(lsb, self.scale, self._extract_pattern_S, self.lsb_fn) + output_msb = self.forward_stage(msb, self._extract_pattern_S, self.msb_fn) + output_lsb = self.forward_stage(lsb, self._extract_pattern_S, self.lsb_fn) if not config is None and config.current_iter % config.display_step == 0: config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter) @@ -262,6 +262,53 @@ class SRMsbLsbNet(SRNetBase): raise NotImplementedError +class SRMsbLsbNetChebyKAN(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SRMsbLsbNetChebyKAN, self).__init__() + self.scale = scale + self.hidden_dim = hidden_dim + self.layers_count = layers_count + + self.msb_fn = layers.UpscaleBlockChebyKAN( + in_features=4, + hidden_dim=hidden_dim, + layers_count=layers_count, + upscale_factor=self.scale, + input_max_value=255, + output_max_value=255 + ) + self.lsb_fn = layers.UpscaleBlockChebyKAN( + in_features=4, + hidden_dim=hidden_dim, + layers_count=layers_count, + upscale_factor=self.scale, + input_max_value=15, + output_max_value=255 + ) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + + lsb = x % 16 + msb = x - lsb + + output_msb = self.forward_stage(msb, self._extract_pattern_S, self.msb_fn) + output_lsb = self.forward_stage(lsb, self._extract_pattern_S, self.lsb_fn) + + if not config is None and config.current_iter % config.display_step == 0: + config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter) + config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter) + x = output_msb + output_lsb + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + raise NotImplementedError + + + class SRMsbLsbShiftNet(SRNetBase): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(SRMsbLsbShiftNet, self).__init__() @@ -316,6 +363,59 @@ class SRMsbLsbShiftNet(SRNetBase): raise NotImplementedError +class SRMsbLsbCenterShiftNet(SRNetBase): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4, count = 4): + super(SRMsbLsbCenterShiftNet, self).__init__() + self.scale = scale + self.hidden_dim = hidden_dim + self.layers_count = layers_count + self.count = count + self.msb_fns = nn.ModuleList([layers.UpscaleBlock( + in_features=4, + hidden_dim=hidden_dim, + layers_count=layers_count, + upscale_factor=self.scale, + input_max_value=255, + output_max_value=255 + ) for x in range(self.count)]) + self.lsb_fns = nn.ModuleList([layers.UpscaleBlock( + in_features=4, + hidden_dim=hidden_dim, + layers_count=layers_count, + upscale_factor=self.scale, + input_max_value=15, + output_max_value=255 + ) for x in range(self.count)]) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + + lsb = x % 16 + msb = x - lsb + + output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for (i,j), msb_fn, lsb_fn in zip([[-2,-2],[-2,2],[2,-2],[2,2]], self.msb_fns, self.lsb_fns): + output_msb_s = self.forward_stage(msb, self.scale, self._extract_pattern_S, msb_fn) + output_lsb_s = self.forward_stage(lsb, self.scale, self._extract_pattern_S, lsb_fn) + output_msb += torch.nn.functional.pad(output_msb_s, [2, 2, 2, 2], mode='replicate')[:,:,2+i:2+i+h*self.scale,2+j:2+j+w*self.scale] + output_lsb += torch.nn.functional.pad(output_lsb_s, [2, 2, 2, 2], mode='replicate')[:,:,2+i:2+i+h*self.scale,2+j:2+j+w*self.scale] + output_msb /= self.count + output_lsb /= self.count + + if not config is None and config.current_iter % config.display_step == 0: + config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter) + config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter) + + x = output_msb + output_lsb + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + raise NotImplementedError + class SRMsbLsbR90Net(SRNetBase): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(SRMsbLsbR90Net, self).__init__() diff --git a/src/test.py b/src/test.py index d155649..73481df 100644 --- a/src/test.py +++ b/src/test.py @@ -24,27 +24,29 @@ import argparse class TestOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.") + self.parser.add_argument('--model_path', type=str, default="../experiments/last.pth", help="Model path.") self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.") self.parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Names of test datasets.") self.parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") - self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar') + self.parser.add_argument('--print_progress', type=bool, default=True, help='Show progres bar') self.parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache') + self.parser.add_argument('--test_worker_num', default=1, type=int, help='Test parallelism. Use 1 for time measurement.') def parse_args(self): args = self.parser.parse_args() args.datasets_dir = Path(args.datasets_dir).resolve() args.test_datasets = args.test_datasets.split(',') args.exp_dir = Path(args.model_path).resolve().parent.parent + print(args.exp_dir) args.model_path = Path(args.model_path).resolve() args.model_name = args.model_path.stem args.test_dir = Path(args.exp_dir).resolve() / 'test' if not args.test_dir.exists(): args.test_dir.mkdir() args.current_iter = int(args.model_name.split('_')[-1]) - args.results_path = os.path.join(args.test_dir, f'results_{args.model_name}_{args.device}.csv') + args.results_path = os.path.join(args.test_dir, f'results_{args.model_name}_{args.color_model}_{args.device}.csv') # Tensorboard for monitoring writer = SummaryWriter(log_dir=args.test_dir) logger_name = f'test_{args.model_path.stem}' @@ -90,7 +92,7 @@ if __name__ == "__main__": reset_cache=config.reset_cache, ) - results = test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress,) + results = test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.print_progress) results.to_csv(config.results_path) print() diff --git a/src/train.py b/src/train.py index 0a3ea38..f5f1606 100644 --- a/src/train.py +++ b/src/train.py @@ -17,6 +17,7 @@ from common.data import SRTrainDataset, SRTestDataset from common.utils import logger_info from common.metrics import PSNR, cal_ssim from common.color import _rgb2ycbcr, PIL_CONVERT_COLOR +import yaml from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS from common.test import test_steps @@ -38,20 +39,22 @@ class TrainOptions: parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers") parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size') parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training") - parser.add_argument('--models_dir', type=str, default='../models/', help="experiment folder") + parser.add_argument('--models_dir', type=str, default='../experiments/', help="experiment folder") parser.add_argument('--datasets_dir', type=str, default="../data/") parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further') parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations') parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration') parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration') parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration') - parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers") + parser.add_argument('--loader_worker_num', type=int, default=1, help="Number of dataloader workers") + parser.add_argument('--test_worker_num', type=int, default=1, help="Test parallelism. Use 1 for time measurement.") parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--device', default='cuda', help='Device of the model') parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") parser.add_argument('--color_model', type=str, default="RGB", help=f"Color model for train and test dataset. Choose from: {list(PIL_CONVERT_COLOR.keys())}") parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache') + parser.add_argument('--learning_rate', type=float, default=0.0025, help='Learning rate') self.parser = parser @@ -66,6 +69,9 @@ class TrainOptions: args.start_iter = int(args.model_path.stem.split("_")[-1]) return args + def save_config(self, config): + yaml.dump(config, open(config.exp_dir / "config.yaml", 'w')) + def __repr__(self): config = self.parse_args() message = '' @@ -79,33 +85,37 @@ class TrainOptions: message += '----------------- End -------------------' return message -def prepare_experiment_folder(config): - assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." - assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}." - - config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve() + def prepare_config(self): + config = self.parse_args() + + assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." + assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}." + + config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve() + + if not config.exp_dir.exists(): + config.exp_dir.mkdir() - if not config.exp_dir.exists(): - config.exp_dir.mkdir() + config.checkpoint_dir = (config.exp_dir / "checkpoints").resolve() + if not config.checkpoint_dir.exists(): + config.checkpoint_dir.mkdir() - config.checkpoint_dir = (config.exp_dir / "checkpoints").resolve() - if not config.checkpoint_dir.exists(): - config.checkpoint_dir.mkdir() + config.test_dir = (config.exp_dir / 'val').resolve() + if not config.test_dir.exists(): + config.test_dir.mkdir() - config.test_dir = (config.exp_dir / 'val').resolve() - if not config.test_dir.exists(): - config.test_dir.mkdir() + config.logs_dir = (config.exp_dir / 'logs').resolve() + if not config.logs_dir.exists(): + config.logs_dir.mkdir() - config.logs_dir = (config.exp_dir / 'logs').resolve() - if not config.logs_dir.exists(): - config.logs_dir.mkdir() + return config if __name__ == "__main__": # torch.set_float32_matmul_precision('high') script_start_time = datetime.now() config_inst = TrainOptions() - config = config_inst.parse_args() + config = config_inst.prepare_config() if not config.model_path is None: model = LoadCheckpoint(config.model_path) @@ -117,10 +127,10 @@ if __name__ == "__main__": model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale) model = model.to(torch.device(config.device)) # model = torch.compile(model) - optimizer = AdamWScheduleFree(model.parameters(), betas=(0.9, 0.95)) - print(optimizer) + optimizer = AdamWScheduleFree(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.95)) + print(optimizer) - prepare_experiment_folder(config) + config_inst.save_config(config) # Tensorboard for monitoring writer = SummaryWriter(log_dir=config.logs_dir) @@ -147,7 +157,7 @@ if __name__ == "__main__": train_loader = DataLoader( dataset = train_dataset, batch_size = config.batch_size, - num_workers = config.worker_num, + num_workers = config.loader_worker_num, shuffle = True, drop_last = False, pin_memory = True, @@ -172,7 +182,7 @@ if __name__ == "__main__": i = config.start_iter if not config.model_path is None: config.current_iter = i - valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") + test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") loss_fn = model.get_loss_fn() for i in range(config.start_iter + 1, config.total_iter + 1):