diff --git a/src/common/base.py b/src/common/base.py deleted file mode 100644 index 3e1f08b..0000000 --- a/src/common/base.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from common.utils import round_func -from common import layers -import copy - -class SRBase(nn.Module): - def __init__(self): - super(SRBase, self).__init__() - - def get_loss_fn(self): - def loss_fn(pred, target): - return F.mse_loss(pred/127.5-127.5, target/127.5-127.5) - return loss_fn - - # def get_loss_fn(self): - # ssim_loss = losses.SSIM(data_range=255) - # l1_loss = losses.CharbonnierLoss() - # def loss_fn(pred, target): - # return ssim_loss(pred, target) + l1_loss(pred, target) - # return loss_fn - diff --git a/src/common/color.py b/src/common/color.py index b06a928..dfcff67 100644 --- a/src/common/color.py +++ b/src/common/color.py @@ -10,6 +10,7 @@ PIL_CONVERT_COLOR = { 'full_YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image, 'full_Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0), 'sdtv_Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0] if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"), + 'sdtv2_Y': lambda pil_image: rgb2y(np.array(pil_image)) if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"), 'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image, } @@ -31,4 +32,88 @@ def _rgb2ycbcr(img, maxVal=255): t[:, 2] += O[2] ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]]) - return ycbcr \ No newline at end of file + return ycbcr + + +def rgb2y(im): + """ + this impl: + 0.301, 0.586, 0.113 = 77/256, 150/256, 29/256 + -0.172, -0.340, 0.512 = -44/256, -87/256, 131/256 + 0.512, -0.430, -0.082 = 131/256, -110/256, -21/256 + + ycbcr 601 sdtv spec[1]: + 0.299, 0.587, 0.114 + -0.172, -0.339, 0.511 + 0.511, -0.428, -0.083 + + ycbcr 601 sdtv spec[2]: + 0.299, 0.587, 0.114 + -0.169, -0.331, 0.5 + 0.5, -0.419, -0.081 + + [1] Video Demystified A Handbook for the Digital Engineer 4th ed - keith Jack, Chapter 3 + [2] Color Space Conversions Adrian Ford, Alan Roberts + """ + im = im.astype(np.float32) + R, G, B = im[:,:,0], im[:,:,1], im[:,:,2] + Y = 77/256*R + 150/256*G + 29/256*B + # [1] Note that 8-bit YCbCr and R'G'B' data should be saturated a + return Y.clip(0,255).astype(np.uint8) + +def rgb2yuv(im): + """ + this impl: + 0.301, 0.586, 0.113 = 77/256, 150/256, 29/256 + -0.172, -0.340, 0.512 = -44/256, -87/256, 131/256 + 0.512, -0.430, -0.082 = 131/256, -110/256, -21/256 + + ycbcr 601 sdtv spec[1]: + 0.299, 0.587, 0.114 + -0.172, -0.339, 0.511 + 0.511, -0.428, -0.083 + + ycbcr 601 sdtv spec[2]: + 0.299, 0.587, 0.114 + -0.169, -0.331, 0.5 + 0.5, -0.419, -0.081 + + [1] Video Demystified A Handbook for the Digital Engineer 4th ed - keith Jack, Chapter 3 + [2] Color Space Conversions Adrian Ford, Alan Roberts + """ + im = im.astype(np.float32) + R, G, B = im[:,:,0], im[:,:,1], im[:,:,2] + Y = 77/256*R + 150/256*G + 29/256*B + U = -44/256*R - 87/256*G + 131/256*B + V = 131/256*R - 110/256*G - 21/256*B + Y, U, V = Y, U + 128, V + 128 + # [1] Note that 8-bit YCbCr and R'G'B' data should be saturated at the 0 and 255 levels to avoid underflow and overflow + return np.stack([Y,U,V], axis=-1).clip(0,255).astype(np.uint8) + +def yuv2rgb(im): + """ + this impl: + 1, 0, 1.406 = 1, 0, 360/256 + 1, -0.344, -0.719 = 1, -88/256, -184/256 + 1, 1.777, 0 = 1, 455/256, 0 + + ycbcr 601 sdtv spec[1]: + 1, 0, 1.371 + 1, -0.336, -0.698 + 1, 1.732, 0 + + ycbcr 601 sdtv spec[2]: + 1, 0, 1.403 + 1, -0.344, -0.714 + 1, 1.773, 0 + + [1] Video Demystified A Handbook for the Digital Engineer 4th ed - keith Jack, Chapter 3 + [2] Color Space Conversions Adrian Ford, Alan Roberts + """ + im = im.astype(np.float32) + Y, Ud, Vd = im[:,:,0], im[:,:,1]-128, im[:,:,2]-128 + R = Y + 360/256*Vd + G = Y - 88/256*Ud - 184/256*Vd + B = Y + 455/256*Ud + # [1] Note that 8-bit YCbCr and R'G'B' data should be saturated at the 0 and 255 levels to avoid underflow and overflow + return np.stack([R, G, B], axis=-1).clip(0,255).astype(np.uint8) \ No newline at end of file diff --git a/src/common/layers.py b/src/common/layers.py index e3a052b..0440296 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -60,7 +60,7 @@ class UpscaleBlock(nn.Module): return x class LinearUpscaleBlockNet(nn.Module): - def __init__(self, in_features=4, out_channels=1, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255): + def __init__(self, in_features=4, out_channels=1, hidden_dim = 32, layers_count=6, upscale_factor=1, input_max_value=255, output_max_value=255): super(LinearUpscaleBlockNet, self).__init__() assert layers_count > 0 self.in_features = in_features @@ -78,7 +78,7 @@ class LinearUpscaleBlockNet(nn.Module): def forward(self, x): x = (x-self.in_bias)/self.in_scale - x = torch.nn.functional.gelu(self.embed(x)) + x = self.embed(x) for linear_projection in self.linear_projections: x = torch.cat([x, torch.nn.functional.gelu(linear_projection(x))], dim=2) x = self.project_channels(x) diff --git a/src/common/transferer.py b/src/common/transferer.py index 34496c6..ffad1b8 100644 --- a/src/common/transferer.py +++ b/src/common/transferer.py @@ -20,7 +20,6 @@ class Transferer(): for attr, value in model.named_children(): if isinstance(value, layers.UpscaleBlock): getattr(qmodel, attr).stage = getattr(model, attr).stage.get_lut_model(quantization_interval=quantization_interval, batch_size=batch_size) - return qmodel TRANSFERER = Transferer() \ No newline at end of file diff --git a/src/models/models.py b/src/models/models.py index 8713321..9a83394 100644 --- a/src/models/models.py +++ b/src/models/models.py @@ -7,9 +7,18 @@ from common import lut from pathlib import Path from common import layers from common import losses -from common.base import SRBase from common.transferer import TRANSFERER +class SRBase(nn.Module): + def __init__(self): + super(SRBase, self).__init__() + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/127.5-127.5, target/127.5-127.5) + return loss_fn + + class SRNetBase(SRBase): def __init__(self): super(SRNetBase, self).__init__() @@ -39,6 +48,40 @@ class SRLut(SRNetBase): TRANSFERER.register(SRNet, SRLut) +class SRNetR90Base(SRBase): + def __init__(self): + super(SRNetR90Base, self).__init__() + self.config = None + self.stage1_S = layers.UpscaleBlock(None) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + + def forward(self, x, script_config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + rotated_output = self.stage1_S(rotated, self._extract_pattern_S) + output += torch.rot90(rotated_output, k=-rotations_count, dims=[2, 3]) + x = output / 4 + x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor) + return x + +class SRNetR90(SRNetR90Base): + def __init__(self, config): + super(SRNetR90, self).__init__() + self.config = config + self.stage1_S.stage = layers.LinearUpscaleBlockNet(hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + +class SRLutR90(SRNetBase): + def __init__(self, config): + super(SRLutR90, self).__init__() + self.config = config + self.stage1_S.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor) + +TRANSFERER.register(SRNetR90, SRLutR90) + + class ChebyKANBase(SRBase): def __init__(self): super(ChebyKANBase, self).__init__()