new framework impl. added chebykan and linear models.

main
protsenkovi 2 months ago
parent 67bd678763
commit d34cc7833e

File diff suppressed because one or more lines are too long

@ -24,3 +24,10 @@ Requirements:
```
pip install schedulefree tensorboard opencv-python-headless scipy pandas matplotlib
```
python train.py --model ChebyKANNet --total_iter 200000 --train_datasets mix-01-blur-1+ffmpeg --test_datasets mix-01-blur-test --upscale_factor 1 --val_step 210000 --layers_count 4
python transfer_to_lut.py --model_path ../experiments/ChebyKANNet_RGB_mix-01-blur-1+ffmpeg_x1/checkpoints/ChebyKANNet_200000.pth
python test.py --model_path ../experiments/last_transfered_lut.pth --test_datasets src --save_predictions
python test.py --model_path ../--test_datasets src --save_predictions

@ -0,0 +1,24 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import layers
import copy
class SRBase(nn.Module):
def __init__(self):
super(SRBase, self).__init__()
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
# def get_loss_fn(self):
# ssim_loss = losses.SSIM(data_range=255)
# l1_loss = losses.CharbonnierLoss()
# def loss_fn(pred, target):
# return ssim_loss(pred, target) + l1_loss(pred, target)
# return loss_fn

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .utils import round_func
from . import lut
class PercievePattern():
"""
@ -11,7 +12,7 @@ class PercievePattern():
1. receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]
2. receptive_field_idxes=[[0,0,0],[0,1,0],[1,1,0],[1,1]]
"""
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, channels=1):
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, channels=1, mode='replicate'):
assert window_size >= (np.max([x for y in receptive_field_idxes for x in y])+1)
tmp = []
for coords in receptive_field_idxes:
@ -21,6 +22,7 @@ class PercievePattern():
if len(coords) == 3:
tmp.append(coords)
receptive_field_idxes = np.array(sorted(tmp))
self.mode = mode
self.window_size = window_size
self.center = center
self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size*self.window_size + receptive_field_idxes[i,1]*self.window_size + receptive_field_idxes[i,2] for i in range(len(receptive_field_idxes))]
@ -28,33 +30,49 @@ class PercievePattern():
def __call__(self, x):
b,c,h,w = x.shape
if not self.mode is None:
x = F.pad(
x,
pad=[
self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1
],
mode='replicate'
mode=self.mode
)
x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([x[:,self.receptive_field_idxes[i],:] for i in range(len(self.receptive_field_idxes))], 2)
return x
class UpscaleBlock(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
def __init__(self, stage):
super(UpscaleBlock, self).__init__()
self.stage = stage
def forward(self, x, percieve_pattern):
b,c,h,w = x.shape
upscale_factor = self.stage.upscale_factor
out_channels = self.stage.out_channels
x = percieve_pattern(x)
x = self.stage(x)
x = x.reshape(b, h, w, out_channels, upscale_factor, upscale_factor)
x = x.permute(0, 3, 1, 4, 2, 5)
x = x.reshape(b, out_channels, h*upscale_factor, w*upscale_factor)
return x
class LinearUpscaleBlockNet(nn.Module):
def __init__(self, in_features=4, out_channels=1, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
super(LinearUpscaleBlockNet, self).__init__()
assert layers_count > 0
self.in_features = in_features
self.upscale_factor = upscale_factor
self.out_channels = out_channels
self.hidden_dim = hidden_dim
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
self.linear_projections = []
for i in range(layers_count):
self.linear_projections.append(nn.Linear(in_features=(i+1)*hidden_dim, out_features=hidden_dim, bias=True))
self.linear_projections = nn.ModuleList(self.linear_projections)
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=out_channels*upscale_factor*upscale_factor, bias=True)
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
@ -66,334 +84,67 @@ class UpscaleBlock(nn.Module):
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
x = round_func(x)
return x
class RgbToYcbcr(nn.Module):
r"""Convert an image from RGB to YCbCr.
The image data is assumed to be in the range of (0, 1).
Returns:
YCbCr version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> ycbcr = RgbToYcbcr()
>>> output = ycbcr(input) # 2x3x4x5
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105
"""
def forward(self, image):
r = image[..., 0, :, :]
g = image[..., 1, :, :]
b = image[..., 2, :, :]
delta = 0.5
y = 0.299 * r + 0.587 * g + 0.114 * b
cb = (b - y) * 0.564 + delta
cr = (r - y) * 0.713 + delta
return torch.stack([y, cb, cr], -3)
class YcbcrToRgb(nn.Module):
r"""Convert an image from YCbCr to Rgb.
The image data is assumed to be in the range of (0, 1).
Returns:
RGB version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> rgb = YcbcrToRgb()
>>> output = rgb(input) # 2x3x4x5
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L127
"""
def forward(self, image):
y = image[..., 0, :, :]
cb = image[..., 1, :, :]
cr = image[..., 2, :, :]
delta = 0.5
cb_shifted = cb - delta
cr_shifted = cr - delta
r = y + 1.403 * cr_shifted
g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
b = y + 1.773 * cb_shifted
return torch.stack([r, g, b], -3)
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage = lut.transfer_N_input_SxS_output(self, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = LinearUpscaleBlockLut.init_from_numpy(stage)
return lut_model
class LinearUpscaleBlockLut(nn.Module):
def __init__(self, quantization_interval, upscale_factor):
super(LinearUpscaleBlockLut, self).__init__()
self.out_channels = 1
self.upscale_factor = upscale_factor
self.quantization_interval = quantization_interval
self.stage = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (upscale_factor,upscale_factor)).type(torch.float32))
class EffKAN(torch.nn.Module):
"""
https://github.com/Blealtan/efficient-kan/blob/605b8c2ae24b8085bd11b96896a17eabe12f6736/src/efficient_kan/kan.py#L6
"""
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
@staticmethod
def init_from_numpy(
stage
):
super(EffKAN, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
upscale_factor = int(stage.shape[-1])
quantization_interval = 256//(stage.shape[0]-1)
lut_model = LinearUpscaleBlockLut(quantization_interval=quantization_interval, upscale_factor=upscale_factor)
lut_model.stage = nn.Parameter(torch.tensor(stage).type(torch.float32))
return lut_model
Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
assert x.size(-1) == self.in_features
original_shape = x.shape
x = x.view(-1, self.in_features)
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
output = base_output + spline_output
output = output.view(*original_shape[:-1], self.out_features)
return output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)
splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
def forward(self, x):
x = lut.select_index_4dlut_tetrahedral(x, self.stage)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
Compute the regularization loss.
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage.shape}"
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
class ChebyKANLayer(nn.Module):
"""
https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py
"""
def __init__(self, in_features, out_features, degree=8):
def __init__(self, in_features, out_features, degree=8, input_max_value=255, output_max_value=255):
super(ChebyKANLayer, self).__init__()
self.inputdim = in_features
self.outdim = out_features
self.in_features = in_features
self.out_features = out_features
self.degree = degree
self.cheby_coeffs = nn.Parameter(torch.empty(in_features, out_features, degree + 1))
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (in_features * (degree + 1)))
self.register_buffer("arange", torch.arange(0, degree + 1, 1))
def forward(self, x):
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
def forward_all_to_all(self, x):
# Since Chebyshev polynomial is defined in [-1, 1]
# We need to normalize x to [-1, 1] using tanh
b,hw,c = x.shape
x = torch.tanh(x)
b, hw, c = x.shape
assert c == self.in_features, f"Input features count={c} is not equal to specified for this layer count={self.in_features}."
x = (x-self.in_bias)/self.in_scale
# x = torch.tanh(x)
# View and repeat input degree + 1 times
x = x.view((b, hw, self.inputdim, 1)).expand(
x = x.view((b, hw, self.in_features, 1)).expand(
-1, -1, -1, self.degree + 1
) # shape = (batch_size, inputdim, self.degree + 1)
# Apply acos
@ -404,157 +155,152 @@ class ChebyKANLayer(nn.Module):
x = x.cos()
# Compute the Chebyshev interpolation
y = torch.einsum(
"btid,iod->bto", x, self.cheby_coeffs
) # shape = (batch_size, hw, outdim)
y = y.view(b, hw, self.outdim)
"btid,iod->btio", x, self.cheby_coeffs
)
# shape = (batch_size, hw, outdim)
# print("btio", y.shape)
y = torch.tanh(y)
# y = (y-1) / self.degree
y = y*self.out_scale + self.out_bias
y = y.view(b, hw, self.in_features, self.out_features)
# y = y.clamp(0,255)
# print("btio", y.shape)
return y
def forward(self, x):
out = self.forward_all_to_all(x)
# print("net", out.min(), out.max())
out = (out - 127.5)/127.5
out = out.sum(dim=-2)
# print("net", out.min(), out.max())
# out = torch.tanh(out)
out = out / self.in_features
out = out*127.5 + 127.5
return out
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
in_features = self.in_features
out_features = self.out_features
domain_values = torch.cat([torch.arange(0, 256, quantization_interval, dtype=torch.uint8), torch.tensor([255])])
inputs = domain_values.type(torch.float32).to(self.cheby_coeffs.device)
lut = np.full((in_features, out_features, bucket_count+1), dtype=np.uint8, fill_value=255)
# lut_min = np.full((in_features, out_features), dtype=np.uint8, fill_value=0)
# lut_length = np.full((in_features, out_features), dtype=np.uint8, fill_value=255)
# print(in_features, out_features)
qmodel = ChebyKANLut(in_features=self.in_features, out_features=self.out_features, quantization_interval=quantization_interval)
with torch.no_grad():
for d in range(len(domain_values)):
out = self.forward_all_to_all(inputs[d].view(1,1,1).expand(1,1,in_features)).squeeze(0).squeeze(1).cpu().numpy()
lut[:,:,d] = out.astype(np.uint8)
# lut_min[:,:] = lut.min(axis=-1).astype(np.uint8)
# lut_length[:,:] = lut.max(axis=-1).astype(np.uint8)-lut.min(axis=-1).astype(np.uint8)
qmodel.lut = nn.Parameter(torch.tensor(lut).type(torch.float32))
# qmodel.lut_min = nn.Parameter(torch.tensor(lut_min).type(torch.float32))
# qmodel.lut_length = nn.Parameter(torch.tensor(lut_length).type(torch.float32))
return qmodel
def __repr__(self):
return (f"{self.__class__.__name__}\n"
f"cheby coefs size: {self.cheby_coeffs.shape}\n"
f"in/out bias, scale: {self.in_bias}, {self.in_scale} / {self.out_bias}, {self.out_scale}" )
class ChebyKANLut(nn.Module):
def __init__(self, in_features, out_features, quantization_interval=16):
super(ChebyKANLut, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.lut = nn.Parameter(
torch.randint(0, 255, size=(in_features, out_features, 256//quantization_interval+1)).type(torch.float32)
)
# self.lut_min = nn.Parameter(torch.ones((in_features, out_features)).type(torch.float32)*0)
# self.lut_max = nn.Parameter(torch.ones((in_features, out_features)).type(torch.float32)*255)
class UpscaleBlockChebyKAN(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255, degree=8):
super(UpscaleBlockChebyKAN, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
def forward_all_to_all(self, x):
b,w,c = x.shape
out = torch.zeros([b, w, self.in_features, self.out_features], dtype=x.dtype, device=x.device)
for i in range(self.in_features):
for j in range(self.out_features):
out[:,:, i, j:(j+1)] = lut.select_index_1dlut_linear(x[:,:,i:(i+1)], self.lut[i,j])
return out
self.linear_projections = []
for i in range(layers_count):
self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, degree=degree))
self.linear_projections = nn.ModuleList(self.linear_projections)
def forward(self, x):
out = self.forward_all_to_all(x)
out = (out - 127.5)/127.5
out = out.sum(dim=-2)
# out = torch.tanh(out)
out = out / self.in_features
out = out*127.5 + 127.5
return out
self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.lut.shape}"
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
def forward(self, x):
x = (x-self.in_bias)/self.in_scale
x = self.embed(x)
for linear_projection in self.linear_projections:
x = self.layer_norm(linear_projection(x))
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
return x
class UpscaleBlockEffKAN(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
super(UpscaleBlockEffKAN, self).__init__()
class ChebyKANUpscaleBlockNet(nn.Module):
def __init__(self, in_features=4, out_channels=1, hidden_dim=64, layers_count=2, upscale_factor=1, input_max_value=255, output_max_value=255, degree=8):
super(ChebyKANUpscaleBlockNet, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
self.linear_projections = []
for i in range(layers_count):
self.linear_projections.append(EffKAN(in_features=hidden_dim, out_features=hidden_dim, bias=True))
self.linear_projections = nn.ModuleList(self.linear_projections)
self.out_channels = out_channels
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
if layers_count == 1:
self.kan_layers = [
ChebyKANLayer(in_features=in_features, out_features=out_channels*self.upscale_factor*self.upscale_factor, degree=degree)
]
elif layers_count > 1:
self.kan_layers = [
ChebyKANLayer(in_features=in_features, out_features=hidden_dim, degree=degree),
]
for i in range(layers_count-2):
self.kan_layers.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, degree=degree))
self.kan_layers.append(ChebyKANLayer(in_features=hidden_dim, out_features=out_channels*self.upscale_factor*self.upscale_factor, degree=degree))
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
self.kan_layers = nn.ModuleList(self.kan_layers)
def forward(self, x):
x = (x-self.in_bias)/self.in_scale
x = self.embed(x)
for linear_projection in self.linear_projections:
x = self.layer_norm(linear_projection(x))
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
for kan in self.kan_layers:
x = kan(x)
# print(x.min(), x.max())
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
lut_kan_layers = [v.get_lut_model(quantization_interval=quantization_interval) for v in self.kan_layers]
qmodel = ChebyKANUpscaleBlockLut.init_from_numpy(lut_kan_layers=lut_kan_layers)
qmodel.upscale_factor = self.upscale_factor
return qmodel
class ComplexGaborLayer2D(nn.Module):
'''
Implicit representation with complex Gabor nonlinearity with 2D activation function
https://github.com/vishwa91/wire/blob/main/modules/wire2d.py
Inputs;
in_features: Input features
out_features; Output features
bias: if True, enable bias for the linear operation
is_first: Legacy SIREN parameter
omega_0: Legacy SIREN parameter
omega0: Frequency of Gabor sinusoid term
sigma0: Scaling of Gabor Gaussian term
trainable: If True, omega and sigma are trainable parameters
'''
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega0=10.0, sigma0=10.0,
trainable=False):
super().__init__()
self.omega_0 = omega0
self.scale_0 = sigma0
self.is_first = is_first
self.in_features = in_features
if self.is_first:
dtype = torch.float
else:
dtype = torch.cfloat
# Set trainable parameters if they are to be simultaneously optimized
self.omega_0 = nn.Parameter(self.omega_0*torch.ones(1), trainable)
self.scale_0 = nn.Parameter(self.scale_0*torch.ones(1), trainable)
self.linear = nn.Linear(in_features,
out_features,
bias=bias,
dtype=dtype)
# Second Gaussian window
self.scale_orth = nn.Linear(in_features,
out_features,
bias=bias,
dtype=dtype)
def forward(self, input):
lin = self.linear(input)
scale_x = lin
scale_y = self.scale_orth(input)
freq_term = torch.exp(1j*self.omega_0*lin)
arg = scale_x.abs().square() + scale_y.abs().square()
gauss_term = torch.exp(-self.scale_0*self.scale_0*arg)
return freq_term*gauss_term
class UpscaleBlockGabor(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
super(UpscaleBlockGabor, self).__init__()
assert layers_count > 0
class ChebyKANUpscaleBlockLut(nn.Module):
def __init__(self, layer_count, in_features_list, out_features_list, quantization_interval, upscale_factor):
super(ChebyKANUpscaleBlockLut, self).__init__()
self.quantization_interval = quantization_interval
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = ComplexGaborLayer2D(in_features=in_features, out_features=hidden_dim, is_first=True)
self.linear_projections = []
for i in range(layers_count):
self.linear_projections.append(ComplexGaborLayer2D(in_features=hidden_dim, out_features=hidden_dim, is_first=False))
self.linear_projections = nn.ModuleList(self.linear_projections)
self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
self.out_channels = 1
self.lut_kan_layers = nn.ModuleList([
ChebyKANLut(in_features=in_features, out_features=out_features, quantization_interval=quantization_interval)
for i, in_features, out_features in zip(range(layer_count), in_features_list, out_features_list)])
@staticmethod
def init_from_numpy(
lut_kan_layers
):
in_features_list = [l.lut.shape[0] for l in lut_kan_layers]
out_features_list = [l.lut.shape[1] for l in lut_kan_layers]
quantization_interval = 256//lut_kan_layers[0].lut.shape[2]+1
qmodel = ChebyKANUpscaleBlockLut(layer_count=len(lut_kan_layers), in_features_list=in_features_list, out_features_list=out_features_list, quantization_interval=quantization_interval, upscale_factor=out_features_list[-1])
qmodel.lut_kan_layers = nn.ModuleList(lut_kan_layers)
return qmodel
def forward(self, x):
x = (x-self.in_bias)/self.in_scale
x = self.embed(x)
for linear_projection in self.linear_projections:
x = linear_projection(x)
x = x.real
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
for kan in self.lut_kan_layers:
x = kan(x)
# print(x.min(), x.max())
return x
def __repr__(self):
return "\n".join([f"{self.__class__.__name__}"] + [f"lut size: {l.lut.shape}" for l in self.lut_kan_layers])

@ -8,23 +8,53 @@ import numpy as np
##################### TRANSFER ##########################
# TODO make one function instead of 4
class Domain1DValues(Dataset):
def __init__(self, quantization_interval=1, max_value=255):
super(Domain1DValues, self).__init__()
values1d = torch.arange(0, max_value+1, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([max_value+1])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*1)).view(-1, 1, 2)
def __getitem__(self, idx):
if isinstance(idx, slice):
ix1s, batch = [], []
for i in range(idx.start, idx.stop):
ix1, values = self.__getitem__(i)
ix1s.append(ix1)
batch.append(values)
return ix1s, batch
else:
v = self.values[idx]
ix = v[0]//self.quantization_interval
return ix[0], v
def __len__(self):
return len(self.values)
def __iter__(self):
for i in range(len(self.values)):
yield self.__getitem__(i)
class Domain2DValues(Dataset):
def __init__(self, quantization_interval=1):
def __init__(self, quantization_interval=1, max_value=255):
super(Domain2DValues, self).__init__()
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([256])])
values1d = torch.arange(0, max_value+1, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([max_value+1])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*2)).view(-1, 1, 2)
def __getitem__(self, idx):
if isinstance(idx, slice):
ix1s, ix2s, batch = [], [], [], []
ix1s, ix2s, batch = [], [], []
for i in range(idx.start, idx.stop):
ix1, ix2, values = self.__getitem__(i)
ix1s.append(ix1)
ix2s.append(ix2)
batch.append(values)
return ix1s, ix2s, ix3s, batch
return ix1s, ix2s, batch
else:
v = self.values[idx]
ix = v[0]//self.quantization_interval
@ -38,10 +68,10 @@ class Domain2DValues(Dataset):
yield self.__getitem__(i)
class Domain3DValues(Dataset):
def __init__(self, quantization_interval=1):
def __init__(self, quantization_interval=1, max_value=255):
super(Domain3DValues, self).__init__()
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([256])])
values1d = torch.arange(0, max_value+1, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([max_value+1])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*3)).view(-1, 1, 3)
@ -99,25 +129,34 @@ class Domain4DValues(Dataset):
yield self.__getitem__(i)
def transfer_rc_conv(rc_conv, quantization_interval=1):
receptive_field_pixel_count = rc_conv.window_size**2
bucket_count = 256//quantization_interval
lut = np.full((receptive_field_pixel_count, bucket_count+1), dtype=np.uint8, fill_value=255)
for pixel_id in range(receptive_field_pixel_count):
for idx, value in enumerate(range(0, 256, quantization_interval)):
inputs = torch.tensor([value]).type(torch.float32).view(1,1,1).cuda()
def transfer_1_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain1DValues(quantization_interval=quantization_interval, max_value=max_value)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
pin_memory=True,
num_workers=1 if quantization_interval >= 16 else mp.cpu_count(),
shuffle=False,
)
counter = 0
for idx, (ix1s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
with torch.no_grad():
outputs = rc_conv.pixel_wise_forward(inputs)
lut[:,idx] = outputs.flatten().cpu().numpy().astype(np.uint8)
print(f"\r {rc_conv.__class__.__name__} {pixel_id*bucket_count + idx +1}/{receptive_field_pixel_count*bucket_count}", end=" ")
outputs = block(inputs)
lut[ix1s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print()
return lut
def transfer_4_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain4DValues(quantization_interval=quantization_interval, max_value=max_value)
lut = np.full((bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain2DValues(quantization_interval=quantization_interval, max_value=max_value)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
@ -126,21 +165,21 @@ def transfer_4_input_SxS_output(block, quantization_interval=16, batch_size=2**1
shuffle=False,
)
counter = 0
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).to(list(block.parameters())[0].device)
for idx, (ix1s, ix2s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
with torch.no_grad():
outputs = block(inputs)
lut[ix1s, ix2s, ix3s, ix4s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
lut[ix1s, ix2s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print()
return lut
def transfer_3_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
def transfer_3_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
domain_values = Domain3DValues(quantization_interval=quantization_interval)
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain3DValues(quantization_interval=quantization_interval, max_value=max_value)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
@ -159,11 +198,11 @@ def transfer_3_input_SxS_output(block, quantization_interval=16, batch_size=2**1
print()
return lut
def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
def transfer_4_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
domain_values = Domain2DValues(quantization_interval=quantization_interval)
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain4DValues(quantization_interval=quantization_interval, max_value=max_value)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
@ -172,16 +211,38 @@ def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**1
shuffle=False,
)
counter = 0
for idx, (ix1s, ix2s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).to(list(block.parameters())[0].device)
with torch.no_grad():
outputs = block(inputs)
lut[ix1s, ix2s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
lut[ix1s, ix2s, ix3s, ix4s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print()
return lut
def transfer_N_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
match block.in_features:
case 1: return transfer_1_input_SxS_output(block, quantization_interval, batch_size, max_value)
case 2: return transfer_2_input_SxS_output(block, quantization_interval, batch_size, max_value)
case 3: return transfer_3_input_SxS_output(block, quantization_interval, batch_size, max_value)
case 4: return transfer_4_input_SxS_output(block, quantization_interval, batch_size, max_value)
# TODO revive rc
# def transfer_rc_conv(rc_conv, quantization_interval=1):
# receptive_field_pixel_count = rc_conv.window_size**2
# bucket_count = 256//quantization_interval
# lut = np.full((receptive_field_pixel_count, bucket_count+1), dtype=np.uint8, fill_value=255)
# for pixel_id in range(receptive_field_pixel_count):
# for idx, value in enumerate(range(0, 256, quantization_interval)):
# inputs = torch.tensor([value]).type(torch.float32).view(1,1,1).cuda()
# with torch.no_grad():
# outputs = rc_conv.pixel_wise_forward(inputs)
# lut[:,idx] = outputs.flatten().cpu().numpy().astype(np.uint8)
# print(f"\r {rc_conv.__class__.__name__} {pixel_id*bucket_count + idx +1}/{receptive_field_pixel_count*bucket_count}", end=" ")
# print()
# return lut
##################### FORWARD ##########################
# def forward_rc_conv_centered(index, lut):
@ -221,23 +282,19 @@ def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**1
##################### UTILS ##########################
# TODO rewrite for unfolded
def select_index_1dlut_linear(ixA, lut):
lut = torch.clamp(lut, 0, 255)
b,c,h,w = ixA.shape
ixA = ixA.flatten()
def select_index_1dlut_linear(index, lut):
b, hw, c = index.shape
L = lut.shape[0]
Q = 256/(L-1)
msbA = torch.floor_divide(ixA, Q).type(torch.int64)
msbA = torch.floor_divide(index, Q).type(torch.int64)
msbB = msbA + 1
msbA = msbA.flatten()
msbB = msbB.flatten()
lsb = ixA % Q
outA = lut[msbA]
outB = lut[msbB]
lsb_coef = lsb / Q
lsb = index % Q
outA = lut[msbA].reshape((b, hw, c))
outB = lut[msbB].reshape((b, hw, c))
lsb_coef = (lsb / Q).reshape((b, hw, c))
out = outA + lsb_coef*(outB-outA)
out = out.reshape((b,c,h,w))
return out
def select_index_3dlut_tetrahedral(index, lut):
@ -306,6 +363,7 @@ def select_index_3dlut_tetrahedral(index, lut):
def select_index_4dlut_tetrahedral(index, lut):
b, hw, c = index.shape
index = index.clamp(0, 255)
lut = torch.clamp(lut, 0, 255)
dimA, dimB, dimC, dimD = lut.shape[:4]
q = 256/(dimA-1)
@ -316,7 +374,6 @@ def select_index_4dlut_tetrahedral(index, lut):
msbA = torch.floor_divide(index, q).type(torch.int64)
msbB = msbA + 1
lsb = index % q
img_a1 = msbA[:,:,0].reshape(b*hw, 1)
img_b1 = msbA[:,:,1].reshape(b*hw, 1)
img_c1 = msbA[:,:,2].reshape(b*hw, 1)

@ -38,7 +38,7 @@ def test_image_pair(model, hr_image, lr_image, color_model, output_image_path=No
Image.fromarray(cmaplut[pred_lr_image[:,:,0]]).save(output_image_path)
# metrics
hr_image = modcrop(hr_image, model.scale)
hr_image = modcrop(hr_image, model.config.upscale_factor)
if pred_lr_image.shape[-1] == 3 and color_model == 'RGB':
Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr':
@ -47,7 +47,7 @@ def test_image_pair(model, hr_image, lr_image, color_model, output_image_path=No
Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0]
lr_area = np.prod(lr_image.shape[-2:])
psnr = PSNR(Y_left, Y_right, model.scale)
psnr = PSNR(Y_left, Y_right, model.config.upscale_factor)
ssim = cal_ssim(Y_left, Y_right)
return psnr, ssim, run_time_ns, lr_area

@ -0,0 +1,26 @@
from common import layers
class Transferer():
def __init__(self):
self.registered_types = {}
def register(self, input_class, output_class):
self.registered_types[input_class] = output_class
def transfer(self, input_model, batch_size, quantization_interval):
input_class = input_model.__class__
if not input_class in self.registered_types:
raise Exception(f"No transfer function is registered for class {input_class}")
transfered_model = self.transfer_model(input_model, self.registered_types[input_class], batch_size, quantization_interval)
return transfered_model
def transfer_model(self, model, output_model_class, batch_size, quantization_interval):
qmodel = output_model_class(config = model.config)
model.config.quantization_interval = quantization_interval
for attr, value in model.named_children():
if isinstance(value, layers.UpscaleBlock):
getattr(qmodel, attr).stage = getattr(model, attr).stage.get_lut_model(quantization_interval=quantization_interval, batch_size=batch_size)
return qmodel
TRANSFERER = Transferer()

@ -13,11 +13,12 @@ import argparse
class ImageDemoOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../experiments/last_transfered_net.pth","../experiments/last_transfered_lut.pth"], help="Model paths for comparison")
default_paths = ["../experiments/last_trained_net.pth","../experiments/last_transfered_lut.pth"]
default_paths = [f for f in default_paths if Path(f).exists()]
self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=default_paths, help="Model paths for comparison. Example --model_paths ./A/ ./B/")
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path")
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--output_path', type=str, default="../experiments/", help="Output path.")
self.parser.add_argument('--output_name', type=str, default="image_demo.png", help="Output name.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False)
self.parser.add_argument('--device', default='cuda', help='Device of the model')
@ -29,6 +30,7 @@ class ImageDemoOptions():
args.hr_image_path = Path(args.hr_image_path).resolve()
args.lr_image_path = Path(args.lr_image_path).resolve()
args.model_paths = [Path(x).resolve() for x in args.model_paths]
args.output_name = args.hr_image_path.stem + "_" + "_".join([Path(f).resolve().stem for f in args.model_paths]) + ".png"
return args
def __repr__(self):
@ -96,4 +98,5 @@ for i in range(row_count):
canvas = np.concatenate(columns, axis=0).astype(np.uint8)
Image.fromarray(canvas).save(config.output_path / config.output_name)
print(f"Saved to {config.output_path / config.output_name}")
print(datetime.now() - start_script_time )

@ -1,12 +1,5 @@
from . import rcnet
from . import rclut
from . import srnet
from . import srlut
from . import sdynet
from . import hdbnet
from . import hdblut
from models.base import SRNetBase
from common import losses
from . import models
from common.base import SRBase
import torch
import numpy as np
from pathlib import Path
@ -15,7 +8,10 @@ import inspect, sys
AVAILABLE_MODELS = {}
for module_name in sys.modules.keys():
if 'models.' in module_name:
AVAILABLE_MODELS.update({k:v for k,v in inspect.getmembers(sys.modules[module_name], lambda x: inspect.isclass(x) and SRNetBase in x.__bases__)})
AVAILABLE_MODELS.update({
k:v for k,v in inspect.getmembers(sys.modules[module_name],
lambda x: inspect.isclass(x) and SRBase in inspect.getmro(x) and (not 'base' in x.__name__.lower()) )
})
def SaveCheckpoint(model, path):
model_container = {
@ -29,7 +25,8 @@ def LoadCheckpoint(model_path):
model_path = Path(model_path).absolute()
if model_path.exists():
model_container = torch.load(model_path)
model = AVAILABLE_MODELS[model_container['model']](**{k:v for k,v in model_container.items() if k != "model" and k != "state_dict"})
init_arg_names = list(inspect.signature(AVAILABLE_MODELS[model_container['model']]).parameters.keys())
model = AVAILABLE_MODELS[model_container['model']](**{k:model_container[k] for k in init_arg_names})
model.load_state_dict(model_container['state_dict'], strict=True)
return model
else:

@ -1,28 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
class SRNetBase(nn.Module):
def __init__(self):
super(SRNetBase, self).__init__()
def forward_stage(self, x, percieve_pattern, stage):
b,c,h,w = x.shape
scale = stage.upscale_factor
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, 1, h, w, scale, scale)
x = x.permute(0, 1, 2, 4, 3, 5)
x = x.reshape(b, 1, h*scale, w*scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -1,134 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import select_index_4dlut_tetrahedral
from common import layers
from common.utils import round_func
class HDBLut(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(HDBLut, self).__init__()
assert scale == 4
self.scale = scale
self.quantization_interval = quantization_interval
self.stage1_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage1_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage1_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage1_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self.stage1_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self.stage2_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage2_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage2_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage2_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self.stage2_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
):
# quantization_interval = 256//(stage1_3H.shape[0]-1)
quantization_interval = 16
lut_model = HDBLut(quantization_interval=quantization_interval, scale=4)
lut_model.stage1_3H = nn.Parameter(torch.tensor(stage1_3H).type(torch.float32))
lut_model.stage1_3D = nn.Parameter(torch.tensor(stage1_3D).type(torch.float32))
lut_model.stage1_3B = nn.Parameter(torch.tensor(stage1_3B).type(torch.float32))
lut_model.stage1_2H = nn.Parameter(torch.tensor(stage1_2H).type(torch.float32))
lut_model.stage1_2D = nn.Parameter(torch.tensor(stage1_2D).type(torch.float32))
lut_model.stage2_3H = nn.Parameter(torch.tensor(stage2_3H).type(torch.float32))
lut_model.stage2_3D = nn.Parameter(torch.tensor(stage2_3D).type(torch.float32))
lut_model.stage2_3B = nn.Parameter(torch.tensor(stage2_3B).type(torch.float32))
lut_model.stage2_2H = nn.Parameter(torch.tensor(stage2_2H).type(torch.float32))
lut_model.stage2_2D = nn.Parameter(torch.tensor(stage2_2D).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
print(np.prod(x.shape))
x = percieve_pattern(x)
shifts = torch.tensor([lut.shape[0]**d for d in range(len(lut.shape)-2)], device=x.device).flip(0).reshape(1,1,len(lut.shape)-2)
print(x.shape, x.min(), x.max())
x = torch.sum(x * shifts, dim=-1)
print(x.shape)
lut = torch.clamp(lut, 0, 255)
lut = lut.reshape(-1, scale, scale)
x = x.flatten().type(torch.int64)
x = lut[x]
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
print(x.shape)
# raise RuntimeError
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4*2
print(output_msb.min(), output_msb.max())
print(output_lsb.min(), output_lsb.max())
output_msb = output_msb + output_lsb
x = output_msb
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
print("STAGE2", msb.min(), msb.max(), lsb.min(), lsb.max())
for rotations_count in range(4):
rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4*2
output_msb = output_msb + output_lsb
x = output_msb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stage1_3H size: {self.stage1_3H.shape}" + \
f"\n stage1_3D size: {self.stage1_3D.shape}" + \
f"\n stage1_3B size: {self.stage1_3B.shape}" + \
f"\n stage1_2H size: {self.stage1_2H.shape}" + \
f"\n stage1_2D size: {self.stage1_2D.shape}" + \
f"\n stage2_3H size: {self.stage2_3H.shape}" + \
f"\n stage2_3D size: {self.stage2_3D.shape}" + \
f"\n stage2_3B size: {self.stage2_3B.shape}" + \
f"\n stage2_2H size: {self.stage2_2H.shape}" + \
f"\n stage2_2D size: {self.stage2_2D.shape}"

@ -1,441 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import lut
from pathlib import Path
from . import hdblut
from common import layers
from itertools import cycle
from models.base import SRNetBase
class HDBNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4, rotations = 4):
super(HDBNet, self).__init__()
assert scale == 4
self.scale = scale
self.rotations = rotations
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
for rotations_count in range(self.rotations):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B)
output_msb /= 3
output_lsb = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H) + \
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D)
output_lsb /= 2
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
output /= self.rotations
x = output
lsb = x % 16
msb = x - lsb
output = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
for rotations_count in range(self.rotations):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B)
output_msb /= 3
output_lsb = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H) + \
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D)
output_lsb /= 2
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s2_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s2_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
output /= self.rotations
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = hdblut.HDBLut.init_from_numpy(
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBNetv2(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 2, scale = 4, rotations = 4):
super(HDBNetv2, self).__init__()
assert scale == 4
self.scale = scale
self.rotations = rotations
self.layers_count = layers_count
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
for rotations_count in range(self.rotations):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B)
output_lsbt = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H) + \
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= self.rotations*3
output_lsb /= self.rotations*2
output = output_msb + output_lsb
x = output.clamp(0, 255)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
for rotations_count in range(self.rotations):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D) + \
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B)
output_lsbt = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H) + \
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s2_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s2_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= self.rotations*3
output_lsb /= self.rotations*2
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = hdblut.HDBLut.init_from_numpy(
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNet, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = self.forward_stage(msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
output_lsb = self.forward_stage(lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
output_msb /= 3
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNetR90(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNet, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNetR90KAN(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNetR90KAN, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3D = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3B = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3L = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBHNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBHNet, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.msb_fns = SRNetBaseList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(1)])
self.lsb_fns = SRNetBaseList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(1)])
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count, msb_fn, lsb_fn in zip(range(4), cycle(self.msb_fns), cycle(self.lsb_fns)):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, lsb_fn)
output_msb_r = round_func((output_msb_r / 255)*16) * 15
output_lsb_r = (output_lsb_r / 255) * 15
output_msb += torch.rot90(output_msb_r, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsb_r, k=-rotations_count, dims=[2, 3])
output_msb /= 4
output_lsb /= 4
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
fourier_loss_fn = FocalFrequencyLoss()
high_frequency_loss_fn = FourierLoss()
def loss_fn(pred, target):
a = fourier_loss_fn(pred/255, target/255) * 1e8
# b = F.mse_loss(pred/255, target/255) #* 1e3
# c = high_frequency_loss_fn(pred/255, target/255) * 1e6
return a #+ b #+ c
return loss_fn

