new framework impl. added chebykan and linear models.

main
protsenkovi 10 months ago
parent 67bd678763
commit d34cc7833e

File diff suppressed because one or more lines are too long

@ -24,3 +24,10 @@ Requirements:
``` ```
pip install schedulefree tensorboard opencv-python-headless scipy pandas matplotlib pip install schedulefree tensorboard opencv-python-headless scipy pandas matplotlib
``` ```
python train.py --model ChebyKANNet --total_iter 200000 --train_datasets mix-01-blur-1+ffmpeg --test_datasets mix-01-blur-test --upscale_factor 1 --val_step 210000 --layers_count 4
python transfer_to_lut.py --model_path ../experiments/ChebyKANNet_RGB_mix-01-blur-1+ffmpeg_x1/checkpoints/ChebyKANNet_200000.pth
python test.py --model_path ../experiments/last_transfered_lut.pth --test_datasets src --save_predictions
python test.py --model_path ../--test_datasets src --save_predictions

@ -0,0 +1,24 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import layers
import copy
class SRBase(nn.Module):
def __init__(self):
super(SRBase, self).__init__()
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
# def get_loss_fn(self):
# ssim_loss = losses.SSIM(data_range=255)
# l1_loss = losses.CharbonnierLoss()
# def loss_fn(pred, target):
# return ssim_loss(pred, target) + l1_loss(pred, target)
# return loss_fn

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from .utils import round_func from .utils import round_func
from . import lut
class PercievePattern(): class PercievePattern():
""" """
@ -11,7 +12,7 @@ class PercievePattern():
1. receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]] 1. receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]
2. receptive_field_idxes=[[0,0,0],[0,1,0],[1,1,0],[1,1]] 2. receptive_field_idxes=[[0,0,0],[0,1,0],[1,1,0],[1,1]]
""" """
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, channels=1): def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, channels=1, mode='replicate'):
assert window_size >= (np.max([x for y in receptive_field_idxes for x in y])+1) assert window_size >= (np.max([x for y in receptive_field_idxes for x in y])+1)
tmp = [] tmp = []
for coords in receptive_field_idxes: for coords in receptive_field_idxes:
@ -21,6 +22,7 @@ class PercievePattern():
if len(coords) == 3: if len(coords) == 3:
tmp.append(coords) tmp.append(coords)
receptive_field_idxes = np.array(sorted(tmp)) receptive_field_idxes = np.array(sorted(tmp))
self.mode = mode
self.window_size = window_size self.window_size = window_size
self.center = center self.center = center
self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size*self.window_size + receptive_field_idxes[i,1]*self.window_size + receptive_field_idxes[i,2] for i in range(len(receptive_field_idxes))] self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size*self.window_size + receptive_field_idxes[i,1]*self.window_size + receptive_field_idxes[i,2] for i in range(len(receptive_field_idxes))]
@ -28,33 +30,49 @@ class PercievePattern():
def __call__(self, x): def __call__(self, x):
b,c,h,w = x.shape b,c,h,w = x.shape
if not self.mode is None:
x = F.pad( x = F.pad(
x, x,
pad=[ pad=[
self.center[0], self.window_size-self.center[0]-1, self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1 self.center[1], self.window_size-self.center[1]-1
], ],
mode='replicate' mode=self.mode
) )
x = F.unfold(input=x, kernel_size=self.window_size) x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([x[:,self.receptive_field_idxes[i],:] for i in range(len(self.receptive_field_idxes))], 2) x = torch.stack([x[:,self.receptive_field_idxes[i],:] for i in range(len(self.receptive_field_idxes))], 2)
return x return x
class UpscaleBlock(nn.Module): class UpscaleBlock(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255): def __init__(self, stage):
super(UpscaleBlock, self).__init__() super(UpscaleBlock, self).__init__()
self.stage = stage
def forward(self, x, percieve_pattern):
b,c,h,w = x.shape
upscale_factor = self.stage.upscale_factor
out_channels = self.stage.out_channels
x = percieve_pattern(x)
x = self.stage(x)
x = x.reshape(b, h, w, out_channels, upscale_factor, upscale_factor)
x = x.permute(0, 3, 1, 4, 2, 5)
x = x.reshape(b, out_channels, h*upscale_factor, w*upscale_factor)
return x
class LinearUpscaleBlockNet(nn.Module):
def __init__(self, in_features=4, out_channels=1, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
super(LinearUpscaleBlockNet, self).__init__()
assert layers_count > 0 assert layers_count > 0
self.in_features = in_features
self.upscale_factor = upscale_factor self.upscale_factor = upscale_factor
self.out_channels = out_channels
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True) self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
self.linear_projections = [] self.linear_projections = []
for i in range(layers_count): for i in range(layers_count):
self.linear_projections.append(nn.Linear(in_features=(i+1)*hidden_dim, out_features=hidden_dim, bias=True)) self.linear_projections.append(nn.Linear(in_features=(i+1)*hidden_dim, out_features=hidden_dim, bias=True))
self.linear_projections = nn.ModuleList(self.linear_projections) self.linear_projections = nn.ModuleList(self.linear_projections)
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=out_channels*upscale_factor*upscale_factor, bias=True)
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
self.in_bias = self.in_scale = input_max_value/2 self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2 self.out_bias = self.out_scale = output_max_value/2
@ -66,334 +84,67 @@ class UpscaleBlock(nn.Module):
x = self.project_channels(x) x = self.project_channels(x)
x = torch.tanh(x) x = torch.tanh(x)
x = x*self.out_scale + self.out_bias x = x*self.out_scale + self.out_bias
x = round_func(x)
return x return x
class RgbToYcbcr(nn.Module): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
r"""Convert an image from RGB to YCbCr. stage = lut.transfer_N_input_SxS_output(self, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = LinearUpscaleBlockLut.init_from_numpy(stage)
The image data is assumed to be in the range of (0, 1). return lut_model
Returns:
YCbCr version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> ycbcr = RgbToYcbcr()
>>> output = ycbcr(input) # 2x3x4x5
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105
"""
def forward(self, image):
r = image[..., 0, :, :]
g = image[..., 1, :, :]
b = image[..., 2, :, :]
delta = 0.5
y = 0.299 * r + 0.587 * g + 0.114 * b
cb = (b - y) * 0.564 + delta
cr = (r - y) * 0.713 + delta
return torch.stack([y, cb, cr], -3)
class YcbcrToRgb(nn.Module):
r"""Convert an image from YCbCr to Rgb.
The image data is assumed to be in the range of (0, 1).
Returns:
RGB version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> rgb = YcbcrToRgb()
>>> output = rgb(input) # 2x3x4x5
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L127
"""
def forward(self, image):
y = image[..., 0, :, :]
cb = image[..., 1, :, :]
cr = image[..., 2, :, :]
delta = 0.5
cb_shifted = cb - delta
cr_shifted = cr - delta
r = y + 1.403 * cr_shifted
g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
b = y + 1.773 * cb_shifted
return torch.stack([r, g, b], -3)
class LinearUpscaleBlockLut(nn.Module):
def __init__(self, quantization_interval, upscale_factor):
super(LinearUpscaleBlockLut, self).__init__()
self.out_channels = 1
self.upscale_factor = upscale_factor
self.quantization_interval = quantization_interval
self.stage = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (upscale_factor,upscale_factor)).type(torch.float32))
class EffKAN(torch.nn.Module): @staticmethod
""" def init_from_numpy(
https://github.com/Blealtan/efficient-kan/blob/605b8c2ae24b8085bd11b96896a17eabe12f6736/src/efficient_kan/kan.py#L6 stage
"""
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
): ):
super(EffKAN, self).__init__() upscale_factor = int(stage.shape[-1])
self.in_features = in_features quantization_interval = 256//(stage.shape[0]-1)
self.out_features = out_features lut_model = LinearUpscaleBlockLut(quantization_interval=quantization_interval, upscale_factor=upscale_factor)
self.grid_size = grid_size lut_model.stage = nn.Parameter(torch.tensor(stage).type(torch.float32))
self.spline_order = spline_order return lut_model
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
Returns: def forward(self, x):
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). x = lut.select_index_4dlut_tetrahedral(x, self.stage)
""" return x
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
assert x.size(-1) == self.in_features
original_shape = x.shape
x = x.view(-1, self.in_features)
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
output = base_output + spline_output
output = output.view(*original_shape[:-1], self.out_features)
return output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)
splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): def __repr__(self):
""" return f"{self.__class__.__name__}\n lut size: {self.stage.shape}"
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients # This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
class ChebyKANLayer(nn.Module): class ChebyKANLayer(nn.Module):
""" """
https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py
""" """
def __init__(self, in_features, out_features, degree=8): def __init__(self, in_features, out_features, degree=8, input_max_value=255, output_max_value=255):
super(ChebyKANLayer, self).__init__() super(ChebyKANLayer, self).__init__()
self.inputdim = in_features self.in_features = in_features
self.outdim = out_features self.out_features = out_features
self.degree = degree self.degree = degree
self.cheby_coeffs = nn.Parameter(torch.empty(in_features, out_features, degree + 1)) self.cheby_coeffs = nn.Parameter(torch.empty(in_features, out_features, degree + 1))
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (in_features * (degree + 1))) nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (in_features * (degree + 1)))
self.register_buffer("arange", torch.arange(0, degree + 1, 1)) self.register_buffer("arange", torch.arange(0, degree + 1, 1))
def forward(self, x): self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
def forward_all_to_all(self, x):
# Since Chebyshev polynomial is defined in [-1, 1] # Since Chebyshev polynomial is defined in [-1, 1]
# We need to normalize x to [-1, 1] using tanh # We need to normalize x to [-1, 1] using tanh
b,hw,c = x.shape b, hw, c = x.shape
x = torch.tanh(x) assert c == self.in_features, f"Input features count={c} is not equal to specified for this layer count={self.in_features}."
x = (x-self.in_bias)/self.in_scale
# x = torch.tanh(x)
# View and repeat input degree + 1 times # View and repeat input degree + 1 times
x = x.view((b, hw, self.inputdim, 1)).expand( x = x.view((b, hw, self.in_features, 1)).expand(
-1, -1, -1, self.degree + 1 -1, -1, -1, self.degree + 1
) # shape = (batch_size, inputdim, self.degree + 1) ) # shape = (batch_size, inputdim, self.degree + 1)
# Apply acos # Apply acos
@ -404,157 +155,152 @@ class ChebyKANLayer(nn.Module):
x = x.cos() x = x.cos()
# Compute the Chebyshev interpolation # Compute the Chebyshev interpolation
y = torch.einsum( y = torch.einsum(
"btid,iod->bto", x, self.cheby_coeffs "btid,iod->btio", x, self.cheby_coeffs
) # shape = (batch_size, hw, outdim) )
# shape = (batch_size, hw, outdim)
y = y.view(b, hw, self.outdim) # print("btio", y.shape)
y = torch.tanh(y)
# y = (y-1) / self.degree
y = y*self.out_scale + self.out_bias
y = y.view(b, hw, self.in_features, self.out_features)
# y = y.clamp(0,255)
# print("btio", y.shape)
return y return y
def forward(self, x):
out = self.forward_all_to_all(x)
# print("net", out.min(), out.max())
out = (out - 127.5)/127.5
out = out.sum(dim=-2)
# print("net", out.min(), out.max())
# out = torch.tanh(out)
out = out / self.in_features
out = out*127.5 + 127.5
return out
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
in_features = self.in_features
out_features = self.out_features
domain_values = torch.cat([torch.arange(0, 256, quantization_interval, dtype=torch.uint8), torch.tensor([255])])
inputs = domain_values.type(torch.float32).to(self.cheby_coeffs.device)
lut = np.full((in_features, out_features, bucket_count+1), dtype=np.uint8, fill_value=255)
# lut_min = np.full((in_features, out_features), dtype=np.uint8, fill_value=0)
# lut_length = np.full((in_features, out_features), dtype=np.uint8, fill_value=255)
# print(in_features, out_features)
qmodel = ChebyKANLut(in_features=self.in_features, out_features=self.out_features, quantization_interval=quantization_interval)
with torch.no_grad():
for d in range(len(domain_values)):
out = self.forward_all_to_all(inputs[d].view(1,1,1).expand(1,1,in_features)).squeeze(0).squeeze(1).cpu().numpy()
lut[:,:,d] = out.astype(np.uint8)
# lut_min[:,:] = lut.min(axis=-1).astype(np.uint8)
# lut_length[:,:] = lut.max(axis=-1).astype(np.uint8)-lut.min(axis=-1).astype(np.uint8)
qmodel.lut = nn.Parameter(torch.tensor(lut).type(torch.float32))
# qmodel.lut_min = nn.Parameter(torch.tensor(lut_min).type(torch.float32))
# qmodel.lut_length = nn.Parameter(torch.tensor(lut_length).type(torch.float32))
return qmodel
def __repr__(self):
return (f"{self.__class__.__name__}\n"
f"cheby coefs size: {self.cheby_coeffs.shape}\n"
f"in/out bias, scale: {self.in_bias}, {self.in_scale} / {self.out_bias}, {self.out_scale}" )
class ChebyKANLut(nn.Module):
def __init__(self, in_features, out_features, quantization_interval=16):
super(ChebyKANLut, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.lut = nn.Parameter(
torch.randint(0, 255, size=(in_features, out_features, 256//quantization_interval+1)).type(torch.float32)
)
# self.lut_min = nn.Parameter(torch.ones((in_features, out_features)).type(torch.float32)*0)
# self.lut_max = nn.Parameter(torch.ones((in_features, out_features)).type(torch.float32)*255)
class UpscaleBlockChebyKAN(nn.Module): def forward_all_to_all(self, x):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255, degree=8): b,w,c = x.shape
super(UpscaleBlockChebyKAN, self).__init__() out = torch.zeros([b, w, self.in_features, self.out_features], dtype=x.dtype, device=x.device)
assert layers_count > 0 for i in range(self.in_features):
self.upscale_factor = upscale_factor for j in range(self.out_features):
self.hidden_dim = hidden_dim out[:,:, i, j:(j+1)] = lut.select_index_1dlut_linear(x[:,:,i:(i+1)], self.lut[i,j])
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True) return out
self.linear_projections = [] def forward(self, x):
for i in range(layers_count): out = self.forward_all_to_all(x)
self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, degree=degree)) out = (out - 127.5)/127.5
self.linear_projections = nn.ModuleList(self.linear_projections) out = out.sum(dim=-2)
# out = torch.tanh(out)
out = out / self.in_features
out = out*127.5 + 127.5
return out
self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.lut.shape}"
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
def forward(self, x):
x = (x-self.in_bias)/self.in_scale
x = self.embed(x)
for linear_projection in self.linear_projections:
x = self.layer_norm(linear_projection(x))
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
return x
class UpscaleBlockEffKAN(nn.Module): class ChebyKANUpscaleBlockNet(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255): def __init__(self, in_features=4, out_channels=1, hidden_dim=64, layers_count=2, upscale_factor=1, input_max_value=255, output_max_value=255, degree=8):
super(UpscaleBlockEffKAN, self).__init__() super(ChebyKANUpscaleBlockNet, self).__init__()
assert layers_count > 0 assert layers_count > 0
self.upscale_factor = upscale_factor self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim self.out_channels = out_channels
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
self.linear_projections = []
for i in range(layers_count):
self.linear_projections.append(EffKAN(in_features=hidden_dim, out_features=hidden_dim, bias=True))
self.linear_projections = nn.ModuleList(self.linear_projections)
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) if layers_count == 1:
self.kan_layers = [
ChebyKANLayer(in_features=in_features, out_features=out_channels*self.upscale_factor*self.upscale_factor, degree=degree)
]
elif layers_count > 1:
self.kan_layers = [
ChebyKANLayer(in_features=in_features, out_features=hidden_dim, degree=degree),
]
for i in range(layers_count-2):
self.kan_layers.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, degree=degree))
self.kan_layers.append(ChebyKANLayer(in_features=hidden_dim, out_features=out_channels*self.upscale_factor*self.upscale_factor, degree=degree))
self.in_bias = self.in_scale = input_max_value/2 self.kan_layers = nn.ModuleList(self.kan_layers)
self.out_bias = self.out_scale = output_max_value/2
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
def forward(self, x): def forward(self, x):
x = (x-self.in_bias)/self.in_scale for kan in self.kan_layers:
x = self.embed(x) x = kan(x)
for linear_projection in self.linear_projections: # print(x.min(), x.max())
x = self.layer_norm(linear_projection(x))
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
return x return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
lut_kan_layers = [v.get_lut_model(quantization_interval=quantization_interval) for v in self.kan_layers]
qmodel = ChebyKANUpscaleBlockLut.init_from_numpy(lut_kan_layers=lut_kan_layers)
qmodel.upscale_factor = self.upscale_factor
return qmodel
class ComplexGaborLayer2D(nn.Module):
'''
Implicit representation with complex Gabor nonlinearity with 2D activation function
https://github.com/vishwa91/wire/blob/main/modules/wire2d.py
Inputs;
in_features: Input features
out_features; Output features
bias: if True, enable bias for the linear operation
is_first: Legacy SIREN parameter
omega_0: Legacy SIREN parameter
omega0: Frequency of Gabor sinusoid term
sigma0: Scaling of Gabor Gaussian term
trainable: If True, omega and sigma are trainable parameters
'''
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega0=10.0, sigma0=10.0,
trainable=False):
super().__init__()
self.omega_0 = omega0
self.scale_0 = sigma0
self.is_first = is_first
self.in_features = in_features
if self.is_first:
dtype = torch.float
else:
dtype = torch.cfloat
# Set trainable parameters if they are to be simultaneously optimized
self.omega_0 = nn.Parameter(self.omega_0*torch.ones(1), trainable)
self.scale_0 = nn.Parameter(self.scale_0*torch.ones(1), trainable)
self.linear = nn.Linear(in_features,
out_features,
bias=bias,
dtype=dtype)
# Second Gaussian window
self.scale_orth = nn.Linear(in_features,
out_features,
bias=bias,
dtype=dtype)
def forward(self, input):
lin = self.linear(input)
scale_x = lin
scale_y = self.scale_orth(input)
freq_term = torch.exp(1j*self.omega_0*lin)
arg = scale_x.abs().square() + scale_y.abs().square()
gauss_term = torch.exp(-self.scale_0*self.scale_0*arg)
return freq_term*gauss_term class ChebyKANUpscaleBlockLut(nn.Module):
def __init__(self, layer_count, in_features_list, out_features_list, quantization_interval, upscale_factor):
super(ChebyKANUpscaleBlockLut, self).__init__()
class UpscaleBlockGabor(nn.Module): self.quantization_interval = quantization_interval
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
super(UpscaleBlockGabor, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim self.out_channels = 1
self.embed = ComplexGaborLayer2D(in_features=in_features, out_features=hidden_dim, is_first=True) self.lut_kan_layers = nn.ModuleList([
ChebyKANLut(in_features=in_features, out_features=out_features, quantization_interval=quantization_interval)
self.linear_projections = [] for i, in_features, out_features in zip(range(layer_count), in_features_list, out_features_list)])
for i in range(layers_count):
self.linear_projections.append(ComplexGaborLayer2D(in_features=hidden_dim, out_features=hidden_dim, is_first=False)) @staticmethod
self.linear_projections = nn.ModuleList(self.linear_projections) def init_from_numpy(
lut_kan_layers
self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) ):
in_features_list = [l.lut.shape[0] for l in lut_kan_layers]
self.in_bias = self.in_scale = input_max_value/2 out_features_list = [l.lut.shape[1] for l in lut_kan_layers]
self.out_bias = self.out_scale = output_max_value/2 quantization_interval = 256//lut_kan_layers[0].lut.shape[2]+1
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh qmodel = ChebyKANUpscaleBlockLut(layer_count=len(lut_kan_layers), in_features_list=in_features_list, out_features_list=out_features_list, quantization_interval=quantization_interval, upscale_factor=out_features_list[-1])
qmodel.lut_kan_layers = nn.ModuleList(lut_kan_layers)
return qmodel
def forward(self, x): def forward(self, x):
x = (x-self.in_bias)/self.in_scale for kan in self.lut_kan_layers:
x = self.embed(x) x = kan(x)
for linear_projection in self.linear_projections: # print(x.min(), x.max())
x = linear_projection(x)
x = x.real
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
return x return x
def __repr__(self):
return "\n".join([f"{self.__class__.__name__}"] + [f"lut size: {l.lut.shape}" for l in self.lut_kan_layers])

