added output values clamp for y srnet/srlut

main
protsenkovi 6 months ago
parent 58dab59a11
commit a2542c7b1f

@ -101,7 +101,7 @@ class ConvUpscaleBlock(nn.Module):
x = round_func(x) x = round_func(x)
return x return x
# https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105
class RgbToYcbcr(nn.Module): class RgbToYcbcr(nn.Module):
r"""Convert an image from RGB to YCbCr. r"""Convert an image from RGB to YCbCr.
@ -131,7 +131,6 @@ class RgbToYcbcr(nn.Module):
cr = (r - y) * 0.713 + delta cr = (r - y) * 0.713 + delta
return torch.stack([y, cb, cr], -3) return torch.stack([y, cb, cr], -3)
class YcbcrToRgb(nn.Module): class YcbcrToRgb(nn.Module):
r"""Convert an image from YCbCr to Rgb. r"""Convert an image from YCbCr to Rgb.

@ -87,8 +87,8 @@ def valid_steps(model, datasets, config, log_prefix=""):
'Dataset', 'Dataset',
'AVG PSNR', 'AVG PSNR',
'AVG SSIM', 'AVG SSIM',
f'AVG {config.device} Time, s', f'AVG {config.device} time, s',
f'P95 {config.device} Time, s', f'P95 {config.device} time, s',
'Image count', 'Image count',
'AVG image area', 'AVG image area',
'Total area', 'Total area',

@ -17,6 +17,7 @@ class ImageDemoOptions():
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--project_path', type=str, default="../", help="Project path.") self.parser.add_argument('--project_path', type=str, default="../", help="Project path.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False)
def parse_args(self): def parse_args(self):
args = self.parser.parse_args() args = self.parser.parse_args()
@ -51,8 +52,13 @@ lut_model = LoadCheckpoint(config.lut_model_path).cuda()
print(net_model) print(net_model)
print(lut_model) print(lut_model)
lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1].copy() lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1]
image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1].copy() image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1]
if config.mirror:
lr_image = lr_image[:,::-1,:]
image_gt = image_gt[:,::-1,:]
lr_image = lr_image.copy()
image_gt = image_gt.copy()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda() input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda()

@ -99,14 +99,14 @@ class SRLutRot90Y(nn.Module):
lut_model = SRLutRot90Y(quantization_interval=quantization_interval, scale=scale) lut_model = SRLutRot90Y(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model return lut_model
def forward(self, x): def forward(self, x):
b,c,h,w = x.shape b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x) x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:] y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:] cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4): for rotations_count in range(4):
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3]) rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
@ -115,7 +115,7 @@ class SRLutRot90Y(nn.Module):
output += unrotated_prediction output += unrotated_prediction
output /= 4 output /= 4
output = torch.cat([output, cbcr_scaled], dim=1) output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output) output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output return output
def __repr__(self): def __repr__(self):

@ -106,7 +106,7 @@ class SRNetDenseRot90Y(nn.Module):
output += torch.rot90(rx, k=-rotations_count, dims=[2, 3]) output += torch.rot90(rx, k=-rotations_count, dims=[2, 3])
output /= 4 output /= 4
output = torch.cat([output, cbcr_scaled], dim=1) output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output) output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):

@ -84,10 +84,10 @@ if __name__ == "__main__":
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
) )
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}") results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")
results.to_csv(config.results_path) results.to_csv(config.results_path)
print() print(config.model_name)
print(results) print(results)
print() print()
print(f"Results saved to {config.results_path}") print(f"Results saved to {config.results_path}")

Loading…
Cancel
Save