main
protsenkovi 5 months ago
parent 0f54e2748b
commit df56917042

@ -367,8 +367,8 @@ class ChebyKANLayer(nn.Module):
self.outdim = out_features self.outdim = out_features
self.degree = degree self.degree = degree
self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1)) self.cheby_coeffs = nn.Parameter(torch.empty(in_features, out_features, degree + 1))
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1))) nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (in_features * (degree + 1)))
self.register_buffer("arange", torch.arange(0, degree + 1, 1)) self.register_buffer("arange", torch.arange(0, degree + 1, 1))
def forward(self, x): def forward(self, x):
@ -390,6 +390,7 @@ class ChebyKANLayer(nn.Module):
y = torch.einsum( y = torch.einsum(
"btid,iod->bto", x, self.cheby_coeffs "btid,iod->bto", x, self.cheby_coeffs
) # shape = (batch_size, hw, outdim) ) # shape = (batch_size, hw, outdim)
y = y.view(b, hw, self.outdim) y = y.view(b, hw, self.outdim)
return y return y
@ -404,10 +405,10 @@ class UpscaleBlockChebyKAN(nn.Module):
self.linear_projections = [] self.linear_projections = []
for i in range(layers_count): for i in range(layers_count):
self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, bias=True)) self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim))
self.linear_projections = nn.ModuleList(self.linear_projections) self.linear_projections = nn.ModuleList(self.linear_projections)
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
self.in_bias = self.in_scale = input_max_value/2 self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2 self.out_bias = self.out_scale = output_max_value/2

@ -113,7 +113,7 @@ def transfer_rc_conv(rc_conv, quantization_interval=1):
print() print()
return lut return lut
def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255): def transfer_4_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = (max_value+1)//quantization_interval bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1 scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2 lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2

