diff --git a/src/common/data.py b/src/common/data.py index ff3214a..e12c839 100644 --- a/src/common/data.py +++ b/src/common/data.py @@ -50,17 +50,17 @@ class SRTrainDataset(Dataset): i = random.randint(0, lr_image.shape[0] - self.sz) j = random.randint(0, lr_image.shape[1] - self.sz) - c = random.choice([0, 1, 2]) + # c = random.choice([0, 1, 2]) hr_patch = hr_image[ (i*scale):(i*scale + self.sz*scale), (j*scale):(j*scale + self.sz*scale), - c:(c+1) + : ] lr_patch = lr_image[ i:(i + self.sz), j:(j + self.sz), - c:(c+1) + : ] if self.rigid_aug: diff --git a/src/common/layers.py b/src/common/layers.py index 54cf3f1..f15d7c9 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -99,4 +99,67 @@ class ConvUpscaleBlock(nn.Module): x = torch.tanh(x) x = x*127.5 + 127.5 x = round_func(x) - return x \ No newline at end of file + return x + + +class RgbToYcbcr(nn.Module): + r"""Convert an image from RGB to YCbCr. + + The image data is assumed to be in the range of (0, 1). + + Returns: + YCbCr version of the image. + + Shape: + - image: :math:`(*, 3, H, W)` + - output: :math:`(*, 3, H, W)` + + Examples: + >>> input = torch.rand(2, 3, 4, 5) + >>> ycbcr = RgbToYcbcr() + >>> output = ycbcr(input) # 2x3x4x5 + """ + + def forward(self, image): + r = image[..., 0, :, :] + g = image[..., 1, :, :] + b = image[..., 2, :, :] + + delta = 0.5 + y = 0.299 * r + 0.587 * g + 0.114 * b + cb = (b - y) * 0.564 + delta + cr = (r - y) * 0.713 + delta + return torch.stack([y, cb, cr], -3) + + +class YcbcrToRgb(nn.Module): + r"""Convert an image from YCbCr to Rgb. + + The image data is assumed to be in the range of (0, 1). + + Returns: + RGB version of the image. + + Shape: + - image: :math:`(*, 3, H, W)` + - output: :math:`(*, 3, H, W)` + + Examples: + >>> input = torch.rand(2, 3, 4, 5) + >>> rgb = YcbcrToRgb() + >>> output = rgb(input) # 2x3x4x5 + """ + + def forward(self, image): + y = image[..., 0, :, :] + cb = image[..., 1, :, :] + cr = image[..., 2, :, :] + + delta = 0.5 + cb_shifted = cb - delta + cr_shifted = cr - delta + + r = y + 1.403 * cr_shifted + g = y - 0.714 * cr_shifted - 0.344 * cb_shifted + b = y + 1.773 * cb_shifted + return torch.stack([r, g, b], -3) diff --git a/src/common/utils.py b/src/common/utils.py index 01a4e66..be0811c 100644 --- a/src/common/utils.py +++ b/src/common/utils.py @@ -6,6 +6,7 @@ from scipy import signal import torch import os + def round_func(input): # Backward Pass Differentiable Approximation (BPDA) # This is equivalent to replacing round function (non-differentiable) diff --git a/src/common/validation.py b/src/common/validation.py index c8d0c7f..db5636a 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -14,11 +14,11 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cu lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1) lr_image = lr_image.unsqueeze(0).to(torch.device(device)) b, c, h, w = lr_image.shape - lr_image = lr_image.reshape(b*c, 1, h, w) + lr_image = lr_image.reshape(b, c, h, w) # predict pred_lr_image = model(lr_image) # postprocess - pred_lr_image = pred_lr_image.reshape(b, c, h*model.scale, w*model.scale).squeeze(0).permute(1,2,0).type(torch.uint8) + pred_lr_image = pred_lr_image.squeeze(0).permute(1,2,0).type(torch.uint8) pred_lr_image = pred_lr_image.cpu().numpy() run_time_ns = time.perf_counter_ns() - start_time torch.cuda.empty_cache() diff --git a/src/models/__init__.py b/src/models/__init__.py index 42a48db..26e5516 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -11,6 +11,7 @@ AVAILABLE_MODELS = { 'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut, 'SRNetDense': srnet.SRNetDense, 'SRNetDenseRot90': srnet.SRNetDenseRot90, 'SRLutRot90': srlut.SRLutRot90, + 'SRNetDenseRot90Y': srnet.SRNetDenseRot90Y, 'SRLutRot90Y': srlut.SRLutRot90Y, 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, diff --git a/src/models/srlut.py b/src/models/srlut.py index 4a35e5c..3c1589c 100644 --- a/src/models/srlut.py +++ b/src/models/srlut.py @@ -4,6 +4,7 @@ import torch.nn.functional as F import numpy as np from pathlib import Path from common.lut import forward_2x2_input_SxS_output +from common.layers import RgbToYcbcr, YcbcrToRgb class SRLut(nn.Module): def __init__( @@ -72,5 +73,50 @@ class SRLutRot90(nn.Module): output = output.view(b, c, h*self.scale, w*self.scale) return output + def __repr__(self): + return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" + + +class SRLutRot90Y(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(SRLutRot90Y, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.rgb_to_ycbcr = RgbToYcbcr() + self.ycbcr_to_rgb = YcbcrToRgb() + + @staticmethod + def init_from_lut( + stage_lut + ): + scale = int(stage_lut.shape[-1]) + quantization_interval = 256//(stage_lut.shape[0]-1) + lut_model = SRLutRot90Y(quantization_interval=quantization_interval, scale=scale) + lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = self.rgb_to_ycbcr(x) + y = x[:,0:1,:,:] + cbcr = x[:,1:,:,:] + cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') + + output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(y, k=rotations_count, dims=[2, 3]) + rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut) + unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) + output += unrotated_prediction + output /= 4 + output = torch.cat([output, cbcr_scaled], dim=1) + output = self.ycbcr_to_rgb(output) + return output + def __repr__(self): return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" \ No newline at end of file diff --git a/src/models/srnet.py b/src/models/srnet.py index e52d53c..f242da8 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -5,8 +5,9 @@ import numpy as np from common.utils import round_func from common import lut from pathlib import Path -from .srlut import SRLut, SRLutRot90 -from common.layers import PercievePattern, DenseConvUpscaleBlock, ConvUpscaleBlock +from . import srlut +from common.layers import PercievePattern, DenseConvUpscaleBlock, ConvUpscaleBlock, RgbToYcbcr, YcbcrToRgb + class SRNet(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): @@ -26,7 +27,7 @@ class SRNet(nn.Module): def get_lut_model(self, quantization_interval=16, batch_size=2**10): stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = SRLut.init_from_lut(stage_lut) + lut_model = srlut.SRLut.init_from_lut(stage_lut) return lut_model @@ -48,7 +49,7 @@ class SRNetDense(nn.Module): def get_lut_model(self, quantization_interval=16, batch_size=2**10): stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = SRLut.init_from_lut(stage_lut) + lut_model = srlut.SRLut.init_from_lut(stage_lut) return lut_model class SRNetDenseRot90(nn.Module): @@ -75,5 +76,40 @@ class SRNetDenseRot90(nn.Module): def get_lut_model(self, quantization_interval=16, batch_size=2**10): stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = SRLutRot90.init_from_lut(stage_lut) - return lut_model \ No newline at end of file + lut_model = srlut.SRLutRot90.init_from_lut(stage_lut) + return lut_model + +class SRNetDenseRot90Y(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SRNetDenseRot90Y, self).__init__() + self.scale = scale + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.rgb_to_ycbcr = RgbToYcbcr() + self.ycbcr_to_rgb = YcbcrToRgb() + + def forward(self, x): + b,c,h,w = x.shape + x = self.rgb_to_ycbcr(x) + y = x[:,0:1,:,:] + cbcr = x[:,1:,:,:] + cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') + + x = y.view(b, 1, h, w) + output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rx = torch.rot90(x, k=rotations_count, dims=[2, 3]) + _,_,rh,rw = rx.shape + rx = self._extract_pattern_S(rx) + rx = self.stage(rx) + rx = rx.view(b, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(b, 1, rh*self.scale, rw*self.scale) + output += torch.rot90(rx, k=-rotations_count, dims=[2, 3]) + output /= 4 + output = torch.cat([output, cbcr_scaled], dim=1) + output = self.ycbcr_to_rgb(output) + return output + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = srlut.SRLutRot90Y.init_from_lut(stage_lut) + return lut_model \ No newline at end of file diff --git a/src/train.py b/src/train.py index b797a5b..8a97294 100644 --- a/src/train.py +++ b/src/train.py @@ -91,21 +91,6 @@ def prepare_experiment_folder(config): if not config.logs_dir.exists(): config.logs_dir.mkdir() - - -def dice_loss(inputs, targets, smooth=1): - #comment out if your model contains a sigmoid or equivalent activation layer - inputs = F.sigmoid(inputs) - - #flatten label and prediction tensors - inputs = inputs.view(-1) - targets = targets.view(-1) - - intersection = (inputs * targets).sum() - dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) - return 1 - dice - - if __name__ == "__main__": script_start_time = datetime.now()