@ -8,23 +8,53 @@ import numpy as np
##################### TRANSFER ########################## ##################### TRANSFER ##########################
# TODO make one function instead of 4
class Domain1DValues(Dataset):
def __init__(self, quantization_interval=1, max_value=255):
super(Domain1DValues, self).__init__()
values1d = torch.arange(0, max_value+1, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([max_value+1])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*1)).view(-1, 1, 2)
def __getitem__(self, idx):
if isinstance(idx, slice):
ix1s, batch = [], []
for i in range(idx.start, idx.stop):
ix1, values = self.__getitem__(i)
ix1s.append(ix1)
batch.append(values)
return ix1s, batch
else:
v = self.values[idx]
ix = v[0]//self.quantization_interval
return ix[0], v
def __len__(self):
return len(self.values)
def __iter__(self):
for i in range(len(self.values)):
yield self.__getitem__(i)
class Domain2DValues(Dataset): class Domain2DValues(Dataset):
def __init__(self, quantization_interval=1): def __init__(self, quantization_interval=1, max_value=255):
super(Domain2DValues, self).__init__() super(Domain2DValues, self).__init__()
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8) values1d = torch.arange(0, max_value+1, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([256])]) values1d = torch.cat([values1d, torch.tensor([max_value+1])])
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*2)).view(-1, 1, 2) self.values = torch.cartesian_prod(*([values1d]*2)).view(-1, 1, 2)
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
ix1s, ix2s, batch = [], [], [], [] ix1s, ix2s, batch = [], [], []
for i in range(idx.start, idx.stop): for i in range(idx.start, idx.stop):
ix1, ix2, values = self.__getitem__(i) ix1, ix2, values = self.__getitem__(i)
ix1s.append(ix1) ix1s.append(ix1)
ix2s.append(ix2) ix2s.append(ix2)
batch.append(values) batch.append(values)
return ix1s, ix2s, ix3s, batch return ix1s, ix2s, batch
else: else:
v = self.values[idx] v = self.values[idx]
ix = v[0]//self.quantization_interval ix = v[0]//self.quantization_interval
@ -38,10 +68,10 @@ class Domain2DValues(Dataset):
yield self.__getitem__(i) yield self.__getitem__(i)
class Domain3DValues(Dataset): class Domain3DValues(Dataset):
def __init__(self, quantization_interval=1): def __init__(self, quantization_interval=1, max_value=255):
super(Domain3DValues, self).__init__() super(Domain3DValues, self).__init__()
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8) values1d = torch.arange(0, max_value+1, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([256])]) values1d = torch.cat([values1d, torch.tensor([max_value+1])])
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*3)).view(-1, 1, 3) self.values = torch.cartesian_prod(*([values1d]*3)).view(-1, 1, 3)
@ -99,25 +129,34 @@ class Domain4DValues(Dataset):
yield self.__getitem__(i) yield self.__getitem__(i)
def transfer_rc_conv(rc_conv, quantization_interval=1): def transfer_1_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
receptive_field_pixel_count = rc_conv.window_size**2 bucket_count = (max_value+1)//quantization_interval
bucket_count = 256//quantization_interval scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((receptive_field_pixel_count, bucket_count+1), dtype=np.uint8, fill_value=255) lut = np.full((bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
for pixel_id in range(receptive_field_pixel_count): domain_values = Domain1DValues(quantization_interval=quantization_interval, max_value=max_value)
for idx, value in enumerate(range(0, 256, quantization_interval)): domain_values_loader = DataLoader(
inputs = torch.tensor([value]).type(torch.float32).view(1,1,1).cuda() domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
pin_memory=True,
num_workers=1 if quantization_interval >= 16 else mp.cpu_count(),
shuffle=False,
)
counter = 0
for idx, (ix1s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
with torch.no_grad(): with torch.no_grad():
outputs = rc_conv.pixel_wise_forward(inputs) outputs = block(inputs)
lut[:,idx] = outputs.flatten().cpu().numpy().astype(np.uint8) lut[ix1s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
print(f"\r {rc_conv.__class__.__name__} {pixel_id*bucket_count + idx +1}/{receptive_field_pixel_count*bucket_count}", end=" ") counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print() print()
return lut return lut
def transfer_4_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255): def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = (max_value+1)//quantization_interval bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1 scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2 lut = np.full((bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain4DValues(quantization_interval=quantization_interval, max_value=max_value) domain_values = Domain2DValues(quantization_interval=quantization_interval, max_value=max_value)
domain_values_loader = DataLoader( domain_values_loader = DataLoader(
domain_values, domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16, batch_size=batch_size if quantization_interval >= 16 else 2**16,
@ -126,21 +165,21 @@ def transfer_4_input_SxS_output(block, quantization_interval=16, batch_size=2**1
shuffle=False, shuffle=False,
) )
counter = 0 counter = 0
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader): for idx, (ix1s, ix2s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).to(list(block.parameters())[0].device) inputs = batch.type(torch.float32).cuda()
with torch.no_grad(): with torch.no_grad():
outputs = block(inputs) outputs = block(inputs)
lut[ix1s, ix2s, ix3s, ix4s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8) lut[ix1s, ix2s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0] counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ") print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print() print()
return lut return lut
def transfer_3_input_SxS_output(block, quantization_interval=16, batch_size=2**10): def transfer_3_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = 256//quantization_interval bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1 scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2 lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain3DValues(quantization_interval=quantization_interval) domain_values = Domain3DValues(quantization_interval=quantization_interval, max_value=max_value)
domain_values_loader = DataLoader( domain_values_loader = DataLoader(
domain_values, domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16, batch_size=batch_size if quantization_interval >= 16 else 2**16,
@ -159,11 +198,11 @@ def transfer_3_input_SxS_output(block, quantization_interval=16, batch_size=2**1
print() print()
return lut return lut
def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**10): def transfer_4_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = 256//quantization_interval bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1 scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2 lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain2DValues(quantization_interval=quantization_interval) domain_values = Domain4DValues(quantization_interval=quantization_interval, max_value=max_value)
domain_values_loader = DataLoader( domain_values_loader = DataLoader(
domain_values, domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16, batch_size=batch_size if quantization_interval >= 16 else 2**16,
@ -172,16 +211,38 @@ def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**1
shuffle=False, shuffle=False,
) )
counter = 0 counter = 0
for idx, (ix1s, ix2s, batch) in enumerate(domain_values_loader): for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda() inputs = batch.type(torch.float32).to(list(block.parameters())[0].device)
with torch.no_grad(): with torch.no_grad():
outputs = block(inputs) outputs = block(inputs)
lut[ix1s, ix2s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8) lut[ix1s, ix2s, ix3s, ix4s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0] counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ") print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print() print()
return lut return lut
def transfer_N_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
match block.in_features:
case 1: return transfer_1_input_SxS_output(block, quantization_interval, batch_size, max_value)
case 2: return transfer_2_input_SxS_output(block, quantization_interval, batch_size, max_value)
case 3: return transfer_3_input_SxS_output(block, quantization_interval, batch_size, max_value)
case 4: return transfer_4_input_SxS_output(block, quantization_interval, batch_size, max_value)
# TODO revive rc
# def transfer_rc_conv(rc_conv, quantization_interval=1):
# receptive_field_pixel_count = rc_conv.window_size**2
# bucket_count = 256//quantization_interval
# lut = np.full((receptive_field_pixel_count, bucket_count+1), dtype=np.uint8, fill_value=255)
# for pixel_id in range(receptive_field_pixel_count):
# for idx, value in enumerate(range(0, 256, quantization_interval)):
# inputs = torch.tensor([value]).type(torch.float32).view(1,1,1).cuda()
# with torch.no_grad():
# outputs = rc_conv.pixel_wise_forward(inputs)
# lut[:,idx] = outputs.flatten().cpu().numpy().astype(np.uint8)
# print(f"\r {rc_conv.__class__.__name__} {pixel_id*bucket_count + idx +1}/{receptive_field_pixel_count*bucket_count}", end=" ")
# print()
# return lut
##################### FORWARD ########################## ##################### FORWARD ##########################
# def forward_rc_conv_centered(index, lut): # def forward_rc_conv_centered(index, lut):
@ -221,23 +282,19 @@ def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**1
##################### UTILS ########################## ##################### UTILS ##########################
# TODO rewrite for unfolded def select_index_1dlut_linear(index, lut):
def select_index_1dlut_linear(ixA, lut): b, hw, c = index.shape
lut = torch.clamp(lut, 0, 255)
b,c,h,w = ixA.shape
ixA = ixA.flatten()
L = lut.shape[0] L = lut.shape[0]
Q = 256/(L-1) Q = 256/(L-1)
msbA = torch.floor_divide(ixA, Q).type(torch.int64) msbA = torch.floor_divide(index, Q).type(torch.int64)
msbB = msbA + 1 msbB = msbA + 1
msbA = msbA.flatten() msbA = msbA.flatten()
msbB = msbB.flatten() msbB = msbB.flatten()
lsb = ixA % Q lsb = index % Q
outA = lut[msbA] outA = lut[msbA].reshape((b, hw, c))
outB = lut[msbB] outB = lut[msbB].reshape((b, hw, c))
lsb_coef = lsb / Q lsb_coef = (lsb / Q).reshape((b, hw, c))
out = outA + lsb_coef*(outB-outA) out = outA + lsb_coef*(outB-outA)
out = out.reshape((b,c,h,w))
return out return out
def select_index_3dlut_tetrahedral(index, lut): def select_index_3dlut_tetrahedral(index, lut):
@ -306,6 +363,7 @@ def select_index_3dlut_tetrahedral(index, lut):
def select_index_4dlut_tetrahedral(index, lut): def select_index_4dlut_tetrahedral(index, lut):
b, hw, c = index.shape b, hw, c = index.shape
index = index.clamp(0, 255)
lut = torch.clamp(lut, 0, 255) lut = torch.clamp(lut, 0, 255)
dimA, dimB, dimC, dimD = lut.shape[:4] dimA, dimB, dimC, dimD = lut.shape[:4]
q = 256/(dimA-1) q = 256/(dimA-1)
@ -316,7 +374,6 @@ def select_index_4dlut_tetrahedral(index, lut):
msbA = torch.floor_divide(index, q).type(torch.int64) msbA = torch.floor_divide(index, q).type(torch.int64)
msbB = msbA + 1 msbB = msbA + 1
lsb = index % q lsb = index % q
img_a1 = msbA[:,:,0].reshape(b*hw, 1) img_a1 = msbA[:,:,0].reshape(b*hw, 1)
img_b1 = msbA[:,:,1].reshape(b*hw, 1) img_b1 = msbA[:,:,1].reshape(b*hw, 1)
img_c1 = msbA[:,:,2].reshape(b*hw, 1) img_c1 = msbA[:,:,2].reshape(b*hw, 1)

@ -38,7 +38,7 @@ def test_image_pair(model, hr_image, lr_image, color_model, output_image_path=No
Image.fromarray(cmaplut[pred_lr_image[:,:,0]]).save(output_image_path) Image.fromarray(cmaplut[pred_lr_image[:,:,0]]).save(output_image_path)
# metrics # metrics
hr_image = modcrop(hr_image, model.scale) hr_image = modcrop(hr_image, model.config.upscale_factor)
if pred_lr_image.shape[-1] == 3 and color_model == 'RGB': if pred_lr_image.shape[-1] == 3 and color_model == 'RGB':
Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr': if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr':
@ -47,7 +47,7 @@ def test_image_pair(model, hr_image, lr_image, color_model, output_image_path=No
Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0] Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0]
lr_area = np.prod(lr_image.shape[-2:]) lr_area = np.prod(lr_image.shape[-2:])
psnr = PSNR(Y_left, Y_right, model.scale) psnr = PSNR(Y_left, Y_right, model.config.upscale_factor)
ssim = cal_ssim(Y_left, Y_right) ssim = cal_ssim(Y_left, Y_right)
return psnr, ssim, run_time_ns, lr_area return psnr, ssim, run_time_ns, lr_area

@ -0,0 +1,26 @@
from common import layers
class Transferer():
def __init__(self):
self.registered_types = {}
def register(self, input_class, output_class):
self.registered_types[input_class] = output_class
def transfer(self, input_model, batch_size, quantization_interval):
input_class = input_model.__class__
if not input_class in self.registered_types:
raise Exception(f"No transfer function is registered for class {input_class}")
transfered_model = self.transfer_model(input_model, self.registered_types[input_class], batch_size, quantization_interval)
return transfered_model
def transfer_model(self, model, output_model_class, batch_size, quantization_interval):
qmodel = output_model_class(config = model.config)
model.config.quantization_interval = quantization_interval
for attr, value in model.named_children():
if isinstance(value, layers.UpscaleBlock):
getattr(qmodel, attr).stage = getattr(model, attr).stage.get_lut_model(quantization_interval=quantization_interval, batch_size=batch_size)
return qmodel
TRANSFERER = Transferer()

@ -13,11 +13,12 @@ import argparse
class ImageDemoOptions(): class ImageDemoOptions():
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../experiments/last_transfered_net.pth","../experiments/last_transfered_lut.pth"], help="Model paths for comparison") default_paths = ["../experiments/last_trained_net.pth","../experiments/last_transfered_lut.pth"]
default_paths = [f for f in default_paths if Path(f).exists()]
self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=default_paths, help="Model paths for comparison. Example --model_paths ./A/ ./B/")
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path") self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path")
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--output_path', type=str, default="../experiments/", help="Output path.") self.parser.add_argument('--output_path', type=str, default="../experiments/", help="Output path.")
self.parser.add_argument('--output_name', type=str, default="image_demo.png", help="Output name.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False) self.parser.add_argument('--mirror', action='store_true', default=False)
self.parser.add_argument('--device', default='cuda', help='Device of the model') self.parser.add_argument('--device', default='cuda', help='Device of the model')
@ -29,6 +30,7 @@ class ImageDemoOptions():
args.hr_image_path = Path(args.hr_image_path).resolve() args.hr_image_path = Path(args.hr_image_path).resolve()
args.lr_image_path = Path(args.lr_image_path).resolve() args.lr_image_path = Path(args.lr_image_path).resolve()
args.model_paths = [Path(x).resolve() for x in args.model_paths] args.model_paths = [Path(x).resolve() for x in args.model_paths]
args.output_name = args.hr_image_path.stem + "_" + "_".join([Path(f).resolve().stem for f in args.model_paths]) + ".png"
return args return args
def __repr__(self): def __repr__(self):
@ -96,4 +98,5 @@ for i in range(row_count):
canvas = np.concatenate(columns, axis=0).astype(np.uint8) canvas = np.concatenate(columns, axis=0).astype(np.uint8)
Image.fromarray(canvas).save(config.output_path / config.output_name) Image.fromarray(canvas).save(config.output_path / config.output_name)
print(f"Saved to {config.output_path / config.output_name}")
print(datetime.now() - start_script_time ) print(datetime.now() - start_script_time )

@ -1,12 +1,5 @@
from . import rcnet from . import models
from . import rclut from common.base import SRBase
from . import srnet
from . import srlut
from . import sdynet
from . import hdbnet
from . import hdblut
from models.base import SRNetBase
from common import losses
import torch import torch
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
@ -15,7 +8,10 @@ import inspect, sys
AVAILABLE_MODELS = {} AVAILABLE_MODELS = {}
for module_name in sys.modules.keys(): for module_name in sys.modules.keys():
if 'models.' in module_name: if 'models.' in module_name:
AVAILABLE_MODELS.update({k:v for k,v in inspect.getmembers(sys.modules[module_name], lambda x: inspect.isclass(x) and SRNetBase in x.__bases__)}) AVAILABLE_MODELS.update({
k:v for k,v in inspect.getmembers(sys.modules[module_name],
lambda x: inspect.isclass(x) and SRBase in inspect.getmro(x) and (not 'base' in x.__name__.lower()) )
})
def SaveCheckpoint(model, path): def SaveCheckpoint(model, path):
model_container = { model_container = {
@ -29,7 +25,8 @@ def LoadCheckpoint(model_path):
model_path = Path(model_path).absolute() model_path = Path(model_path).absolute()
if model_path.exists(): if model_path.exists():
model_container = torch.load(model_path) model_container = torch.load(model_path)
model = AVAILABLE_MODELS[model_container['model']](**{k:v for k,v in model_container.items() if k != "model" and k != "state_dict"}) init_arg_names = list(inspect.signature(AVAILABLE_MODELS[model_container['model']]).parameters.keys())
model = AVAILABLE_MODELS[model_container['model']](**{k:model_container[k] for k in init_arg_names})
model.load_state_dict(model_container['state_dict'], strict=True) model.load_state_dict(model_container['state_dict'], strict=True)
return model return model
else: else:

@ -1,28 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
class SRNetBase(nn.Module):
def __init__(self):
super(SRNetBase, self).__init__()
def forward_stage(self, x, percieve_pattern, stage):
b,c,h,w = x.shape
scale = stage.upscale_factor
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, 1, h, w, scale, scale)
x = x.permute(0, 1, 2, 4, 3, 5)
x = x.reshape(b, 1, h*scale, w*scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -1,134 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import select_index_4dlut_tetrahedral
from common import layers
from common.utils import round_func
class HDBLut(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(HDBLut, self).__init__()
assert scale == 4
self.scale = scale
self.quantization_interval = quantization_interval
self.stage1_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage1_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage1_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage1_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self.stage1_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self.stage2_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage2_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage2_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage2_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self.stage2_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
):
# quantization_interval = 256//(stage1_3H.shape[0]-1)
quantization_interval = 16
lut_model = HDBLut(quantization_interval=quantization_interval, scale=4)
lut_model.stage1_3H = nn.Parameter(torch.tensor(stage1_3H).type(torch.float32))
lut_model.stage1_3D = nn.Parameter(torch.tensor(stage1_3D).type(torch.float32))
lut_model.stage1_3B = nn.Parameter(torch.tensor(stage1_3B).type(torch.float32))
lut_model.stage1_2H = nn.Parameter(torch.tensor(stage1_2H).type(torch.float32))
lut_model.stage1_2D = nn.Parameter(torch.tensor(stage1_2D).type(torch.float32))
lut_model.stage2_3H = nn.Parameter(torch.tensor(stage2_3H).type(torch.float32))
lut_model.stage2_3D = nn.Parameter(torch.tensor(stage2_3D).type(torch.float32))
lut_model.stage2_3B = nn.Parameter(torch.tensor(stage2_3B).type(torch.float32))
lut_model.stage2_2H = nn.Parameter(torch.tensor(stage2_2H).type(torch.float32))
lut_model.stage2_2D = nn.Parameter(torch.tensor(stage2_2D).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
print(np.prod(x.shape))
x = percieve_pattern(x)
shifts = torch.tensor([lut.shape[0]**d for d in range(len(lut.shape)-2)], device=x.device).flip(0).reshape(1,1,len(lut.shape)-2)
print(x.shape, x.min(), x.max())
x = torch.sum(x * shifts, dim=-1)
print(x.shape)
lut = torch.clamp(lut, 0, 255)
lut = lut.reshape(-1, scale, scale)
x = x.flatten().type(torch.int64)
x = lut[x]
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
print(x.shape)
# raise RuntimeError
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4*2
print(output_msb.min(), output_msb.max())
print(output_lsb.min(), output_lsb.max())
output_msb = output_msb + output_lsb
x = output_msb
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
print("STAGE2", msb.min(), msb.max(), lsb.min(), lsb.max())
for rotations_count in range(4):
rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4*2
output_msb = output_msb + output_lsb
x = output_msb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stage1_3H size: {self.stage1_3H.shape}" + \
f"\n stage1_3D size: {self.stage1_3D.shape}" + \
f"\n stage1_3B size: {self.stage1_3B.shape}" + \
f"\n stage1_2H size: {self.stage1_2H.shape}" + \
f"\n stage1_2D size: {self.stage1_2D.shape}" + \
f"\n stage2_3H size: {self.stage2_3H.shape}" + \
f"\n stage2_3D size: {self.stage2_3D.shape}" + \
f"\n stage2_3B size: {self.stage2_3B.shape}" + \
f"\n stage2_2H size: {self.stage2_2H.shape}" + \
f"\n stage2_2D size: {self.stage2_2D.shape}"

@ -1,441 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import lut
from pathlib import Path
from . import hdblut
from common import layers
from itertools import cycle
from models.base import SRNetBase
class HDBNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4, rotations = 4):
super(HDBNet, self).__init__()
assert scale == 4
self.scale = scale
self.rotations = rotations
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
for rotations_count in range(self.rotations):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B)
output_msb /= 3
output_lsb = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H) + \
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D)
output_lsb /= 2
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
output /= self.rotations
x = output
lsb = x % 16
msb = x - lsb
output = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
for rotations_count in range(self.rotations):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B)
output_msb /= 3
output_lsb = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H) + \
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D)
output_lsb /= 2
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s2_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s2_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
output /= self.rotations
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = hdblut.HDBLut.init_from_numpy(
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBNetv2(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 2, scale = 4, rotations = 4):
super(HDBNetv2, self).__init__()
assert scale == 4
self.scale = scale
self.rotations = rotations
self.layers_count = layers_count
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
for rotations_count in range(self.rotations):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B)
output_lsbt = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H) + \
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= self.rotations*3
output_lsb /= self.rotations*2
output = output_msb + output_lsb
x = output.clamp(0, 255)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
for rotations_count in range(self.rotations):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B)
output_lsbt = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H) + \
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s2_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s2_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= self.rotations*3
output_lsb /= self.rotations*2
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = hdblut.HDBLut.init_from_numpy(
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNet, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = self.forward_stage(msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
output_lsb = self.forward_stage(lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
output_msb /= 3
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNetR90(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNet, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNetR90KAN(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNetR90KAN, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3D = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3B = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3L = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBHNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBHNet, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.msb_fns = SRNetBaseList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(1)])
self.lsb_fns = SRNetBaseList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(1)])
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count, msb_fn, lsb_fn in zip(range(4), cycle(self.msb_fns), cycle(self.lsb_fns)):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, lsb_fn)
output_msb_r = round_func((output_msb_r / 255)*16) * 15
output_lsb_r = (output_lsb_r / 255) * 15
output_msb += torch.rot90(output_msb_r, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsb_r, k=-rotations_count, dims=[2, 3])
output_msb /= 4
output_lsb /= 4
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
fourier_loss_fn = FocalFrequencyLoss()
high_frequency_loss_fn = FourierLoss()
def loss_fn(pred, target):
a = fourier_loss_fn(pred/255, target/255) * 1e8
# b = F.mse_loss(pred/255, target/255) #* 1e3
# c = high_frequency_loss_fn(pred/255, target/255) * 1e6
return a #+ b #+ c
return loss_fn

