From cf7bab78afcc81cd2a00f07845cd2691d32d878f Mon Sep 17 00:00:00 2001 From: protsenkovi Date: Tue, 29 Oct 2024 22:22:41 +0400 Subject: [PATCH] HDBNet (HKLUT) implementation --- src/common/base.py | 2 +- src/common/data.py | 7 ++-- src/models/models.py | 80 ++++++++++++++++++++++++++++++++++++++++---- src/train.py | 1 - 4 files changed, 78 insertions(+), 12 deletions(-) diff --git a/src/common/base.py b/src/common/base.py index 92132e7..3e1f08b 100644 --- a/src/common/base.py +++ b/src/common/base.py @@ -12,7 +12,7 @@ class SRBase(nn.Module): def get_loss_fn(self): def loss_fn(pred, target): - return F.mse_loss(pred/255, target/255) + return F.mse_loss(pred/127.5-127.5, target/127.5-127.5) return loss_fn # def get_loss_fn(self): diff --git a/src/common/data.py b/src/common/data.py index 9298f78..5c3d4a9 100644 --- a/src/common/data.py +++ b/src/common/data.py @@ -65,7 +65,6 @@ class SRTrainDataset(Dataset): : ] - if self.rigid_aug: if random.uniform(0, 1) < 0.5: hr_patch = np.fliplr(hr_patch) @@ -78,9 +77,9 @@ class SRTrainDataset(Dataset): k = random.choice([0, 1, 2, 3]) hr_patch = np.rot90(hr_patch, k) lr_patch = np.rot90(lr_patch, k) - - hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32) - lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32) + + hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32) + lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32) hr_patch = hr_patch.permute(2,0,1) lr_patch = lr_patch.permute(2,0,1) diff --git a/src/models/models.py b/src/models/models.py index de47063..3a23370 100644 --- a/src/models/models.py +++ b/src/models/models.py @@ -44,7 +44,7 @@ class ChebyKANBase(SRBase): super(ChebyKANBase, self).__init__() self.config = None self.stage1_S = layers.UpscaleBlock(None) - window_size = 7 + window_size = 5 self._extract_pattern = layers.PercievePattern( receptive_field_idxes=[[i,j] for i in range(window_size) for j in range(window_size)], center=[window_size//2,window_size//2], @@ -62,21 +62,21 @@ class ChebyKANNet(ChebyKANBase): def __init__(self, config): super(ChebyKANNet, self).__init__() self.config = config - window_size = 7 + window_size = 5 self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet( in_features=window_size*window_size, out_channels=1, hidden_dim=16, - layers_count=self.config.layers_count, + layers_count=2,#self.config.layers_count, upscale_factor=self.config.upscale_factor, - degree=8 + degree=3 ) class ChebyKANLut(ChebyKANBase): def __init__(self, config): super(ChebyKANLut, self).__init__() self.config = config - window_size = 7 + window_size = 5 self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet( in_features=window_size*window_size, out_channels=1, @@ -92,4 +92,72 @@ class ChebyKANLut(ChebyKANBase): return ssim_loss(pred, target) + l1_loss(pred, target) return loss_fn -TRANSFERER.register(ChebyKANNet, ChebyKANLut) \ No newline at end of file +TRANSFERER.register(ChebyKANNet, ChebyKANLut) + + +class HDBNetBase(SRBase): + def __init__(self): + super(HDBNetBase, self).__init__() + self.config = None + self.stage_3H = layers.UpscaleBlock(None) + self.stage_3D = layers.UpscaleBlock(None) + self.stage_3B = layers.UpscaleBlock(None) + self.stage_2H = layers.UpscaleBlock(None) + self.stage_2D = layers.UpscaleBlock(None) + self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3) + self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3) + self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2) + self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2) + + def forward(self, x, script_config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + lsb = x % 16 + msb = x - lsb + output = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3]) + rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3]) + output_msb = self.stage_3H( rotated_msb, self._extract_pattern_3H ) + \ + self.stage_3D( rotated_msb, self._extract_pattern_3D ) + \ + self.stage_3B( rotated_msb, self._extract_pattern_3B ) + output_lsb = self.stage_2H( rotated_lsb, self._extract_pattern_2H ) + \ + self.stage_2D( rotated_lsb, self._extract_pattern_2D ) + output_msb /= 3 + output_lsb /= 2 + if not script_config is None and script_config.current_iter % script_config.display_step == 0: + script_config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), script_config.current_iter) + script_config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), script_config.current_iter) + output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255) + output /= 4 + x = output + x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor) + return x + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.l1_loss(pred/127.5-127.5, target/127.5-127.5) + return loss_fn + +class HDBNet(HDBNetBase): + def __init__(self, config): + super(HDBNet, self).__init__() + self.config = config + self.stage_3H.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + self.stage_3D.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + self.stage_3B.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + self.stage_2H.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + self.stage_2D.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + +class HDBLut(HDBNetBase): + def __init__(self, config): + super(HDBLut, self).__init__() + self.config = config + self.stage_3H.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor) + self.stage_3D.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor) + self.stage_3B.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor) + self.stage_2H.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor) + self.stage_2D.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor) + +TRANSFERER.register(HDBNet, HDBLut) \ No newline at end of file diff --git a/src/train.py b/src/train.py index 4a068f1..42affe5 100644 --- a/src/train.py +++ b/src/train.py @@ -242,7 +242,6 @@ if __name__ == "__main__": if i % config.display_step == 0: config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i) config.writer.add_scalar('loss', loss.item(), i) - config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, loss:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format( model.__class__.__name__, i,