update models

main
Vladimir 2 months ago
parent 4e1f51278e
commit 6f3229d2f5

@ -1,24 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import layers
import copy
class SRBase(nn.Module):
def __init__(self):
super(SRBase, self).__init__()
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/127.5-127.5, target/127.5-127.5)
return loss_fn
# def get_loss_fn(self):
# ssim_loss = losses.SSIM(data_range=255)
# l1_loss = losses.CharbonnierLoss()
# def loss_fn(pred, target):
# return ssim_loss(pred, target) + l1_loss(pred, target)
# return loss_fn

@ -10,6 +10,7 @@ PIL_CONVERT_COLOR = {
'full_YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image,
'full_Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
'sdtv_Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0] if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"),
'sdtv2_Y': lambda pil_image: rgb2y(np.array(pil_image)) if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"),
'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image,
}
@ -31,4 +32,88 @@ def _rgb2ycbcr(img, maxVal=255):
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr
return ycbcr
def rgb2y(im):
"""
this impl:
0.301, 0.586, 0.113 = 77/256, 150/256, 29/256
-0.172, -0.340, 0.512 = -44/256, -87/256, 131/256
0.512, -0.430, -0.082 = 131/256, -110/256, -21/256
ycbcr 601 sdtv spec[1]:
0.299, 0.587, 0.114
-0.172, -0.339, 0.511
0.511, -0.428, -0.083
ycbcr 601 sdtv spec[2]:
0.299, 0.587, 0.114
-0.169, -0.331, 0.5
0.5, -0.419, -0.081
[1] Video Demystified A Handbook for the Digital Engineer 4th ed - keith Jack, Chapter 3
[2] Color Space Conversions Adrian Ford, Alan Roberts
"""
im = im.astype(np.float32)
R, G, B = im[:,:,0], im[:,:,1], im[:,:,2]
Y = 77/256*R + 150/256*G + 29/256*B
# [1] Note that 8-bit YCbCr and R'G'B' data should be saturated a
return Y.clip(0,255).astype(np.uint8)
def rgb2yuv(im):
"""
this impl:
0.301, 0.586, 0.113 = 77/256, 150/256, 29/256
-0.172, -0.340, 0.512 = -44/256, -87/256, 131/256
0.512, -0.430, -0.082 = 131/256, -110/256, -21/256
ycbcr 601 sdtv spec[1]:
0.299, 0.587, 0.114
-0.172, -0.339, 0.511
0.511, -0.428, -0.083
ycbcr 601 sdtv spec[2]:
0.299, 0.587, 0.114
-0.169, -0.331, 0.5
0.5, -0.419, -0.081
[1] Video Demystified A Handbook for the Digital Engineer 4th ed - keith Jack, Chapter 3
[2] Color Space Conversions Adrian Ford, Alan Roberts
"""
im = im.astype(np.float32)
R, G, B = im[:,:,0], im[:,:,1], im[:,:,2]
Y = 77/256*R + 150/256*G + 29/256*B
U = -44/256*R - 87/256*G + 131/256*B
V = 131/256*R - 110/256*G - 21/256*B
Y, U, V = Y, U + 128, V + 128
# [1] Note that 8-bit YCbCr and R'G'B' data should be saturated at the 0 and 255 levels to avoid underflow and overflow
return np.stack([Y,U,V], axis=-1).clip(0,255).astype(np.uint8)
def yuv2rgb(im):
"""
this impl:
1, 0, 1.406 = 1, 0, 360/256
1, -0.344, -0.719 = 1, -88/256, -184/256
1, 1.777, 0 = 1, 455/256, 0
ycbcr 601 sdtv spec[1]:
1, 0, 1.371
1, -0.336, -0.698
1, 1.732, 0
ycbcr 601 sdtv spec[2]:
1, 0, 1.403
1, -0.344, -0.714
1, 1.773, 0
[1] Video Demystified A Handbook for the Digital Engineer 4th ed - keith Jack, Chapter 3
[2] Color Space Conversions Adrian Ford, Alan Roberts
"""
im = im.astype(np.float32)
Y, Ud, Vd = im[:,:,0], im[:,:,1]-128, im[:,:,2]-128
R = Y + 360/256*Vd
G = Y - 88/256*Ud - 184/256*Vd
B = Y + 455/256*Ud
# [1] Note that 8-bit YCbCr and R'G'B' data should be saturated at the 0 and 255 levels to avoid underflow and overflow
return np.stack([R, G, B], axis=-1).clip(0,255).astype(np.uint8)

@ -60,7 +60,7 @@ class UpscaleBlock(nn.Module):
return x
class LinearUpscaleBlockNet(nn.Module):
def __init__(self, in_features=4, out_channels=1, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
def __init__(self, in_features=4, out_channels=1, hidden_dim = 32, layers_count=6, upscale_factor=1, input_max_value=255, output_max_value=255):
super(LinearUpscaleBlockNet, self).__init__()
assert layers_count > 0
self.in_features = in_features
@ -78,7 +78,7 @@ class LinearUpscaleBlockNet(nn.Module):
def forward(self, x):
x = (x-self.in_bias)/self.in_scale
x = torch.nn.functional.gelu(self.embed(x))
x = self.embed(x)
for linear_projection in self.linear_projections:
x = torch.cat([x, torch.nn.functional.gelu(linear_projection(x))], dim=2)
x = self.project_channels(x)

@ -20,7 +20,6 @@ class Transferer():
for attr, value in model.named_children():
if isinstance(value, layers.UpscaleBlock):
getattr(qmodel, attr).stage = getattr(model, attr).stage.get_lut_model(quantization_interval=quantization_interval, batch_size=batch_size)
return qmodel
TRANSFERER = Transferer()

@ -7,9 +7,18 @@ from common import lut
from pathlib import Path
from common import layers
from common import losses
from common.base import SRBase
from common.transferer import TRANSFERER
class SRBase(nn.Module):
def __init__(self):
super(SRBase, self).__init__()
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/127.5-127.5, target/127.5-127.5)
return loss_fn
class SRNetBase(SRBase):
def __init__(self):
super(SRNetBase, self).__init__()
@ -39,6 +48,40 @@ class SRLut(SRNetBase):
TRANSFERER.register(SRNet, SRLut)
class SRNetR90Base(SRBase):
def __init__(self):
super(SRNetR90Base, self).__init__()
self.config = None
self.stage1_S = layers.UpscaleBlock(None)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, script_config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_output = self.stage1_S(rotated, self._extract_pattern_S)
output += torch.rot90(rotated_output, k=-rotations_count, dims=[2, 3])
x = output / 4
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
return x
class SRNetR90(SRNetR90Base):
def __init__(self, config):
super(SRNetR90, self).__init__()
self.config = config
self.stage1_S.stage = layers.LinearUpscaleBlockNet(hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
class SRLutR90(SRNetBase):
def __init__(self, config):
super(SRLutR90, self).__init__()
self.config = config
self.stage1_S.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
TRANSFERER.register(SRNetR90, SRLutR90)
class ChebyKANBase(SRBase):
def __init__(self):
super(ChebyKANBase, self).__init__()

Loading…
Cancel
Save