main
protsenkovi 6 months ago
parent 55adc280be
commit 5c22148bed

@ -5,19 +5,35 @@ import numpy as np
from .utils import round_func
class PercievePattern():
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2):
assert window_size >= (np.max(receptive_field_idxes)+1)
receptive_field_idxes = np.array(receptive_field_idxes)
"""
Coordinates scheme: [channel, height, width]. Channel can be ommited for all channels.
Examples:
1. receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]
2. receptive_field_idxes=[[0,0,0],[0,1,0],[1,1,0],[1,1]]
"""
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, channels=1):
assert window_size >= (np.max([x for y in receptive_field_idxes for x in y])+1)
tmp = []
for coords in receptive_field_idxes:
if len(coords) < 3:
for i in range(channels):
tmp.append([i,] + coords)
if len(coords) == 3:
tmp.append(coords)
receptive_field_idxes = np.array(sorted(tmp))
self.window_size = window_size
self.center = center
self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size + receptive_field_idxes[i,1] for i in range(len(receptive_field_idxes))]
self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size*self.window_size + receptive_field_idxes[i,1]*self.window_size + receptive_field_idxes[i,2] for i in range(len(receptive_field_idxes))]
assert len(np.unique(self.receptive_field_idxes) == len(self.receptive_field_idxes)), "Duplicated coordinates found. Coordinates scheme: [channel, height, width]."
def __call__(self, x):
b,c,h,w = x.shape
x = F.pad(
x,
pad=[self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1],
pad=[
self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1
],
mode='replicate'
)
x = F.unfold(input=x, kernel_size=self.window_size)
@ -452,3 +468,62 @@ class UpscaleBlockEffKAN(nn.Module):
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
return x
class ComplexGaborLayer2D(nn.Module):
'''
Implicit representation with complex Gabor nonlinearity with 2D activation function
https://github.com/vishwa91/wire/blob/main/modules/wire2d.py
Inputs;
in_features: Input features
out_features; Output features
bias: if True, enable bias for the linear operation
is_first: Legacy SIREN parameter
omega_0: Legacy SIREN parameter
omega0: Frequency of Gabor sinusoid term
sigma0: Scaling of Gabor Gaussian term
trainable: If True, omega and sigma are trainable parameters
'''
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega0=10.0, sigma0=10.0,
trainable=False):
super().__init__()
self.omega_0 = omega0
self.scale_0 = sigma0
self.is_first = is_first
self.in_features = in_features
if self.is_first:
dtype = torch.float
else:
dtype = torch.cfloat
# Set trainable parameters if they are to be simultaneously optimized
self.omega_0 = nn.Parameter(self.omega_0*torch.ones(1), trainable)
self.scale_0 = nn.Parameter(self.scale_0*torch.ones(1), trainable)
self.linear = nn.Linear(in_features,
out_features,
bias=bias,
dtype=dtype)
# Second Gaussian window
self.scale_orth = nn.Linear(in_features,
out_features,
bias=bias,
dtype=dtype)
def forward(self, input):
lin = self.linear(input)
scale_x = lin
scale_y = self.scale_orth(input)
freq_term = torch.exp(1j*self.omega_0*lin)
arg = scale_x.abs().square() + scale_y.abs().square()
gauss_term = torch.exp(-self.scale_0*self.scale_0*arg)
return freq_term*gauss_term