@ -0,0 +1,95 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import lut
from pathlib import Path
from common import layers
from common import losses
from common.base import SRBase
from common.transferer import TRANSFERER
class SRNetBase(SRBase):
def __init__(self):
super(SRNetBase, self).__init__()
self.config = None
self.stage1_S = layers.UpscaleBlock(None)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, script_config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
x = self.stage1_S(x, self._extract_pattern_S)
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
return x
class SRNet(SRNetBase):
def __init__(self, config):
super(SRNet, self).__init__()
self.config = config
self.stage1_S.stage = layers.LinearUpscaleBlockNet(hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
class SRLut(SRNetBase):
def __init__(self, config):
super(SRLut, self).__init__()
self.config = config
self.stage1_S.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
TRANSFERER.register(SRNet, SRLut)
class ChebyKANBase(SRBase):
def __init__(self):
super(ChebyKANBase, self).__init__()
self.config = None
self.stage1_S = layers.UpscaleBlock(None)
window_size = 7
self._extract_pattern = layers.PercievePattern(
receptive_field_idxes=[[i,j] for i in range(window_size) for j in range(window_size)],
center=[window_size//2,window_size//2],
window_size=window_size
)
def forward(self, x, script_config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
x = self.stage1_S(x, self._extract_pattern)
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
return x
class ChebyKANNet(ChebyKANBase):
def __init__(self, config):
super(ChebyKANNet, self).__init__()
self.config = config
window_size = 7
self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
in_features=window_size*window_size,
out_channels=1,
hidden_dim=16,
layers_count=self.config.layers_count,
upscale_factor=self.config.upscale_factor,
degree=8
)
class ChebyKANLut(ChebyKANBase):
def __init__(self, config):
super(ChebyKANLut, self).__init__()
self.config = config
window_size = 7
self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
in_features=window_size*window_size,
out_channels=1,
hidden_dim=16,
layers_count=self.config.layers_count,
upscale_factor=self.config.upscale_factor
).get_lut_model(quantization_interval=self.config.quantization_interval)
def get_loss_fn(self):
ssim_loss = losses.SSIM(data_range=255)
l1_loss = losses.CharbonnierLoss()
def loss_fn(pred, target):
return ssim_loss(pred, target) + l1_loss(pred, target)
return loss_fn
TRANSFERER.register(ChebyKANNet, ChebyKANLut)

@ -1,505 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
# from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
# from pathlib import Path
# class RCLutCentered_3x3(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutCentered_3x3, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts, dense_conv_lut
# ):
# scale = int(dense_conv_lut.shape[-1])
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
# lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale)
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
# x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1]
# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
# x = x.view(b, c, x.shape[-2], x.shape[-1])
# return x
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
# ")"])
# class RCLutCentered_7x7(nn.Module):
# def __init__(
# self,
# window_size,
# quantization_interval,
# scale
# ):
# super(RCLutCentered_7x7, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts, dense_conv_lut
# ):
# scale = int(dense_conv_lut.shape[-1])
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
# lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale)
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
# # x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4)
# x = x.view(b, c, x.shape[-2], x.shape[-1])
# return x
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
# ")"])
# class RCLutRot90_3x3(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutRot90_3x3, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts, dense_conv_lut
# ):
# scale = int(dense_conv_lut.shape[-1])
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
# lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
# output += unrotated_prediction
# output /= 4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
# ")"])
# class RCLutRot90_7x7(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutRot90_7x7, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts, dense_conv_lut
# ):
# scale = int(dense_conv_lut.shape[-1])
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
# window_size = rc_conv_luts.shape[0]
# lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale)
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
# output += unrotated_prediction
# output /= 4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
# ")"])
# class RCLutx1(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutx1, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts_3x3, dense_conv_lut_3x3,
# rc_conv_luts_5x5, dense_conv_lut_5x5,
# rc_conv_luts_7x7, dense_conv_lut_7x7
# ):
# scale = int(dense_conv_lut_3x3.shape[-1])
# quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1)
# lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale)
# lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32))
# lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32))
# lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32))
# lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32))
# lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32))
# lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32))
# return lut_model
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}",
# f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}",
# f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}",
# f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}",
# f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}",
# f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}",
# ")"])
# class RCLutx2(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutx2, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
# ):
# scale = int(s2_dense_conv_lut_3x3.shape[-1])
# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
# lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale)
# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
# return lut_model
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
# ")"])
# class RCLutx2Centered(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutx2Centered, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
# ):
# scale = int(s2_dense_conv_lut_3x3.shape[-1])
# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
# lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale)
# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
# return lut_model
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
# x = forward_rc_conv_centered(index=index, lut=rc_conv_lut)
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
# ")"])

