main
protsenkovi 6 months ago
parent 0f54e2748b
commit df56917042

@ -367,8 +367,8 @@ class ChebyKANLayer(nn.Module):
self.outdim = out_features
self.degree = degree
self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
self.cheby_coeffs = nn.Parameter(torch.empty(in_features, out_features, degree + 1))
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (in_features * (degree + 1)))
self.register_buffer("arange", torch.arange(0, degree + 1, 1))
def forward(self, x):
@ -390,6 +390,7 @@ class ChebyKANLayer(nn.Module):
y = torch.einsum(
"btid,iod->bto", x, self.cheby_coeffs
) # shape = (batch_size, hw, outdim)
y = y.view(b, hw, self.outdim)
return y
@ -404,10 +405,10 @@ class UpscaleBlockChebyKAN(nn.Module):
self.linear_projections = []
for i in range(layers_count):
self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, bias=True))
self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim))
self.linear_projections = nn.ModuleList(self.linear_projections)
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2

@ -113,7 +113,7 @@ def transfer_rc_conv(rc_conv, quantization_interval=1):
print()
return lut
def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
def transfer_4_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2

@ -5,48 +5,17 @@ from . import srlut
from . import sdynet
from . import hdbnet
from . import hdblut
from models.base import SRNetBase
from common import losses
import torch
import numpy as np
from pathlib import Path
import inspect, sys
AVAILABLE_MODELS = {
'SRNet': srnet.SRNet,
'SRLut': srlut.SRLut,
'SRNetR90': srnet.SRNetR90,
'SRLutR90': srlut.SRLutR90,
'SRNetR90Y': srnet.SRNetR90Y,
'SRLutR90Y': srlut.SRLutR90Y,
'SDYNetx1': sdynet.SDYNetx1,
'SDYLutx1': sdylut.SDYLutx1,
'SDYNetx2': sdynet.SDYNetx2,
'SDYLutx2': sdylut.SDYLutx2,
'SDYNetx3': sdynet.SDYNetx3,
'SDYLutx3': sdylut.SDYLutx3,
'SDYNetR90x1': sdynet.SDYNetR90x1,
'SDYLutR90x1': sdylut.SDYLutR90x1,
'SDYNetR90x2': sdynet.SDYNetR90x2,
'SDYLutR90x2': sdylut.SDYLutR90x2,
'SRNetY': srnet.SRNetY,
'SRLutY': srlut.SRLutY,
'HDBNet': hdbnet.HDBNet,
'HDBLut': hdblut.HDBLut,
'HDBLNet': hdbnet.HDBLNet,
'HDBHNet': hdbnet.HDBHNet,
'SRMsbLsbNet': srnet.SRMsbLsbNet,
'SRMsbLsbShiftNet': srnet.SRMsbLsbShiftNet,
'SRMsbLsbR90Net': srnet.SRMsbLsbR90Net,
'SRMsbLsb4R90Net': srnet.SRMsbLsb4R90Net,
# 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
# 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
# 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,
# 'RCNetRot90_7x7': rcnet.RCNetRot90_7x7, 'RCLutRot90_7x7': rclut.RCLutRot90_7x7,
# 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1,
# 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2,
# 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered,
# 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable,
# 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable,
}
AVAILABLE_MODELS = {}
for module_name in sys.modules.keys():
if 'models.' in module_name:
AVAILABLE_MODELS.update({k:v for k,v in inspect.getmembers(sys.modules[module_name], lambda x: inspect.isclass(x) and SRNetBase in x.__bases__)})
def SaveCheckpoint(model, path):
model_container = {

@ -0,0 +1,27 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
class SRNetBase(nn.Module):
def __init__(self):
super(SRNetBase, self).__init__()
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -6,8 +6,9 @@ from pathlib import Path
from common.lut import select_index_4dlut_tetrahedral
from common import layers
from common.utils import round_func
from models.base import SRNetBase
class SRLut(nn.Module):
class SRLut(SRNetBase):
def __init__(
self,
quantization_interval,
@ -54,7 +55,7 @@ class SRLut(nn.Module):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutY(nn.Module):
class SRLutY(SRNetBase):
def __init__(
self,
quantization_interval,
@ -108,7 +109,7 @@ class SRLutY(nn.Module):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90(nn.Module):
class SRLutR90(SRNetBase):
def __init__(
self,
quantization_interval,
@ -161,7 +162,7 @@ class SRLutR90(nn.Module):
return loss_fn
class SRLutR90Y(nn.Module):
class SRLutR90Y(SRNetBase):
def __init__(
self,
quantization_interval,
@ -219,3 +220,57 @@ class SRLutR90Y(nn.Module):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90YCbCr(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutR90YCbCr, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutR90YCbCr(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1,4):
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -8,35 +8,37 @@ from pathlib import Path
from . import srlut
from common import layers
from common import losses
from models.base import SRNetBase
class SRNetBase(nn.Module):
def __init__(self):
super(SRNetBase, self).__init__()
class SRNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__()
self.scale = scale
self.stage1_S = layers.UpscaleBlock(
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
x = x.reshape(b*c, 1, h, w)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_numpy(stage_lut)
return lut_model
class SRNet(SRNetBase):
class SRNetChebyKan(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__()
super(SRNetChebyKan, self).__init__()
self.scale = scale
self.stage1_S = layers.UpscaleBlock(
self.stage1_S = layers.UpscaleBlockChebyKAN(
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
@ -51,7 +53,7 @@ class SRNet(SRNetBase):
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_numpy(stage_lut)
return lut_model
@ -83,7 +85,7 @@ class SRNetY(SRNetBase):
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutY.init_from_numpy(stage_lut)
return lut_model
@ -103,8 +105,7 @@ class SRNetR90(SRNetBase):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
for rotations_count in range(1,4):
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
@ -112,10 +113,36 @@ class SRNetR90(SRNetBase):
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
return lut_model
class SRNetChebyKanR90(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetChebyKanR90, self).__init__()
self.scale = scale
self.stage1_S = layers.UpscaleBlockChebyKAN(
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.reshape(b, c, h*self.scale, w*self.scale)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
return lut_model
class SRNetR90Y(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
@ -140,8 +167,7 @@ class SRNetR90Y(SRNetBase):
x = y.view(b, 1, h, w)
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
for rotations_count in range(1,4):
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
@ -149,6 +175,41 @@ class SRNetR90Y(SRNetBase):
output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90Yq.init_from_numpy(stage_lut)
return lut_model
class SRNetR90Ycbcr(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90Ycbcr, self).__init__()
self.scale = scale
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
def forward(self, x, config=None):
b,c,h,w = x.shape
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
x = y.view(b, 1, h, w)
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut)

@ -26,7 +26,7 @@ class ValOptions():
self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.")
self.parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Names of test datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar')

Loading…
Cancel
Save