@ -0,0 +1,95 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import lut
from pathlib import Path
from common import layers
from common import losses
from common.base import SRBase
from common.transferer import TRANSFERER
class SRNetBase(SRBase):
def __init__(self):
super(SRNetBase, self).__init__()
self.config = None
self.stage1_S = layers.UpscaleBlock(None)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, script_config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
x = self.stage1_S(x, self._extract_pattern_S)
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
return x
class SRNet(SRNetBase):
def __init__(self, config):
super(SRNet, self).__init__()
self.config = config
self.stage1_S.stage = layers.LinearUpscaleBlockNet(hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
class SRLut(SRNetBase):
def __init__(self, config):
super(SRLut, self).__init__()
self.config = config
self.stage1_S.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
TRANSFERER.register(SRNet, SRLut)
class ChebyKANBase(SRBase):
def __init__(self):
super(ChebyKANBase, self).__init__()
self.config = None
self.stage1_S = layers.UpscaleBlock(None)
window_size = 7
self._extract_pattern = layers.PercievePattern(
receptive_field_idxes=[[i,j] for i in range(window_size) for j in range(window_size)],
center=[window_size//2,window_size//2],
window_size=window_size
)
def forward(self, x, script_config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
x = self.stage1_S(x, self._extract_pattern)
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
return x
class ChebyKANNet(ChebyKANBase):
def __init__(self, config):
super(ChebyKANNet, self).__init__()
self.config = config
window_size = 7
self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
in_features=window_size*window_size,
out_channels=1,
hidden_dim=16,
layers_count=self.config.layers_count,
upscale_factor=self.config.upscale_factor,
degree=8
)
class ChebyKANLut(ChebyKANBase):
def __init__(self, config):
super(ChebyKANLut, self).__init__()
self.config = config
window_size = 7
self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
in_features=window_size*window_size,
out_channels=1,
hidden_dim=16,
layers_count=self.config.layers_count,
upscale_factor=self.config.upscale_factor
).get_lut_model(quantization_interval=self.config.quantization_interval)
def get_loss_fn(self):
ssim_loss = losses.SSIM(data_range=255)
l1_loss = losses.CharbonnierLoss()
def loss_fn(pred, target):
return ssim_loss(pred, target) + l1_loss(pred, target)
return loss_fn
TRANSFERER.register(ChebyKANNet, ChebyKANLut)

@ -1,505 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
# from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
# from pathlib import Path
# class RCLutCentered_3x3(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutCentered_3x3, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts, dense_conv_lut
# ):
# scale = int(dense_conv_lut.shape[-1])
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
# lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale)
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
# x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1]
# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
# x = x.view(b, c, x.shape[-2], x.shape[-1])
# return x
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
# ")"])
# class RCLutCentered_7x7(nn.Module):
# def __init__(
# self,
# window_size,
# quantization_interval,
# scale
# ):
# super(RCLutCentered_7x7, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts, dense_conv_lut
# ):
# scale = int(dense_conv_lut.shape[-1])
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
# lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale)
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
# # x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4)
# x = x.view(b, c, x.shape[-2], x.shape[-1])
# return x
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
# ")"])
# class RCLutRot90_3x3(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutRot90_3x3, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts, dense_conv_lut
# ):
# scale = int(dense_conv_lut.shape[-1])
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
# lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
# output += unrotated_prediction
# output /= 4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
# ")"])
# class RCLutRot90_7x7(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutRot90_7x7, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts, dense_conv_lut
# ):
# scale = int(dense_conv_lut.shape[-1])
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
# window_size = rc_conv_luts.shape[0]
# lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale)
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
# output += unrotated_prediction
# output /= 4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
# ")"])
# class RCLutx1(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutx1, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# rc_conv_luts_3x3, dense_conv_lut_3x3,
# rc_conv_luts_5x5, dense_conv_lut_5x5,
# rc_conv_luts_7x7, dense_conv_lut_7x7
# ):
# scale = int(dense_conv_lut_3x3.shape[-1])
# quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1)
# lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale)
# lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32))
# lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32))
# lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32))
# lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32))
# lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32))
# lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32))
# return lut_model
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}",
# f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}",
# f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}",
# f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}",
# f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}",
# f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}",
# ")"])
# class RCLutx2(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutx2, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
# ):
# scale = int(s2_dense_conv_lut_3x3.shape[-1])
# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
# lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale)
# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
# return lut_model
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
# ")"])
# class RCLutx2Centered(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(RCLutx2Centered, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# @staticmethod
# def init_from_numpy(
# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
# ):
# scale = int(s2_dense_conv_lut_3x3.shape[-1])
# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
# lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale)
# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
# return lut_model
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
# x = forward_rc_conv_centered(index=index, lut=rc_conv_lut)
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w).type(torch.float32)
# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
# k=-rotations_count,
# dims=[2, 3]
# )
# output += torch.rot90(
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
# k=-rotations_count,
# dims=[2, 3]
# )
# output /= 3*4
# output = output.view(b, c, output.shape[-2], output.shape[-1])
# return output
# def __repr__(self):
# return "\n".join([
# f"{self.__class__.__name__}(",
# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
# ")"])

