From a2542c7b1fbc68a43740243678d09ba4fda6f888 Mon Sep 17 00:00:00 2001 From: protsenkovi Date: Thu, 9 May 2024 15:20:30 +0400 Subject: [PATCH] added output values clamp for y srnet/srlut --- src/common/layers.py | 3 +-- src/common/validation.py | 4 ++-- src/image_demo.py | 10 ++++++++-- src/models/srlut.py | 6 +++--- src/models/srnet.py | 2 +- src/validate.py | 4 ++-- 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/common/layers.py b/src/common/layers.py index f15d7c9..6f4bca3 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -101,7 +101,7 @@ class ConvUpscaleBlock(nn.Module): x = round_func(x) return x - +# https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105 class RgbToYcbcr(nn.Module): r"""Convert an image from RGB to YCbCr. @@ -131,7 +131,6 @@ class RgbToYcbcr(nn.Module): cr = (r - y) * 0.713 + delta return torch.stack([y, cb, cr], -3) - class YcbcrToRgb(nn.Module): r"""Convert an image from YCbCr to Rgb. diff --git a/src/common/validation.py b/src/common/validation.py index db5636a..c06bb1d 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -87,8 +87,8 @@ def valid_steps(model, datasets, config, log_prefix=""): 'Dataset', 'AVG PSNR', 'AVG SSIM', - f'AVG {config.device} Time, s', - f'P95 {config.device} Time, s', + f'AVG {config.device} time, s', + f'P95 {config.device} time, s', 'Image count', 'AVG image area', 'Total area', diff --git a/src/image_demo.py b/src/image_demo.py index 609981d..ba99934 100644 --- a/src/image_demo.py +++ b/src/image_demo.py @@ -17,6 +17,7 @@ class ImageDemoOptions(): self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") self.parser.add_argument('--project_path', type=str, default="../", help="Project path.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") + self.parser.add_argument('--mirror', action='store_true', default=False) def parse_args(self): args = self.parser.parse_args() @@ -51,8 +52,13 @@ lut_model = LoadCheckpoint(config.lut_model_path).cuda() print(net_model) print(lut_model) -lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1].copy() -image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1].copy() +lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1] +image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1] +if config.mirror: + lr_image = lr_image[:,::-1,:] + image_gt = image_gt[:,::-1,:] +lr_image = lr_image.copy() +image_gt = image_gt.copy() input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda() diff --git a/src/models/srlut.py b/src/models/srlut.py index 3c1589c..ed80dda 100644 --- a/src/models/srlut.py +++ b/src/models/srlut.py @@ -99,14 +99,14 @@ class SRLutRot90Y(nn.Module): lut_model = SRLutRot90Y(quantization_interval=quantization_interval, scale=scale) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model - + def forward(self, x): b,c,h,w = x.shape x = self.rgb_to_ycbcr(x) y = x[:,0:1,:,:] cbcr = x[:,1:,:,:] cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') - + output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) for rotations_count in range(4): rotated = torch.rot90(y, k=rotations_count, dims=[2, 3]) @@ -115,7 +115,7 @@ class SRLutRot90Y(nn.Module): output += unrotated_prediction output /= 4 output = torch.cat([output, cbcr_scaled], dim=1) - output = self.ycbcr_to_rgb(output) + output = self.ycbcr_to_rgb(output).clamp(0, 255) return output def __repr__(self): diff --git a/src/models/srnet.py b/src/models/srnet.py index f242da8..86ecce4 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -106,7 +106,7 @@ class SRNetDenseRot90Y(nn.Module): output += torch.rot90(rx, k=-rotations_count, dims=[2, 3]) output /= 4 output = torch.cat([output, cbcr_scaled], dim=1) - output = self.ycbcr_to_rgb(output) + output = self.ycbcr_to_rgb(output).clamp(0, 255) return output def get_lut_model(self, quantization_interval=16, batch_size=2**10): diff --git a/src/validate.py b/src/validate.py index 9ebac2a..630e37c 100644 --- a/src/validate.py +++ b/src/validate.py @@ -84,10 +84,10 @@ if __name__ == "__main__": lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", ) - results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}") + results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") results.to_csv(config.results_path) - print() + print(config.model_name) print(results) print() print(f"Results saved to {config.results_path}")