@ -92,7 +92,7 @@ def test_steps(model, datasets, config, log_prefix="", print_progress = False):
total_area += lr_area
row = [
dataset_name,
f"{dataset_name} {config.color_model}",
np.mean(psnrs),
np.mean(ssims),
np.mean(run_times_ns)*1e-9,

@ -8,14 +8,15 @@ class SRNetBase(nn.Module):
def __init__(self):
super(SRNetBase, self).__init__()
def forward_stage(self, x, scale, percieve_pattern, stage):
def forward_stage(self, x, percieve_pattern, stage):
b,c,h,w = x.shape
scale = stage.upscale_factor
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
x = x.reshape(b, 1, h, w, scale, scale)
x = x.permute(0, 1, 2, 4, 3, 5)
x = x.reshape(b, 1, h*scale, w*scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):

@ -113,11 +113,12 @@ class HDBNet(SRNetBase):
return loss_fn
class HDBNetv2(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4, rotations = 4):
def __init__(self, hidden_dim = 64, layers_count = 2, scale = 4, rotations = 4):
super(HDBNetv2, self).__init__()
assert scale == 4
self.scale = scale
self.rotations = rotations
self.layers_count = layers_count
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
@ -208,8 +209,55 @@ class HDBNetv2(SRNetBase):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNet, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
class HDBLNet(nn.Module):
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = self.forward_stage(msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
output_lsb = self.forward_stage(lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
output_msb /= 3
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNetR90(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNet, self).__init__()
self.scale = scale
@ -243,17 +291,75 @@ class HDBLNet(nn.Module):
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L), k=-rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4
output_msb = round_func((output_msb / 255) * 16) * 15
output_lsb = (output_lsb / 255) * 15
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
x = output_msb + output_lsb
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNetR90KAN(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNetR90KAN, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3D = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3B = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3L = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4
output = output_msb + output_lsb
x = output.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
@ -262,39 +368,21 @@ class HDBLNet(nn.Module):
return F.mse_loss(pred/255, target/255)
return loss_fn
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
# stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
# stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
# stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
# stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = hdblut.HDBLut.init_from_numpy(
# stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
# stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
# )
# return lut_model
class HDBHNet(nn.Module):
class HDBHNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBHNet, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.msb_fns = nn.ModuleList([layers.UpscaleBlock(
self.msb_fns = SRNetBaseList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(1)])
self.lsb_fns = nn.ModuleList([layers.UpscaleBlock(
self.lsb_fns = SRNetBaseList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,

@ -92,6 +92,54 @@ class SDYNetx2(SRNetBase):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetx2Inv(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx2Inv, self).__init__()
self.scale = scale
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage2_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage2_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = 0.0
output += self.forward_stage(x, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, self._extract_pattern_Y, self.stage1_Y)
output /= 3
x = output
output = 0.0
output += self.forward_stage(x, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, self._extract_pattern_Y, self.stage2_Y)
output /= 3
x = (output + x)/2
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage1_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetx3(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx3, self).__init__()
@ -151,7 +199,7 @@ class SDYNetx3(SRNetBase):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetR90x1(nn.Module):
class SDYNetR90x1(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetR90x1, self).__init__()
self.scale = scale
@ -162,16 +210,6 @@ class SDYNetR90x1(nn.Module):
self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
@ -219,10 +257,7 @@ class SDYNetR90x2(SRNetBase):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output_1 += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output_1 += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output_1 += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
for rotations_count in range(1,4):
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
@ -230,10 +265,7 @@ class SDYNetR90x2(SRNetBase):
output_1 /= 4*3
x = output_1
output_2 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_2 += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
output_2 += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
output_2 += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
for rotations_count in range(1,4):
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1])
@ -257,3 +289,249 @@ class SDYNetR90x2(SRNetBase):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYEHONetR90x1(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYEHONetR90x1, self).__init__()
self.scale = scale
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_E = layers.PercievePattern(receptive_field_idxes=[[0,0],[3,0],[0,3],[3,3]], center=[0,0], window_size=4)
self._extract_pattern_H = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,2],[2,3],[3,2]], center=[0,0], window_size=4)
self._extract_pattern_O = layers.PercievePattern(receptive_field_idxes=[[0,0],[3,1],[2,2],[1,3]], center=[0,0], window_size=4)
self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_E = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_H = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_O = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output_1 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_E, self.stage1_E), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_H, self.stage1_H), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_O, self.stage1_O), k=-rotations_count, dims=[-2, -1])
output_1 /= 4*6
x = output_1
x = x.view(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYEHONetR90x2(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYEHONetR90x2, self).__init__()
self.scale = scale
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_E = layers.PercievePattern(receptive_field_idxes=[[0,0],[3,0],[0,3],[3,3]], center=[0,0], window_size=4)
self._extract_pattern_H = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,2],[2,3],[3,2]], center=[0,0], window_size=4)
self._extract_pattern_O = layers.PercievePattern(receptive_field_idxes=[[0,0],[3,1],[2,2],[1,3]], center=[0,0], window_size=4)
self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_E = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_H = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_O = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage2_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_E = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_H = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_O = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_E, self.stage1_E), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_H, self.stage1_H), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_O, self.stage1_O), k=-rotations_count, dims=[-2, -1])
output_1 /= 4*6
x = output_1
output_2 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_E, self.stage2_E), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_H, self.stage2_H), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_O, self.stage2_O), k=-rotations_count, dims=[-2, -1])
output_2 /= 4*6
x = output_2
x = x.view(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYMixNetx1(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYMixNetx1, self).__init__()
self.scale = scale
self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes=[[1,1],[1,2],[2,1],[2,2]], center=[1,1], window_size=4)
self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes=[[1,1],[1,3],[3,1],[3,3]], center=[1,1], window_size=4)
self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes=[[1,1],[2,2],[3,2],[2,3]], center=[1,1], window_size=4)
self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes=[[1,1],[0,0],[0,1],[1,0]], center=[1,1], window_size=4)
self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes=[[1,1],[0,2],[0,3]], center=[1,1], window_size=4)
self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes=[[1,1],[2,0],[3,0]], center=[1,1], window_size=4)
self.stage1_1 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_5 = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_6 = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_Mix = layers.UpscaleBlockChebyKAN(in_features=6, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1)
output = output.permute(0, 2, 3, 1).view(b, h*w, 6)
output = self.stage1_Mix(output)
output = output.view(b, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5)
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYMixNetx1v2(SRNetBase):
"""
22
12 23 32 21
11 13 33 31
10 14 34 30
01 03 43 41
00 04 44 40
"""
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYMixNetx1v2, self).__init__()
self.scale = scale
self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes=[[2,2]], center=[2,2], window_size=5)
self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes=[[1,2],[2,3],[3,2],[2,1]], center=[2,2], window_size=5)
self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes=[[1,1],[1,3],[3,3],[3,1]], center=[2,2], window_size=5)
self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes=[[1,0],[1,4],[3,4],[3,0]], center=[2,2], window_size=5)
self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes=[[0,1],[0,3],[4,3],[4,1]], center=[2,2], window_size=5)
self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,4],[4,4],[4,0]], center=[2,2], window_size=5)
self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_Mix = layers.UpscaleBlockChebyKAN(in_features=6, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1)
output = output.permute(0, 2, 3, 1).view(b, h*w, 6)
output = self.stage1_Mix(output)
output = output.view(b, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5)
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYMixNetx1v3(SRNetBase):
"""
22
12 23 32 21
11 13 33 31
10 14 34 30
01 03 43 41
00 04 44 40
"""
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYMixNetx1v3, self).__init__()
self.scale = scale
self._extract_pattern_0 = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=6)
self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes=[[2,2]], center=[2,2], window_size=5)
self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes=[[1,2],[2,3],[3,2],[2,1]], center=[2,2], window_size=5)
self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes=[[1,1],[1,3],[3,3],[3,1]], center=[2,2], window_size=5)
self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes=[[1,0],[1,4],[3,4],[3,0]], center=[2,2], window_size=5)
self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes=[[0,1],[0,3],[4,3],[4,1]], center=[2,2], window_size=5)
self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,4],[4,4],[4,0]], center=[2,2], window_size=5)
self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_Mix = layers.UpscaleBlockChebyKAN(in_features=6, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1)
output = self.forward_stage(output, self._extract_pattern_0, self.stage1_Mix)
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -24,7 +24,7 @@ class SRNet(SRNetBase):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
x = self.forward_stage(x, self._extract_pattern_S, self.stage1_S)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
@ -48,7 +48,7 @@ class SRNetChebyKan(SRNetBase):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
x = self.forward_stage(x, self._extract_pattern_S, self.stage1_S)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
@ -79,7 +79,7 @@ class SRNetY(SRNetBase):
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
x = y.view(b, 1, h, w)
output = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
output = self.forward_stage(x, self._extract_pattern_S, self.stage1_S)
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output
@ -107,7 +107,7 @@ class SRNetR90(SRNetBase):
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.reshape(b, c, h*self.scale, w*self.scale)
return output
@ -134,7 +134,7 @@ class SRNetChebyKanR90(SRNetBase):
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.reshape(b, c, h*self.scale, w*self.scale)
return output
@ -169,7 +169,7 @@ class SRNetR90Y(SRNetBase):
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)
@ -205,7 +205,7 @@ class SRNetR90Ycbcr(SRNetBase):
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255)
return output
@ -248,8 +248,8 @@ class SRMsbLsbNet(SRNetBase):
lsb = x % 16
msb = x - lsb
output_msb = self.forward_stage(msb, self.scale, self._extract_pattern_S, self.msb_fn)
output_lsb = self.forward_stage(lsb, self.scale, self._extract_pattern_S, self.lsb_fn)
output_msb = self.forward_stage(msb, self._extract_pattern_S, self.msb_fn)
output_lsb = self.forward_stage(lsb, self._extract_pattern_S, self.lsb_fn)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
@ -262,6 +262,53 @@ class SRMsbLsbNet(SRNetBase):
raise NotImplementedError
class SRMsbLsbNetChebyKAN(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRMsbLsbNetChebyKAN, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.msb_fn = layers.UpscaleBlockChebyKAN(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=255,
output_max_value=255
)
self.lsb_fn = layers.UpscaleBlockChebyKAN(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=15,
output_max_value=255
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = self.forward_stage(msb, self._extract_pattern_S, self.msb_fn)
output_lsb = self.forward_stage(lsb, self._extract_pattern_S, self.lsb_fn)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
class SRMsbLsbShiftNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRMsbLsbShiftNet, self).__init__()
@ -316,6 +363,59 @@ class SRMsbLsbShiftNet(SRNetBase):
raise NotImplementedError
class SRMsbLsbCenterShiftNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4, count = 4):
super(SRMsbLsbCenterShiftNet, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.count = count
self.msb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=255,
output_max_value=255
) for x in range(self.count)])
self.lsb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=15,
output_max_value=255
) for x in range(self.count)])
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for (i,j), msb_fn, lsb_fn in zip([[-2,-2],[-2,2],[2,-2],[2,2]], self.msb_fns, self.lsb_fns):
output_msb_s = self.forward_stage(msb, self.scale, self._extract_pattern_S, msb_fn)
output_lsb_s = self.forward_stage(lsb, self.scale, self._extract_pattern_S, lsb_fn)
output_msb += torch.nn.functional.pad(output_msb_s, [2, 2, 2, 2], mode='replicate')[:,:,2+i:2+i+h*self.scale,2+j:2+j+w*self.scale]
output_lsb += torch.nn.functional.pad(output_lsb_s, [2, 2, 2, 2], mode='replicate')[:,:,2+i:2+i+h*self.scale,2+j:2+j+w*self.scale]
output_msb /= self.count
output_lsb /= self.count
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
class SRMsbLsbR90Net(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRMsbLsbR90Net, self).__init__()