@ -5,48 +5,17 @@ from . import srlut
from . import sdynet from . import sdynet
from . import hdbnet from . import hdbnet
from . import hdblut from . import hdblut
from models.base import SRNetBase
from common import losses from common import losses
import torch import torch
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import inspect, sys
AVAILABLE_MODELS = { AVAILABLE_MODELS = {}
'SRNet': srnet.SRNet, for module_name in sys.modules.keys():
'SRLut': srlut.SRLut, if 'models.' in module_name:
'SRNetR90': srnet.SRNetR90, AVAILABLE_MODELS.update({k:v for k,v in inspect.getmembers(sys.modules[module_name], lambda x: inspect.isclass(x) and SRNetBase in x.__bases__)})
'SRLutR90': srlut.SRLutR90,
'SRNetR90Y': srnet.SRNetR90Y,
'SRLutR90Y': srlut.SRLutR90Y,
'SDYNetx1': sdynet.SDYNetx1,
'SDYLutx1': sdylut.SDYLutx1,
'SDYNetx2': sdynet.SDYNetx2,
'SDYLutx2': sdylut.SDYLutx2,
'SDYNetx3': sdynet.SDYNetx3,
'SDYLutx3': sdylut.SDYLutx3,
'SDYNetR90x1': sdynet.SDYNetR90x1,
'SDYLutR90x1': sdylut.SDYLutR90x1,
'SDYNetR90x2': sdynet.SDYNetR90x2,
'SDYLutR90x2': sdylut.SDYLutR90x2,
'SRNetY': srnet.SRNetY,
'SRLutY': srlut.SRLutY,
'HDBNet': hdbnet.HDBNet,
'HDBLut': hdblut.HDBLut,
'HDBLNet': hdbnet.HDBLNet,
'HDBHNet': hdbnet.HDBHNet,
'SRMsbLsbNet': srnet.SRMsbLsbNet,
'SRMsbLsbShiftNet': srnet.SRMsbLsbShiftNet,
'SRMsbLsbR90Net': srnet.SRMsbLsbR90Net,
'SRMsbLsb4R90Net': srnet.SRMsbLsb4R90Net,
# 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
# 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
# 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,
# 'RCNetRot90_7x7': rcnet.RCNetRot90_7x7, 'RCLutRot90_7x7': rclut.RCLutRot90_7x7,
# 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1,
# 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2,
# 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered,
# 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable,
# 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable,
}
def SaveCheckpoint(model, path): def SaveCheckpoint(model, path):
model_container = { model_container = {

@ -0,0 +1,27 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
class SRNetBase(nn.Module):
def __init__(self):
super(SRNetBase, self).__init__()
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -6,8 +6,9 @@ from pathlib import Path
from common.lut import select_index_4dlut_tetrahedral from common.lut import select_index_4dlut_tetrahedral
from common import layers from common import layers
from common.utils import round_func from common.utils import round_func
from models.base import SRNetBase
class SRLut(nn.Module): class SRLut(SRNetBase):
def __init__( def __init__(
self, self,
quantization_interval, quantization_interval,
@ -54,7 +55,7 @@ class SRLut(nn.Module):
return F.mse_loss(pred/255, target/255) return F.mse_loss(pred/255, target/255)
return loss_fn return loss_fn
class SRLutY(nn.Module): class SRLutY(SRNetBase):
def __init__( def __init__(
self, self,
quantization_interval, quantization_interval,
@ -108,7 +109,7 @@ class SRLutY(nn.Module):
return F.mse_loss(pred/255, target/255) return F.mse_loss(pred/255, target/255)
return loss_fn return loss_fn
class SRLutR90(nn.Module): class SRLutR90(SRNetBase):
def __init__( def __init__(
self, self,
quantization_interval, quantization_interval,
@ -161,7 +162,7 @@ class SRLutR90(nn.Module):
return loss_fn return loss_fn
class SRLutR90Y(nn.Module): class SRLutR90Y(SRNetBase):
def __init__( def __init__(
self, self,
quantization_interval, quantization_interval,
@ -219,3 +220,57 @@ class SRLutR90Y(nn.Module):
def loss_fn(pred, target): def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255) return F.mse_loss(pred/255, target/255)
return loss_fn return loss_fn
class SRLutR90YCbCr(SRNetBase):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutR90YCbCr, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutR90YCbCr(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1,4):
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -8,35 +8,37 @@ from pathlib import Path
from . import srlut from . import srlut
from common import layers from common import layers
from common import losses from common import losses
from models.base import SRNetBase
class SRNetBase(nn.Module): class SRNet(SRNetBase):
def __init__(self): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetBase, self).__init__() super(SRNet, self).__init__()
self.scale = scale
self.stage1_S = layers.UpscaleBlock(
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = percieve_pattern(x) x = x.reshape(b*c, 1, h, w)
x = stage(x) x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
x = round_func(x) x = x.reshape(b, c, h*self.scale, w*self.scale)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_numpy(stage_lut)
def get_loss_fn(self): return lut_model
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNet(SRNetBase): class SRNetChebyKan(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__() super(SRNetChebyKan, self).__init__()
self.scale = scale self.scale = scale
self.stage1_S = layers.UpscaleBlock( self.stage1_S = layers.UpscaleBlockChebyKAN(
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
layers_count=layers_count, layers_count=layers_count,
upscale_factor=self.scale upscale_factor=self.scale
@ -51,7 +53,7 @@ class SRNet(SRNetBase):
return x return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_numpy(stage_lut) lut_model = srlut.SRLut.init_from_numpy(stage_lut)
return lut_model return lut_model
@ -83,7 +85,7 @@ class SRNetY(SRNetBase):
return output return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutY.init_from_numpy(stage_lut) lut_model = srlut.SRLutY.init_from_numpy(stage_lut)
return lut_model return lut_model
@ -103,8 +105,7 @@ class SRNetR90(SRNetBase):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w) x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) for rotations_count in range(4):
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4 output /= 4
@ -112,10 +113,36 @@ class SRNetR90(SRNetBase):
return output return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90.init_from_numpy(stage_lut) lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
return lut_model return lut_model
class SRNetChebyKanR90(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetChebyKanR90, self).__init__()
self.scale = scale
self.stage1_S = layers.UpscaleBlockChebyKAN(
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.reshape(b, c, h*self.scale, w*self.scale)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
return lut_model
class SRNetR90Y(SRNetBase): class SRNetR90Y(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
@ -140,8 +167,7 @@ class SRNetR90Y(SRNetBase):
x = y.view(b, 1, h, w) x = y.view(b, 1, h, w)
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) for rotations_count in range(4):
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4 output /= 4
@ -149,6 +175,41 @@ class SRNetR90Y(SRNetBase):
output = self.ycbcr_to_rgb(output).clamp(0, 255) output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_4_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90Yq.init_from_numpy(stage_lut)
return lut_model
class SRNetR90Ycbcr(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90Ycbcr, self).__init__()
self.scale = scale
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
def forward(self, x, config=None):
b,c,h,w = x.shape
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
x = y.view(b, 1, h, w)
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut) lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut)

@ -26,7 +26,7 @@ class ValOptions():
self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.") self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.") self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.")
self.parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Names of test datasets.") self.parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Names of test datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar') self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar')

Loading…
Cancel
Save