@ -1,568 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from pathlib import Path
from common import lut
from . import rclut
from common import layers
# class ReconstructedConvCentered(nn.Module):
# def __init__(self, hidden_dim, window_size=7):
# super(ReconstructedConvCentered, self).__init__()
# self.window_size = window_size
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# def pixel_wise_forward(self, x):
# x = (x-127.5)/127.5
# out = torch.einsum('bwk,wh,wh -> bwk', x, self.projection1, self.projection2)
# out = torch.tanh(out)
# out = out*127.5 + 127.5
# return out
# def forward(self, x):
# original_shape = x.shape
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
# x = F.unfold(x, self.window_size)
# x = self.pixel_wise_forward(x)
# x = x.mean(1)
# x = x.reshape(*original_shape)
# x = round_func(x)
# return x
# def __repr__(self):
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
# class RCBlockCentered(nn.Module):
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
# super(RCBlockCentered, self).__init__()
# self.window_size = window_size
# self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size)
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
# def forward(self, x):
# b,c,hs,ws = x.shape
# x = self.rc_conv(x)
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
# x = self.dense_conv_block(x)
# return x
# class RCNetCentered_3x3(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetCentered_3x3, self).__init__()
# self.hidden_dim = hidden_dim
# self.layers_count = layers_count
# self.scale = scale
# self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# x = self.stage(x)
# x = x.view(b, c, h*self.scale, w*self.scale)
# return x
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# window_size = self.stage.rc_conv.window_size
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutCentered_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
# return lut_model
# class RCNetCentered_7x7(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetCentered_7x7, self).__init__()
# self.hidden_dim = hidden_dim
# self.layers_count = layers_count
# self.scale = scale
# window_size = 7
# self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size)
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# x = self.stage(x)
# x = x.view(b, c, h*self.scale, w*self.scale)
# return x
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# window_size = self.stage.rc_conv.window_size
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutCentered_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
# return lut_model
# class ReconstructedConvRot90(nn.Module):
# def __init__(self, hidden_dim, window_size=7):
# super(ReconstructedConvRot90, self).__init__()
# self.window_size = window_size
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# def pixel_wise_forward(self, x):
# x = (x-127.5)/127.5
# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
# out = torch.tanh(out)
# out = out*127.5 + 127.5
# return out
# def forward(self, x):
# original_shape = x.shape
# x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
# x = F.unfold(x, self.window_size)
# x = self.pixel_wise_forward(x)
# x = x.mean(1)
# x = x.reshape(*original_shape)
# x = round_func(x) # quality likely suffer from this
# return x
# def __repr__(self):
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
# class RCBlockRot90(nn.Module):
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
# super(RCBlockRot90, self).__init__()
# self.window_size = window_size
# self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size)
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
# def forward(self, x):
# b,c,hs,ws = x.shape
# x = self.rc_conv(x)
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
# x = self.dense_conv_block(x)
# return x
# class RCNetRot90_3x3(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetRot90_3x3, self).__init__()
# self.hidden_dim = hidden_dim
# self.layers_count = layers_count
# self.scale = scale
# self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# window_size = self.stage.rc_conv.window_size
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutRot90_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# rotated_prediction = self.stage(rotated)
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
# output += unrotated_prediction
# output /= 4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class RCNetRot90_7x7(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetRot90_7x7, self).__init__()
# self.hidden_dim = hidden_dim
# self.layers_count = layers_count
# self.scale = scale
# self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# window_size = self.stage.rc_conv.window_size
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutRot90_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# rotated_prediction = self.stage(rotated)
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
# output += unrotated_prediction
# output /= 4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class RCNetx1(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx1, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx1.init_from_numpy(
# rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3,
# rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5,
# rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class RCNetx2(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx2, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
# self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
# self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
# self.stage2_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage2_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage2_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx2.init_from_numpy(
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class RCNetx2Centered(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx2Centered, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
# self.stage1_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
# self.stage1_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
# self.stage2_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage2_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage2_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx2Centered.init_from_numpy(
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class ReconstructedConvRot90Unlutable(nn.Module):
# def __init__(self, hidden_dim, window_size=7):
# super(ReconstructedConvRot90Unlutable, self).__init__()
# self.window_size = window_size
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# def pixel_wise_forward(self, x):
# x = (x-127.5)/127.5
# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
# out = torch.tanh(out)
# out = out*127.5 + 127.5
# return out
# def forward(self, x):
# original_shape = x.shape
# x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
# x = F.unfold(x, self.window_size)
# x = self.pixel_wise_forward(x)
# x = x.mean(1)
# x = x.reshape(*original_shape)
# # x = round_func(x) # quality likely suffer from this
# return x
# def __repr__(self):
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
# class RCBlockRot90Unlutable(nn.Module):
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
# super(RCBlockRot90Unlutable, self).__init__()
# self.window_size = window_size
# self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size)
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
# def forward(self, x):
# b,c,hs,ws = x.shape
# x = self.rc_conv(x)
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
# x = self.dense_conv_block(x)
# return x
# class RCNetx2Unlutable(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx2Unlutable, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
# self.stage1_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
# self.stage1_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
# self.stage2_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage2_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx2.init_from_numpy(
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class ReconstructedConvCenteredUnlutable(nn.Module):
# def __init__(self, hidden_dim, window_size=7):
# super(ReconstructedConvCenteredUnlutable, self).__init__()
# self.window_size = window_size
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# def pixel_wise_forward(self, x):
# x = (x-127.5)/127.5
# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
# out = torch.tanh(out)
# out = out*127.5 + 127.5
# return out
# def forward(self, x):
# original_shape = x.shape
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
# x = F.unfold(x, self.window_size)
# x = self.pixel_wise_forward(x)
# x = x.mean(1)
# x = x.reshape(*original_shape)
# # x = round_func(x) # quality likely suffer from this
# return x
# def __repr__(self):
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
# class RCBlockCenteredUnlutable(nn.Module):
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
# super(RCBlockRot90Unlutable, self).__init__()
# self.window_size = window_size
# self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size)
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
# def forward(self, x):
# b,c,hs,ws = x.shape
# x = self.rc_conv(x)
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
# x = self.dense_conv_block(x)
# return x
# class RCNetx2CenteredUnlutable(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx2CenteredUnlutable, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
# self.stage1_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
# self.stage1_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
# self.stage2_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage2_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx2Centered.init_from_numpy(
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output

@ -1,395 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import select_index_4dlut_tetrahedral
from common.layers import PercievePattern
from common.utils import round_func
class SDYLutx1(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutx1, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stageS, stageD, stageY
):
scale = int(stageS.shape[-1])
quantization_interval = 256//(stageS.shape[0]-1)
lut_model = SDYLutx1(quantization_interval=quantization_interval, scale=scale)
lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY)
output /= 3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stageS size: {self.stageS.shape}" + \
f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutx2(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutx2, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
):
scale = int(stage2_S.shape[-1])
quantization_interval = 256//(stage2_S.shape[0]-1)
lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale)
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
output /= 3
output = round_func(output)
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
output /= 3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stage1_S size: {self.stage1_S.shape}" + \
f"\n stage1_D size: {self.stage1_D.shape}" + \
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutx3(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutx3, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage3_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage3_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage3_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y, stage3_S, stage3_D, stage3_Y
):
scale = int(stage3_S.shape[-1])
quantization_interval = 256//(stage3_S.shape[0]-1)
lut_model = SDYLutx3(quantization_interval=quantization_interval, scale=scale)
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
lut_model.stage3_S = nn.Parameter(torch.tensor(stage3_S).type(torch.float32))
lut_model.stage3_D = nn.Parameter(torch.tensor(stage3_D).type(torch.float32))
lut_model.stage3_Y = nn.Parameter(torch.tensor(stage3_Y).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
output /= 3
output = round_func(output)
x = output
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage2_Y)
output /= 3
output = round_func(output)
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage3_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage3_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage3_Y)
output /= 3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stage1_S size: {self.stage1_S.shape}" + \
f"\n stage1_D size: {self.stage1_D.shape}" + \
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}" + \
f"\n stage3_S size: {self.stage3_S.shape}" + \
f"\n stage3_D size: {self.stage3_D.shape}" + \
f"\n stage3_Y size: {self.stage3_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutR90x1(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutR90x1, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stageS, stageD, stageY
):
scale = int(stageS.shape[-1])
quantization_interval = 256//(stageS.shape[0]-1)
lut_model = SDYLutR90x1(quantization_interval=quantization_interval, scale=scale)
lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stageS), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stageD), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stageY), k=-rotations_count, dims=[-2, -1])
output /= 4*3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stageS size: {self.stageS.shape}" + \
f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutR90x2(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutR90x2, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
):
scale = int(stage2_S.shape[-1])
quantization_interval = 256//(stage2_S.shape[0]-1)
lut_model = SDYLutR90x2(quantization_interval=quantization_interval, scale=scale)
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1])
output /= 4*3
x = round_func(output)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1])
output /= 4*3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stage1_S size: {self.stage1_S.shape}" + \
f"\n stage1_D size: {self.stage1_D.shape}" + \
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