@ -24,27 +24,29 @@ import argparse
class TestOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
self.parser.add_argument('--model_path', type=str, default="../experiments/last.pth", help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.")
self.parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Names of test datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar')
self.parser.add_argument('--print_progress', type=bool, default=True, help='Show progres bar')
self.parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache')
self.parser.add_argument('--test_worker_num', default=1, type=int, help='Test parallelism. Use 1 for time measurement.')
def parse_args(self):
args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve()
args.test_datasets = args.test_datasets.split(',')
args.exp_dir = Path(args.model_path).resolve().parent.parent
print(args.exp_dir)
args.model_path = Path(args.model_path).resolve()
args.model_name = args.model_path.stem
args.test_dir = Path(args.exp_dir).resolve() / 'test'
if not args.test_dir.exists():
args.test_dir.mkdir()
args.current_iter = int(args.model_name.split('_')[-1])
args.results_path = os.path.join(args.test_dir, f'results_{args.model_name}_{args.device}.csv')
args.results_path = os.path.join(args.test_dir, f'results_{args.model_name}_{args.color_model}_{args.device}.csv')
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.test_dir)
logger_name = f'test_{args.model_path.stem}'
@ -90,7 +92,7 @@ if __name__ == "__main__":
reset_cache=config.reset_cache,
)
results = test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress,)
results = test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.print_progress)
results.to_csv(config.results_path)
print()

