update models

main
Vladimir 2 months ago
parent 4e1f51278e
commit 6f3229d2f5

@ -1,24 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import layers
import copy
class SRBase(nn.Module):
def __init__(self):
super(SRBase, self).__init__()
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/127.5-127.5, target/127.5-127.5)
return loss_fn
# def get_loss_fn(self):
# ssim_loss = losses.SSIM(data_range=255)
# l1_loss = losses.CharbonnierLoss()
# def loss_fn(pred, target):
# return ssim_loss(pred, target) + l1_loss(pred, target)
# return loss_fn

@ -10,6 +10,7 @@ PIL_CONVERT_COLOR = {
'full_YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image, 'full_YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image,
'full_Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0), 'full_Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
'sdtv_Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0] if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"), 'sdtv_Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0] if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"),
'sdtv2_Y': lambda pil_image: rgb2y(np.array(pil_image)) if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"),
'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image, 'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image,
} }
@ -31,4 +32,88 @@ def _rgb2ycbcr(img, maxVal=255):
t[:, 2] += O[2] t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]]) ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr return ycbcr
def rgb2y(im):
"""
this impl:
0.301, 0.586, 0.113 = 77/256, 150/256, 29/256
-0.172, -0.340, 0.512 = -44/256, -87/256, 131/256
0.512, -0.430, -0.082 = 131/256, -110/256, -21/256
ycbcr 601 sdtv spec[1]:
0.299, 0.587, 0.114
-0.172, -0.339, 0.511
0.511, -0.428, -0.083
ycbcr 601 sdtv spec[2]:
0.299, 0.587, 0.114
-0.169, -0.331, 0.5
0.5, -0.419, -0.081
[1] Video Demystified A Handbook for the Digital Engineer 4th ed - keith Jack, Chapter 3
[2] Color Space Conversions Adrian Ford, Alan Roberts
"""
im = im.astype(np.float32)
R, G, B = im[:,:,0], im[:,:,1], im[:,:,2]
Y = 77/256*R + 150/256*G + 29/256*B
# [1] Note that 8-bit YCbCr and R'G'B' data should be saturated a
return Y.clip(0,255).astype(np.uint8)
def rgb2yuv(im):
"""
this impl:
0.301, 0.586, 0.113 = 77/256, 150/256, 29/256
-0.172, -0.340, 0.512 = -44/256, -87/256, 131/256
0.512, -0.430, -0.082 = 131/256, -110/256, -21/256
ycbcr 601 sdtv spec[1]:
0.299, 0.587, 0.114
-0.172, -0.339, 0.511
0.511, -0.428, -0.083
ycbcr 601 sdtv spec[2]:
0.299, 0.587, 0.114
-0.169, -0.331, 0.5
0.5, -0.419, -0.081
[1] Video Demystified A Handbook for the Digital Engineer 4th ed - keith Jack, Chapter 3
[2] Color Space Conversions Adrian Ford, Alan Roberts
"""
im = im.astype(np.float32)
R, G, B = im[:,:,0], im[:,:,1], im[:,:,2]
Y = 77/256*R + 150/256*G + 29/256*B
U = -44/256*R - 87/256*G + 131/256*B
V = 131/256*R - 110/256*G - 21/256*B
Y, U, V = Y, U + 128, V + 128
# [1] Note that 8-bit YCbCr and R'G'B' data should be saturated at the 0 and 255 levels to avoid underflow and overflow
return np.stack([Y,U,V], axis=-1).clip(0,255).astype(np.uint8)
def yuv2rgb(im):
"""
this impl:
1, 0, 1.406 = 1, 0, 360/256
1, -0.344, -0.719 = 1, -88/256, -184/256
1, 1.777, 0 = 1, 455/256, 0
ycbcr 601 sdtv spec[1]:
1, 0, 1.371
1, -0.336, -0.698
1, 1.732, 0
ycbcr 601 sdtv spec[2]:
1, 0, 1.403
1, -0.344, -0.714
1, 1.773, 0
[1] Video Demystified A Handbook for the Digital Engineer 4th ed - keith Jack, Chapter 3
[2] Color Space Conversions Adrian Ford, Alan Roberts
"""
im = im.astype(np.float32)
Y, Ud, Vd = im[:,:,0], im[:,:,1]-128, im[:,:,2]-128
R = Y + 360/256*Vd
G = Y - 88/256*Ud - 184/256*Vd
B = Y + 455/256*Ud
# [1] Note that 8-bit YCbCr and R'G'B' data should be saturated at the 0 and 255 levels to avoid underflow and overflow
return np.stack([R, G, B], axis=-1).clip(0,255).astype(np.uint8)