@ -1,568 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from pathlib import Path
from common import lut
from . import rclut
from common import layers
# class ReconstructedConvCentered(nn.Module):
# def __init__(self, hidden_dim, window_size=7):
# super(ReconstructedConvCentered, self).__init__()
# self.window_size = window_size
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# def pixel_wise_forward(self, x):
# x = (x-127.5)/127.5
# out = torch.einsum('bwk,wh,wh -> bwk', x, self.projection1, self.projection2)
# out = torch.tanh(out)
# out = out*127.5 + 127.5
# return out
# def forward(self, x):
# original_shape = x.shape
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
# x = F.unfold(x, self.window_size)
# x = self.pixel_wise_forward(x)
# x = x.mean(1)
# x = x.reshape(*original_shape)
# x = round_func(x)
# return x
# def __repr__(self):
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
# class RCBlockCentered(nn.Module):
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
# super(RCBlockCentered, self).__init__()
# self.window_size = window_size
# self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size)
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
# def forward(self, x):
# b,c,hs,ws = x.shape
# x = self.rc_conv(x)
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
# x = self.dense_conv_block(x)
# return x
# class RCNetCentered_3x3(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetCentered_3x3, self).__init__()
# self.hidden_dim = hidden_dim
# self.layers_count = layers_count
# self.scale = scale
# self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# x = self.stage(x)
# x = x.view(b, c, h*self.scale, w*self.scale)
# return x
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# window_size = self.stage.rc_conv.window_size
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutCentered_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
# return lut_model
# class RCNetCentered_7x7(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetCentered_7x7, self).__init__()
# self.hidden_dim = hidden_dim
# self.layers_count = layers_count
# self.scale = scale
# window_size = 7
# self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size)
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# x = self.stage(x)
# x = x.view(b, c, h*self.scale, w*self.scale)
# return x
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# window_size = self.stage.rc_conv.window_size
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutCentered_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
# return lut_model
# class ReconstructedConvRot90(nn.Module):
# def __init__(self, hidden_dim, window_size=7):
# super(ReconstructedConvRot90, self).__init__()
# self.window_size = window_size
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# def pixel_wise_forward(self, x):
# x = (x-127.5)/127.5
# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
# out = torch.tanh(out)
# out = out*127.5 + 127.5
# return out
# def forward(self, x):
# original_shape = x.shape
# x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
# x = F.unfold(x, self.window_size)
# x = self.pixel_wise_forward(x)
# x = x.mean(1)
# x = x.reshape(*original_shape)
# x = round_func(x) # quality likely suffer from this
# return x
# def __repr__(self):
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
# class RCBlockRot90(nn.Module):
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
# super(RCBlockRot90, self).__init__()
# self.window_size = window_size
# self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size)
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
# def forward(self, x):
# b,c,hs,ws = x.shape
# x = self.rc_conv(x)
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
# x = self.dense_conv_block(x)
# return x
# class RCNetRot90_3x3(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetRot90_3x3, self).__init__()
# self.hidden_dim = hidden_dim
# self.layers_count = layers_count
# self.scale = scale
# self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# window_size = self.stage.rc_conv.window_size
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutRot90_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# rotated_prediction = self.stage(rotated)
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
# output += unrotated_prediction
# output /= 4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class RCNetRot90_7x7(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetRot90_7x7, self).__init__()
# self.hidden_dim = hidden_dim
# self.layers_count = layers_count
# self.scale = scale
# self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# window_size = self.stage.rc_conv.window_size
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutRot90_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# rotated_prediction = self.stage(rotated)
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
# output += unrotated_prediction
# output /= 4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class RCNetx1(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx1, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx1.init_from_numpy(
# rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3,
# rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5,
# rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class RCNetx2(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx2, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
# self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
# self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
# self.stage2_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage2_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage2_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx2.init_from_numpy(
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class RCNetx2Centered(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx2Centered, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
# self.stage1_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
# self.stage1_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
# self.stage2_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage2_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage2_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx2Centered.init_from_numpy(
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class ReconstructedConvRot90Unlutable(nn.Module):
# def __init__(self, hidden_dim, window_size=7):
# super(ReconstructedConvRot90Unlutable, self).__init__()
# self.window_size = window_size
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# def pixel_wise_forward(self, x):
# x = (x-127.5)/127.5
# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
# out = torch.tanh(out)
# out = out*127.5 + 127.5
# return out
# def forward(self, x):
# original_shape = x.shape
# x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
# x = F.unfold(x, self.window_size)
# x = self.pixel_wise_forward(x)
# x = x.mean(1)
# x = x.reshape(*original_shape)
# # x = round_func(x) # quality likely suffer from this
# return x
# def __repr__(self):
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
# class RCBlockRot90Unlutable(nn.Module):
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
# super(RCBlockRot90Unlutable, self).__init__()
# self.window_size = window_size
# self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size)
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
# def forward(self, x):
# b,c,hs,ws = x.shape
# x = self.rc_conv(x)
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
# x = self.dense_conv_block(x)
# return x
# class RCNetx2Unlutable(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx2Unlutable, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
# self.stage1_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
# self.stage1_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
# self.stage2_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage2_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx2.init_from_numpy(
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output
# class ReconstructedConvCenteredUnlutable(nn.Module):
# def __init__(self, hidden_dim, window_size=7):
# super(ReconstructedConvCenteredUnlutable, self).__init__()
# self.window_size = window_size
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
# def pixel_wise_forward(self, x):
# x = (x-127.5)/127.5
# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
# out = torch.tanh(out)
# out = out*127.5 + 127.5
# return out
# def forward(self, x):
# original_shape = x.shape
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
# x = F.unfold(x, self.window_size)
# x = self.pixel_wise_forward(x)
# x = x.mean(1)
# x = x.reshape(*original_shape)
# # x = round_func(x) # quality likely suffer from this
# return x
# def __repr__(self):
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
# class RCBlockCenteredUnlutable(nn.Module):
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
# super(RCBlockRot90Unlutable, self).__init__()
# self.window_size = window_size
# self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size)
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
# def forward(self, x):
# b,c,hs,ws = x.shape
# x = self.rc_conv(x)
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
# x = self.dense_conv_block(x)
# return x
# class RCNetx2CenteredUnlutable(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(RCNetx2CenteredUnlutable, self).__init__()
# self.scale = scale
# self.hidden_dim = hidden_dim
# self.stage1_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
# self.stage1_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
# self.stage1_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
# self.stage2_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
# self.stage2_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
# self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = rclut.RCLutx2Centered.init_from_numpy(
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
# )
# return lut_model
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.view(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# x = output
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
# output /= 3*4
# output = output.view(b, c, h*self.scale, w*self.scale)
# return output

@ -1,395 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import select_index_4dlut_tetrahedral
from common.layers import PercievePattern
from common.utils import round_func
class SDYLutx1(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutx1, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stageS, stageD, stageY
):
scale = int(stageS.shape[-1])
quantization_interval = 256//(stageS.shape[0]-1)
lut_model = SDYLutx1(quantization_interval=quantization_interval, scale=scale)
lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY)
output /= 3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stageS size: {self.stageS.shape}" + \
f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutx2(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutx2, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
):
scale = int(stage2_S.shape[-1])
quantization_interval = 256//(stage2_S.shape[0]-1)
lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale)
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
output /= 3
output = round_func(output)
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
output /= 3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stage1_S size: {self.stage1_S.shape}" + \
f"\n stage1_D size: {self.stage1_D.shape}" + \
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutx3(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutx3, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage3_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage3_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage3_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y, stage3_S, stage3_D, stage3_Y
):
scale = int(stage3_S.shape[-1])
quantization_interval = 256//(stage3_S.shape[0]-1)
lut_model = SDYLutx3(quantization_interval=quantization_interval, scale=scale)
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
lut_model.stage3_S = nn.Parameter(torch.tensor(stage3_S).type(torch.float32))
lut_model.stage3_D = nn.Parameter(torch.tensor(stage3_D).type(torch.float32))
lut_model.stage3_Y = nn.Parameter(torch.tensor(stage3_Y).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
output /= 3
output = round_func(output)
x = output
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage2_Y)
output /= 3
output = round_func(output)
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage3_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage3_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage3_Y)
output /= 3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stage1_S size: {self.stage1_S.shape}" + \
f"\n stage1_D size: {self.stage1_D.shape}" + \
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}" + \
f"\n stage3_S size: {self.stage3_S.shape}" + \
f"\n stage3_D size: {self.stage3_D.shape}" + \
f"\n stage3_Y size: {self.stage3_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutR90x1(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutR90x1, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stageS, stageD, stageY
):
scale = int(stageS.shape[-1])
quantization_interval = 256//(stageS.shape[0]-1)
lut_model = SDYLutR90x1(quantization_interval=quantization_interval, scale=scale)
lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stageS), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stageD), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stageY), k=-rotations_count, dims=[-2, -1])
output /= 4*3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stageS size: {self.stageS.shape}" + \
f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutR90x2(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutR90x2, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
):
scale = int(stage2_S.shape[-1])
quantization_interval = 256//(stage2_S.shape[0]-1)
lut_model = SDYLutR90x2(quantization_interval=quantization_interval, scale=scale)
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1])
output /= 4*3
x = round_func(output)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1])
output /= 4*3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stage1_S size: {self.stage1_S.shape}" + \
f"\n stage1_D size: {self.stage1_D.shape}" + \
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