@ -17,6 +17,7 @@ from common.data import SRTrainDataset, SRTestDataset
from common.utils import logger_info
from common.metrics import PSNR, cal_ssim
from common.color import _rgb2ycbcr, PIL_CONVERT_COLOR
import yaml
from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS
from common.test import test_steps
@ -38,20 +39,22 @@ class TrainOptions:
parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers")
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training")
parser.add_argument('--models_dir', type=str, default='../models/', help="experiment folder")
parser.add_argument('--models_dir', type=str, default='../experiments/', help="experiment folder")
parser.add_argument('--datasets_dir', type=str, default="../data/")
parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further')
parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations')
parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration')
parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration')
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--loader_worker_num', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--test_worker_num', type=int, default=1, help="Test parallelism. Use 1 for time measurement.")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
parser.add_argument('--color_model', type=str, default="RGB", help=f"Color model for train and test dataset. Choose from: {list(PIL_CONVERT_COLOR.keys())}")
parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache')
parser.add_argument('--learning_rate', type=float, default=0.0025, help='Learning rate')
self.parser = parser
@ -66,6 +69,9 @@ class TrainOptions:
args.start_iter = int(args.model_path.stem.split("_")[-1])
return args
def save_config(self, config):
yaml.dump(config, open(config.exp_dir / "config.yaml", 'w'))
def __repr__(self):
config = self.parse_args()
message = ''
@ -79,7 +85,9 @@ class TrainOptions:
message += '----------------- End -------------------'
return message
def prepare_experiment_folder(config):
def prepare_config(self):
config = self.parse_args()
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}."
@ -100,12 +108,14 @@ def prepare_experiment_folder(config):
if not config.logs_dir.exists():
config.logs_dir.mkdir()
return config
if __name__ == "__main__":
# torch.set_float32_matmul_precision('high')
script_start_time = datetime.now()
config_inst = TrainOptions()
config = config_inst.parse_args()
config = config_inst.prepare_config()
if not config.model_path is None:
model = LoadCheckpoint(config.model_path)
@ -117,10 +127,10 @@ if __name__ == "__main__":
model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale)
model = model.to(torch.device(config.device))
# model = torch.compile(model)
optimizer = AdamWScheduleFree(model.parameters(), betas=(0.9, 0.95))
optimizer = AdamWScheduleFree(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.95))
print(optimizer)
prepare_experiment_folder(config)
config_inst.save_config(config)
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=config.logs_dir)
@ -147,7 +157,7 @@ if __name__ == "__main__":
train_loader = DataLoader(
dataset = train_dataset,
batch_size = config.batch_size,
num_workers = config.worker_num,
num_workers = config.loader_worker_num,
shuffle = True,
drop_last = False,
pin_memory = True,
@ -172,7 +182,7 @@ if __name__ == "__main__":
i = config.start_iter
if not config.model_path is None:
config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
loss_fn = model.get_loss_fn()
for i in range(config.start_iter + 1, config.total_iter + 1):

Loading…
Cancel
Save