@ -60,7 +60,7 @@ class UpscaleBlock(nn.Module):
return x return x
class LinearUpscaleBlockNet(nn.Module): class LinearUpscaleBlockNet(nn.Module):
def __init__(self, in_features=4, out_channels=1, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255): def __init__(self, in_features=4, out_channels=1, hidden_dim = 32, layers_count=6, upscale_factor=1, input_max_value=255, output_max_value=255):
super(LinearUpscaleBlockNet, self).__init__() super(LinearUpscaleBlockNet, self).__init__()
assert layers_count > 0 assert layers_count > 0
self.in_features = in_features self.in_features = in_features
@ -78,7 +78,7 @@ class LinearUpscaleBlockNet(nn.Module):
def forward(self, x): def forward(self, x):
x = (x-self.in_bias)/self.in_scale x = (x-self.in_bias)/self.in_scale
x = torch.nn.functional.gelu(self.embed(x)) x = self.embed(x)
for linear_projection in self.linear_projections: for linear_projection in self.linear_projections:
x = torch.cat([x, torch.nn.functional.gelu(linear_projection(x))], dim=2) x = torch.cat([x, torch.nn.functional.gelu(linear_projection(x))], dim=2)
x = self.project_channels(x) x = self.project_channels(x)

@ -20,7 +20,6 @@ class Transferer():
for attr, value in model.named_children(): for attr, value in model.named_children():
if isinstance(value, layers.UpscaleBlock): if isinstance(value, layers.UpscaleBlock):
getattr(qmodel, attr).stage = getattr(model, attr).stage.get_lut_model(quantization_interval=quantization_interval, batch_size=batch_size) getattr(qmodel, attr).stage = getattr(model, attr).stage.get_lut_model(quantization_interval=quantization_interval, batch_size=batch_size)
return qmodel return qmodel
TRANSFERER = Transferer() TRANSFERER = Transferer()

@ -7,9 +7,18 @@ from common import lut
from pathlib import Path from pathlib import Path
from common import layers from common import layers
from common import losses from common import losses
from common.base import SRBase
from common.transferer import TRANSFERER from common.transferer import TRANSFERER
class SRBase(nn.Module):
def __init__(self):
super(SRBase, self).__init__()
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/127.5-127.5, target/127.5-127.5)
return loss_fn
class SRNetBase(SRBase): class SRNetBase(SRBase):
def __init__(self): def __init__(self):
super(SRNetBase, self).__init__() super(SRNetBase, self).__init__()
@ -39,6 +48,40 @@ class SRLut(SRNetBase):
TRANSFERER.register(SRNet, SRLut) TRANSFERER.register(SRNet, SRLut)
class SRNetR90Base(SRBase):
def __init__(self):
super(SRNetR90Base, self).__init__()
self.config = None
self.stage1_S = layers.UpscaleBlock(None)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, script_config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_output = self.stage1_S(rotated, self._extract_pattern_S)
output += torch.rot90(rotated_output, k=-rotations_count, dims=[2, 3])
x = output / 4
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
return x
class SRNetR90(SRNetR90Base):
def __init__(self, config):
super(SRNetR90, self).__init__()
self.config = config
self.stage1_S.stage = layers.LinearUpscaleBlockNet(hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
class SRLutR90(SRNetBase):
def __init__(self, config):
super(SRLutR90, self).__init__()
self.config = config
self.stage1_S.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
TRANSFERER.register(SRNetR90, SRLutR90)
class ChebyKANBase(SRBase): class ChebyKANBase(SRBase):
def __init__(self): def __init__(self):
super(ChebyKANBase, self).__init__() super(ChebyKANBase, self).__init__()

Loading…
Cancel
Save