File diff suppressed because it is too large Load Diff

@ -1,276 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import select_index_4dlut_tetrahedral
from common import layers
from common.utils import round_func
from models.base import SRNetBase
class SRLut(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLut, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLut(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w).type(torch.float32)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutY(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutY, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutY(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutR90, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.reshape(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90Y(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutR90Y, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1,4):
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90YCbCr(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutR90YCbCr, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutR90YCbCr(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1,4):
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

File diff suppressed because it is too large Load Diff

@ -87,7 +87,7 @@ if __name__ == "__main__":
for test_dataset_name in config.test_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.config.upscale_factor}",
color_model=config.color_model,
reset_cache=config.reset_cache,
)

@ -26,33 +26,37 @@ torch.backends.cudnn.benchmark = True
import argparse
from schedulefree import AdamWScheduleFree
from datetime import datetime
from types import SimpleNamespace
import signal
class SignalHandler:
def __init__(self, signal_code):
self.is_on = False
self.count = 0
signal.signal(signal_code, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print("Early stopping.")
self.is_on = True
self.count += 1
if self.count == 3:
exit(1)
signal_interraption_handler = SignalHandler(signal.SIGINT)
class TrainOptions:
def __init__(self):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False)
parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
parser.add_argument('--model', type=str, default='SRNet', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
# parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.")
parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.")
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
parser.add_argument('--upscale_factor', '-s', type=int, default=4, help="up scale factor")
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers")
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training")
parser.add_argument('--models_dir', type=str, default='../experiments/', help="experiment folder")
parser.add_argument('--experiment_dir', type=str, default='../experiments/', help="experiment folder")
parser.add_argument('--datasets_dir', type=str, default="../data/")
parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further')
parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations')
@ -69,25 +73,25 @@ class TrainOptions:
parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache')
parser.add_argument('--learning_rate', type=float, default=0.0025, help='Learning rate')
parser.add_argument('--grad_step', type=int, default=1, help='Optimizer step.')
self.parser = parser
def parse_args(self):
args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve()
args.models_dir = Path(args.models_dir).resolve()
args.model_path = Path(args.model_path) if not args.model_path is None else None
args.experiment_dir = Path(args.experiment_dir).resolve()
if Path(args.model).exists():
args.model = Path(args.model)
args.start_iter = int(args.model.stem.split("_")[-1])
args.train_datasets = args.train_datasets.split(',')
args.test_datasets = args.test_datasets.split(',')
if not args.model_path is None:
args.start_iter = int(args.model_path.stem.split("_")[-1])
args.quantization_interval = 2**(8-args.quantization_bits)
return args
def save_config(self, config):
yaml.dump(config, open(config.exp_dir / "config.yaml", 'w'))
def __repr__(self):
config = self.parse_args()
config = self.parse_args() if self.config is None else self.config
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()):
@ -99,13 +103,24 @@ class TrainOptions:
message += '----------------- End -------------------'
return message
def prepare_config(self):
def prepare_experiment(self):
config = self.parse_args()
if isinstance(config.model, Path):
model = LoadCheckpoint(config.model)
config.model = model.__class__.__name__
else:
config_dict = vars(config)
model = AVAILABLE_MODELS[config.model](
config = SimpleNamespace(**{k:config_dict[k] for k in config_dict.keys() if k in ['hidden_dim', 'layers_count', 'upscale_factor', 'quantization_interval']})
)
model = model.to(torch.device(config.device))
# model = torch.compile(model)
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}."
config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve()
config.exp_dir = (config.experiment_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.upscale_factor}").resolve()
if not config.exp_dir.exists():
config.exp_dir.mkdir()
@ -122,29 +137,10 @@ class TrainOptions:
if not config.logs_dir.exists():
config.logs_dir.mkdir()
return config
if __name__ == "__main__":
# torch.set_float32_matmul_precision('high')
script_start_time = datetime.now()
config_inst = TrainOptions()
config = config_inst.prepare_config()
if not config.model_path is None:
model = LoadCheckpoint(config.model_path)
config.model = model.__class__.__name__
else:
if 'net' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale, layers_count=config.layers_count)
if 'lut' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale)
model = model.to(torch.device(config.device))
# model = torch.compile(model)
optimizer = AdamWScheduleFree(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.95))
print(optimizer)
config_inst.save_config(config)
self.save_config(config)
self.config = config
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=config.logs_dir)
@ -154,15 +150,25 @@ if __name__ == "__main__":
config.writer = writer
config.logger = logger
config.logger.info(config_inst)
config.logger.info(self)
config.logger.info(model)
config.logger.info(optimizer)
return config, model, optimizer
if __name__ == "__main__":
# torch.set_float32_matmul_precision('high')
script_start_time = datetime.now()
config_inst = TrainOptions()
config, model, optimizer = config_inst.prepare_experiment()
# Training dataset
train_datasets = []
for train_dataset_name in config.train_datasets:
train_datasets.append(SRTrainDataset(
hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}",
lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.upscale_factor}",
patch_size = config.crop_size,
color_model = config.color_model,
reset_cache=config.reset_cache
@ -183,7 +189,7 @@ if __name__ == "__main__":
for test_dataset_name in config.test_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.upscale_factor}",
color_model = config.color_model,
reset_cache=config.reset_cache
)
@ -194,7 +200,7 @@ if __name__ == "__main__":
# TRAINING
i = config.start_iter
if not config.model_path is None:
if isinstance(config.model, Path):
config.current_iter = i
test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
@ -216,9 +222,10 @@ if __name__ == "__main__":
start_time = time.time()
with torch.set_grad_enabled(True):
pred = model(x=lr_patch, config=config)
pred = model(x=lr_patch, script_config=config)
loss = loss_fn(pred=pred, target=hr_patch) / config.grad_step
loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
if i % config.grad_step == 0 and i > 0:
optimizer.step()
@ -260,18 +267,18 @@ if __name__ == "__main__":
print("Saved to ", model_path)
# check if it is network or lut
if hasattr(model, 'get_lut_model'):
link = Path(config.models_dir / f"last_trained_net.pth")
if link.exists():
if 'net' in model.__class__.__name__.lower():
link = Path(config.experiment_dir / f"last_trained_net.pth")
if link.exists(follow_symlinks=False):
link.unlink()
link.symlink_to(model_path)
else:
link = Path(config.models_dir / f"last_trained_lut.pth")
if link.exists():
link = Path(config.experiment_dir / f"last_trained_lut.pth")
if link.exists(follow_symlinks=False):
link.unlink()
link.symlink_to(model_path)
link = Path(config.models_dir / f"last.pth")
if link.exists():
link = Path(config.experiment_dir / f"last.pth")
if link.exists(follow_symlinks=False):
link.unlink()
link.symlink_to(model_path)

@ -15,11 +15,13 @@ torch.backends.cudnn.benchmark = True
from datetime import datetime
import argparse
import models
from common.transferer import TRANSFERER
class TransferToLutOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', '-m', type=str, default='../models/last_trained_net.pth', help="model path folder")
self.parser.add_argument('--model_path', '-m', type=str, default='../experiments/last_trained_net.pth', help="model path folder")
self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")
@ -31,7 +33,7 @@ class TransferToLutOptions():
return args
def __repr__(self):
config = self.parser.parse_args()
config = self.parse_args()
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()):
@ -53,16 +55,15 @@ if __name__ == "__main__":
print(config_inst)
model = models.LoadCheckpoint(config.model_path).cuda()
if getattr(model, 'get_lut_model', None) is None:
if not 'net' in model.__class__.__name__.lower():
print("Transfer to lut can be applied only to the network model.")
exit(1)
print(model)
print()
print("Transfering:")
lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size)
lut_model = TRANSFERER.transfer(model, quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size)
print()
print(lut_model)
lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth"
models.SaveCheckpoint(model=lut_model, path=lut_path)

Loading…
Cancel
Save