File diff suppressed because it is too large Load Diff

@ -1,276 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import select_index_4dlut_tetrahedral
from common import layers
from common.utils import round_func
from models.base import SRNetBase
class SRLut(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLut, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLut(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w).type(torch.float32)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutY(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutY, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutY(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutR90, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.reshape(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90Y(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutR90Y, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1,4):
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90YCbCr(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutR90YCbCr, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutR90YCbCr(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1,4):
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

File diff suppressed because it is too large Load Diff

@ -87,7 +87,7 @@ if __name__ == "__main__":
for test_dataset_name in config.test_datasets: for test_dataset_name in config.test_datasets:
test_datasets[test_dataset_name] = SRTestDataset( test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.config.upscale_factor}",
color_model=config.color_model, color_model=config.color_model,
reset_cache=config.reset_cache, reset_cache=config.reset_cache,
) )

@ -26,33 +26,37 @@ torch.backends.cudnn.benchmark = True
import argparse import argparse
from schedulefree import AdamWScheduleFree from schedulefree import AdamWScheduleFree
from datetime import datetime from datetime import datetime
from types import SimpleNamespace
import signal import signal
class SignalHandler: class SignalHandler:
def __init__(self, signal_code): def __init__(self, signal_code):
self.is_on = False self.is_on = False
self.count = 0
signal.signal(signal_code, self.exit_gracefully) signal.signal(signal_code, self.exit_gracefully)
def exit_gracefully(self, signum, frame): def exit_gracefully(self, signum, frame):
print("Early stopping.") print("Early stopping.")
self.is_on = True self.is_on = True
self.count += 1
if self.count == 3:
exit(1)
signal_interraption_handler = SignalHandler(signal.SIGINT) signal_interraption_handler = SignalHandler(signal.SIGINT)
class TrainOptions: class TrainOptions:
def __init__(self): def __init__(self):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False)
parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}") parser.add_argument('--model', type=str, default='SRNet', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.") # parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.") parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.")
parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.") parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.")
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor") parser.add_argument('--upscale_factor', '-s', type=int, default=4, help="up scale factor")
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers") parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers") parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers")
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size') parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training") parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training")
parser.add_argument('--models_dir', type=str, default='../experiments/', help="experiment folder") parser.add_argument('--experiment_dir', type=str, default='../experiments/', help="experiment folder")
parser.add_argument('--datasets_dir', type=str, default="../data/") parser.add_argument('--datasets_dir', type=str, default="../data/")
parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further') parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further')
parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations') parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations')
@ -69,25 +73,25 @@ class TrainOptions:
parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache') parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache')
parser.add_argument('--learning_rate', type=float, default=0.0025, help='Learning rate') parser.add_argument('--learning_rate', type=float, default=0.0025, help='Learning rate')
parser.add_argument('--grad_step', type=int, default=1, help='Optimizer step.') parser.add_argument('--grad_step', type=int, default=1, help='Optimizer step.')
self.parser = parser self.parser = parser
def parse_args(self): def parse_args(self):
args = self.parser.parse_args() args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve() args.datasets_dir = Path(args.datasets_dir).resolve()
args.models_dir = Path(args.models_dir).resolve() args.experiment_dir = Path(args.experiment_dir).resolve()
args.model_path = Path(args.model_path) if not args.model_path is None else None if Path(args.model).exists():
args.model = Path(args.model)
args.start_iter = int(args.model.stem.split("_")[-1])
args.train_datasets = args.train_datasets.split(',') args.train_datasets = args.train_datasets.split(',')
args.test_datasets = args.test_datasets.split(',') args.test_datasets = args.test_datasets.split(',')
if not args.model_path is None: args.quantization_interval = 2**(8-args.quantization_bits)
args.start_iter = int(args.model_path.stem.split("_")[-1])
return args return args
def save_config(self, config): def save_config(self, config):
yaml.dump(config, open(config.exp_dir / "config.yaml", 'w')) yaml.dump(config, open(config.exp_dir / "config.yaml", 'w'))
def __repr__(self): def __repr__(self):
config = self.parse_args() config = self.parse_args() if self.config is None else self.config
message = '' message = ''
message += '----------------- Options ---------------\n' message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()): for k, v in sorted(vars(config).items()):
@ -99,13 +103,24 @@ class TrainOptions:
message += '----------------- End -------------------' message += '----------------- End -------------------'
return message return message
def prepare_config(self): def prepare_experiment(self):
config = self.parse_args() config = self.parse_args()
if isinstance(config.model, Path):
model = LoadCheckpoint(config.model)
config.model = model.__class__.__name__
else:
config_dict = vars(config)
model = AVAILABLE_MODELS[config.model](
config = SimpleNamespace(**{k:config_dict[k] for k in config_dict.keys() if k in ['hidden_dim', 'layers_count', 'upscale_factor', 'quantization_interval']})
)
model = model.to(torch.device(config.device))
# model = torch.compile(model)
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}."
config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve() config.exp_dir = (config.experiment_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.upscale_factor}").resolve()
if not config.exp_dir.exists(): if not config.exp_dir.exists():
config.exp_dir.mkdir() config.exp_dir.mkdir()
@ -122,29 +137,10 @@ class TrainOptions:
if not config.logs_dir.exists(): if not config.logs_dir.exists():
config.logs_dir.mkdir() config.logs_dir.mkdir()
return config
if __name__ == "__main__":
# torch.set_float32_matmul_precision('high')
script_start_time = datetime.now()
config_inst = TrainOptions()
config = config_inst.prepare_config()
if not config.model_path is None:
model = LoadCheckpoint(config.model_path)
config.model = model.__class__.__name__
else:
if 'net' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale, layers_count=config.layers_count)
if 'lut' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale)
model = model.to(torch.device(config.device))
# model = torch.compile(model)
optimizer = AdamWScheduleFree(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.95)) optimizer = AdamWScheduleFree(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.95))
print(optimizer)
config_inst.save_config(config) self.save_config(config)
self.config = config
# Tensorboard for monitoring # Tensorboard for monitoring
writer = SummaryWriter(log_dir=config.logs_dir) writer = SummaryWriter(log_dir=config.logs_dir)
@ -154,15 +150,25 @@ if __name__ == "__main__":
config.writer = writer config.writer = writer
config.logger = logger config.logger = logger
config.logger.info(config_inst) config.logger.info(self)
config.logger.info(model) config.logger.info(model)
config.logger.info(optimizer)
return config, model, optimizer
if __name__ == "__main__":
# torch.set_float32_matmul_precision('high')
script_start_time = datetime.now()
config_inst = TrainOptions()
config, model, optimizer = config_inst.prepare_experiment()
# Training dataset # Training dataset
train_datasets = [] train_datasets = []
for train_dataset_name in config.train_datasets: for train_dataset_name in config.train_datasets:
train_datasets.append(SRTrainDataset( train_datasets.append(SRTrainDataset(
hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR", hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}", lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.upscale_factor}",
patch_size = config.crop_size, patch_size = config.crop_size,
color_model = config.color_model, color_model = config.color_model,
reset_cache=config.reset_cache reset_cache=config.reset_cache
@ -183,7 +189,7 @@ if __name__ == "__main__":
for test_dataset_name in config.test_datasets: for test_dataset_name in config.test_datasets:
test_datasets[test_dataset_name] = SRTestDataset( test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.upscale_factor}",
color_model = config.color_model, color_model = config.color_model,
reset_cache=config.reset_cache reset_cache=config.reset_cache
) )
@ -194,7 +200,7 @@ if __name__ == "__main__":
# TRAINING # TRAINING
i = config.start_iter i = config.start_iter
if not config.model_path is None: if isinstance(config.model, Path):
config.current_iter = i config.current_iter = i
test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
@ -216,9 +222,10 @@ if __name__ == "__main__":
start_time = time.time() start_time = time.time()
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
pred = model(x=lr_patch, config=config) pred = model(x=lr_patch, script_config=config)
loss = loss_fn(pred=pred, target=hr_patch) / config.grad_step loss = loss_fn(pred=pred, target=hr_patch) / config.grad_step
loss.backward() loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
if i % config.grad_step == 0 and i > 0: if i % config.grad_step == 0 and i > 0:
optimizer.step() optimizer.step()
@ -260,18 +267,18 @@ if __name__ == "__main__":
print("Saved to ", model_path) print("Saved to ", model_path)
# check if it is network or lut # check if it is network or lut
if hasattr(model, 'get_lut_model'): if 'net' in model.__class__.__name__.lower():
link = Path(config.models_dir / f"last_trained_net.pth") link = Path(config.experiment_dir / f"last_trained_net.pth")
if link.exists(): if link.exists(follow_symlinks=False):
link.unlink() link.unlink()
link.symlink_to(model_path) link.symlink_to(model_path)
else: else:
link = Path(config.models_dir / f"last_trained_lut.pth") link = Path(config.experiment_dir / f"last_trained_lut.pth")
if link.exists(): if link.exists(follow_symlinks=False):
link.unlink() link.unlink()
link.symlink_to(model_path) link.symlink_to(model_path)
link = Path(config.models_dir / f"last.pth") link = Path(config.experiment_dir / f"last.pth")
if link.exists(): if link.exists(follow_symlinks=False):
link.unlink() link.unlink()
link.symlink_to(model_path) link.symlink_to(model_path)

@ -15,11 +15,13 @@ torch.backends.cudnn.benchmark = True
from datetime import datetime from datetime import datetime
import argparse import argparse
import models import models
from common.transferer import TRANSFERER
class TransferToLutOptions(): class TransferToLutOptions():
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', '-m', type=str, default='../models/last_trained_net.pth', help="model path folder") self.parser.add_argument('--model_path', '-m', type=str, default='../experiments/last_trained_net.pth', help="model path folder")
self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")
@ -31,7 +33,7 @@ class TransferToLutOptions():
return args return args
def __repr__(self): def __repr__(self):
config = self.parser.parse_args() config = self.parse_args()
message = '' message = ''
message += '----------------- Options ---------------\n' message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()): for k, v in sorted(vars(config).items()):
@ -53,16 +55,15 @@ if __name__ == "__main__":
print(config_inst) print(config_inst)
model = models.LoadCheckpoint(config.model_path).cuda() model = models.LoadCheckpoint(config.model_path).cuda()
if getattr(model, 'get_lut_model', None) is None: if not 'net' in model.__class__.__name__.lower():
print("Transfer to lut can be applied only to the network model.") print("Transfer to lut can be applied only to the network model.")
exit(1) exit(1)
print(model) print(model)
print() print()
print("Transfering:") print("Transfering:")
lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size) lut_model = TRANSFERER.transfer(model, quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size)
print() print()
print(lut_model)
lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth" lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth"
models.SaveCheckpoint(model=lut_model, path=lut_path) models.SaveCheckpoint(model=lut_model, path=lut_path)

Loading…
Cancel
Save