new framework impl. added chebykan and linear models.
parent
67bd678763
commit
d34cc7833e
File diff suppressed because one or more lines are too long
@ -0,0 +1,24 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from common.utils import round_func
|
||||
from common import layers
|
||||
import copy
|
||||
|
||||
class SRBase(nn.Module):
|
||||
def __init__(self):
|
||||
super(SRBase, self).__init__()
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
# def get_loss_fn(self):
|
||||
# ssim_loss = losses.SSIM(data_range=255)
|
||||
# l1_loss = losses.CharbonnierLoss()
|
||||
# def loss_fn(pred, target):
|
||||
# return ssim_loss(pred, target) + l1_loss(pred, target)
|
||||
# return loss_fn
|
||||
|
@ -0,0 +1,26 @@
|
||||
from common import layers
|
||||
|
||||
class Transferer():
|
||||
def __init__(self):
|
||||
self.registered_types = {}
|
||||
|
||||
def register(self, input_class, output_class):
|
||||
self.registered_types[input_class] = output_class
|
||||
|
||||
def transfer(self, input_model, batch_size, quantization_interval):
|
||||
input_class = input_model.__class__
|
||||
if not input_class in self.registered_types:
|
||||
raise Exception(f"No transfer function is registered for class {input_class}")
|
||||
transfered_model = self.transfer_model(input_model, self.registered_types[input_class], batch_size, quantization_interval)
|
||||
return transfered_model
|
||||
|
||||
def transfer_model(self, model, output_model_class, batch_size, quantization_interval):
|
||||
qmodel = output_model_class(config = model.config)
|
||||
model.config.quantization_interval = quantization_interval
|
||||
for attr, value in model.named_children():
|
||||
if isinstance(value, layers.UpscaleBlock):
|
||||
getattr(qmodel, attr).stage = getattr(model, attr).stage.get_lut_model(quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
return qmodel
|
||||
|
||||
TRANSFERER = Transferer()
|
@ -1,28 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from common.utils import round_func
|
||||
|
||||
class SRNetBase(nn.Module):
|
||||
def __init__(self):
|
||||
super(SRNetBase, self).__init__()
|
||||
|
||||
def forward_stage(self, x, percieve_pattern, stage):
|
||||
b,c,h,w = x.shape
|
||||
scale = stage.upscale_factor
|
||||
x = percieve_pattern(x)
|
||||
x = stage(x)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, 1, h, w, scale, scale)
|
||||
x = x.permute(0, 1, 2, 4, 3, 5)
|
||||
x = x.reshape(b, 1, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
@ -1,134 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from common.lut import select_index_4dlut_tetrahedral
|
||||
from common import layers
|
||||
from common.utils import round_func
|
||||
|
||||
class HDBLut(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(HDBLut, self).__init__()
|
||||
assert scale == 4
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
|
||||
self.stage1_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
|
||||
self.stage1_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
|
||||
self.stage1_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
|
||||
self.stage1_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
|
||||
self.stage1_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
|
||||
|
||||
self.stage2_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
|
||||
self.stage2_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
|
||||
self.stage2_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
|
||||
self.stage2_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
|
||||
self.stage2_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
|
||||
|
||||
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
|
||||
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
|
||||
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
|
||||
):
|
||||
# quantization_interval = 256//(stage1_3H.shape[0]-1)
|
||||
quantization_interval = 16
|
||||
lut_model = HDBLut(quantization_interval=quantization_interval, scale=4)
|
||||
lut_model.stage1_3H = nn.Parameter(torch.tensor(stage1_3H).type(torch.float32))
|
||||
lut_model.stage1_3D = nn.Parameter(torch.tensor(stage1_3D).type(torch.float32))
|
||||
lut_model.stage1_3B = nn.Parameter(torch.tensor(stage1_3B).type(torch.float32))
|
||||
lut_model.stage1_2H = nn.Parameter(torch.tensor(stage1_2H).type(torch.float32))
|
||||
lut_model.stage1_2D = nn.Parameter(torch.tensor(stage1_2D).type(torch.float32))
|
||||
|
||||
lut_model.stage2_3H = nn.Parameter(torch.tensor(stage2_3H).type(torch.float32))
|
||||
lut_model.stage2_3D = nn.Parameter(torch.tensor(stage2_3D).type(torch.float32))
|
||||
lut_model.stage2_3B = nn.Parameter(torch.tensor(stage2_3B).type(torch.float32))
|
||||
lut_model.stage2_2H = nn.Parameter(torch.tensor(stage2_2H).type(torch.float32))
|
||||
lut_model.stage2_2D = nn.Parameter(torch.tensor(stage2_2D).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
print(np.prod(x.shape))
|
||||
x = percieve_pattern(x)
|
||||
shifts = torch.tensor([lut.shape[0]**d for d in range(len(lut.shape)-2)], device=x.device).flip(0).reshape(1,1,len(lut.shape)-2)
|
||||
print(x.shape, x.min(), x.max())
|
||||
x = torch.sum(x * shifts, dim=-1)
|
||||
print(x.shape)
|
||||
lut = torch.clamp(lut, 0, 255)
|
||||
lut = lut.reshape(-1, scale, scale)
|
||||
x = x.flatten().type(torch.int64)
|
||||
x = lut[x]
|
||||
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
print(x.shape)
|
||||
# raise RuntimeError
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= 4*3
|
||||
output_lsb /= 4*2
|
||||
print(output_msb.min(), output_msb.max())
|
||||
print(output_lsb.min(), output_lsb.max())
|
||||
|
||||
output_msb = output_msb + output_lsb
|
||||
x = output_msb
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
|
||||
print("STAGE2", msb.min(), msb.max(), lsb.min(), lsb.max())
|
||||
for rotations_count in range(4):
|
||||
rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= 4*3
|
||||
output_lsb /= 4*2
|
||||
output_msb = output_msb + output_lsb
|
||||
x = output_msb
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}" + \
|
||||
f"\n stage1_3H size: {self.stage1_3H.shape}" + \
|
||||
f"\n stage1_3D size: {self.stage1_3D.shape}" + \
|
||||
f"\n stage1_3B size: {self.stage1_3B.shape}" + \
|
||||
f"\n stage1_2H size: {self.stage1_2H.shape}" + \
|
||||
f"\n stage1_2D size: {self.stage1_2D.shape}" + \
|
||||
f"\n stage2_3H size: {self.stage2_3H.shape}" + \
|
||||
f"\n stage2_3D size: {self.stage2_3D.shape}" + \
|
||||
f"\n stage2_3B size: {self.stage2_3B.shape}" + \
|
||||
f"\n stage2_2H size: {self.stage2_2H.shape}" + \
|
||||
f"\n stage2_2D size: {self.stage2_2D.shape}"
|
@ -1,441 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from common.utils import round_func
|
||||
from common import lut
|
||||
from pathlib import Path
|
||||
from . import hdblut
|
||||
from common import layers
|
||||
from itertools import cycle
|
||||
from models.base import SRNetBase
|
||||
|
||||
class HDBNet(SRNetBase):
|
||||
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4, rotations = 4):
|
||||
super(HDBNet, self).__init__()
|
||||
assert scale == 4
|
||||
self.scale = scale
|
||||
self.rotations = rotations
|
||||
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
|
||||
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
|
||||
|
||||
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
|
||||
self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
|
||||
|
||||
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
|
||||
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = stage(x)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(self.rotations):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msb = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H) + \
|
||||
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D) + \
|
||||
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B)
|
||||
output_msb /= 3
|
||||
output_lsb = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H) + \
|
||||
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D)
|
||||
output_lsb /= 2
|
||||
if not config is None and config.current_iter % config.display_step == 0:
|
||||
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
|
||||
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
|
||||
output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
|
||||
output /= self.rotations
|
||||
x = output
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(self.rotations):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msb = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H) + \
|
||||
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D) + \
|
||||
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B)
|
||||
output_msb /= 3
|
||||
output_lsb = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H) + \
|
||||
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D)
|
||||
output_lsb /= 2
|
||||
if not config is None and config.current_iter % config.display_step == 0:
|
||||
config.writer.add_histogram('s2_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
|
||||
config.writer.add_histogram('s2_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
|
||||
output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
|
||||
output /= self.rotations
|
||||
x = output
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
lut_model = hdblut.HDBLut.init_from_numpy(
|
||||
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
|
||||
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
|
||||
)
|
||||
return lut_model
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
class HDBNetv2(SRNetBase):
|
||||
def __init__(self, hidden_dim = 64, layers_count = 2, scale = 4, rotations = 4):
|
||||
super(HDBNetv2, self).__init__()
|
||||
assert scale == 4
|
||||
self.scale = scale
|
||||
self.rotations = rotations
|
||||
self.layers_count = layers_count
|
||||
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
|
||||
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
|
||||
|
||||
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
|
||||
self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
|
||||
self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
|
||||
|
||||
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
|
||||
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(self.rotations):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msbt = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H) + \
|
||||
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D) + \
|
||||
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B)
|
||||
output_lsbt = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H) + \
|
||||
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D)
|
||||
if not config is None and config.current_iter % config.display_step == 0:
|
||||
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
|
||||
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
|
||||
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= self.rotations*3
|
||||
output_lsb /= self.rotations*2
|
||||
output = output_msb + output_lsb
|
||||
x = output.clamp(0, 255)
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(self.rotations):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msbt = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H) + \
|
||||
self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D) + \
|
||||
self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B)
|
||||
output_lsbt = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H) + \
|
||||
self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D)
|
||||
if not config is None and config.current_iter % config.display_step == 0:
|
||||
config.writer.add_histogram('s2_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
|
||||
config.writer.add_histogram('s2_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
|
||||
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= self.rotations*3
|
||||
output_lsb /= self.rotations*2
|
||||
output = output_msb + output_lsb
|
||||
x = output.clamp(0, 255)
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
lut_model = hdblut.HDBLut.init_from_numpy(
|
||||
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
|
||||
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
|
||||
)
|
||||
return lut_model
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
class HDBLNet(SRNetBase):
|
||||
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
super(HDBLNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
|
||||
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = stage(x)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = self.forward_stage(msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
|
||||
self.forward_stage(msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
|
||||
self.forward_stage(msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
|
||||
output_lsb = self.forward_stage(lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
|
||||
output_msb /= 3
|
||||
if not config is None and config.current_iter % config.display_step == 0:
|
||||
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
|
||||
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
|
||||
output = output_msb + output_lsb
|
||||
x = output.clamp(0, 255)
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
|
||||
class HDBLNetR90(SRNetBase):
|
||||
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
super(HDBLNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
|
||||
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = stage(x)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
|
||||
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
|
||||
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
|
||||
output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
|
||||
if not config is None and config.current_iter % config.display_step == 0:
|
||||
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
|
||||
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
|
||||
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= 4*3
|
||||
output_lsb /= 4
|
||||
output = output_msb + output_lsb
|
||||
x = output.clamp(0, 255)
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
|
||||
class HDBLNetR90KAN(SRNetBase):
|
||||
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
super(HDBLNetR90KAN, self).__init__()
|
||||
self.scale = scale
|
||||
self.stage1_3H = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
self.stage1_3D = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
self.stage1_3B = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
self.stage1_3L = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
|
||||
|
||||
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = stage(x)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
|
||||
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
|
||||
self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
|
||||
output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
|
||||
if not config is None and config.current_iter % config.display_step == 0:
|
||||
config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
|
||||
config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
|
||||
output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= 4*3
|
||||
output_lsb /= 4
|
||||
output = output_msb + output_lsb
|
||||
x = output.clamp(0, 255)
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
|
||||
class HDBHNet(SRNetBase):
|
||||
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
super(HDBHNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.hidden_dim = hidden_dim
|
||||
self.layers_count = layers_count
|
||||
|
||||
self.msb_fns = SRNetBaseList([layers.UpscaleBlock(
|
||||
in_features=4,
|
||||
hidden_dim=hidden_dim,
|
||||
layers_count=layers_count,
|
||||
upscale_factor=self.scale
|
||||
) for x in range(1)])
|
||||
self.lsb_fns = SRNetBaseList([layers.UpscaleBlock(
|
||||
in_features=4,
|
||||
hidden_dim=hidden_dim,
|
||||
layers_count=layers_count,
|
||||
upscale_factor=self.scale
|
||||
) for x in range(1)])
|
||||
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = stage(x)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
|
||||
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
for rotations_count, msb_fn, lsb_fn in zip(range(4), cycle(self.msb_fns), cycle(self.lsb_fns)):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, msb_fn)
|
||||
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, lsb_fn)
|
||||
output_msb_r = round_func((output_msb_r / 255)*16) * 15
|
||||
output_lsb_r = (output_lsb_r / 255) * 15
|
||||
output_msb += torch.rot90(output_msb_r, k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(output_lsb_r, k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= 4
|
||||
output_lsb /= 4
|
||||
if not config is None and config.current_iter % config.display_step == 0:
|
||||
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
|
||||
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
|
||||
x = output_msb + output_lsb
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_loss_fn(self):
|
||||
fourier_loss_fn = FocalFrequencyLoss()
|
||||
high_frequency_loss_fn = FourierLoss()
|
||||
def loss_fn(pred, target):
|
||||
a = fourier_loss_fn(pred/255, target/255) * 1e8
|
||||
# b = F.mse_loss(pred/255, target/255) #* 1e3
|
||||
# c = high_frequency_loss_fn(pred/255, target/255) * 1e6
|
||||
return a #+ b #+ c
|
||||
return loss_fn
|
@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from common.utils import round_func
|
||||
from common import lut
|
||||
from pathlib import Path
|
||||
from common import layers
|
||||
from common import losses
|
||||
from common.base import SRBase
|
||||
from common.transferer import TRANSFERER
|
||||
|
||||
class SRNetBase(SRBase):
|
||||
def __init__(self):
|
||||
super(SRNetBase, self).__init__()
|
||||
self.config = None
|
||||
self.stage1_S = layers.UpscaleBlock(None)
|
||||
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
def forward(self, x, script_config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
x = self.stage1_S(x, self._extract_pattern_S)
|
||||
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
|
||||
return x
|
||||
|
||||
class SRNet(SRNetBase):
|
||||
def __init__(self, config):
|
||||
super(SRNet, self).__init__()
|
||||
self.config = config
|
||||
self.stage1_S.stage = layers.LinearUpscaleBlockNet(hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
|
||||
|
||||
class SRLut(SRNetBase):
|
||||
def __init__(self, config):
|
||||
super(SRLut, self).__init__()
|
||||
self.config = config
|
||||
self.stage1_S.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
|
||||
|
||||
TRANSFERER.register(SRNet, SRLut)
|
||||
|
||||
|
||||
class ChebyKANBase(SRBase):
|
||||
def __init__(self):
|
||||
super(ChebyKANBase, self).__init__()
|
||||
self.config = None
|
||||
self.stage1_S = layers.UpscaleBlock(None)
|
||||
window_size = 7
|
||||
self._extract_pattern = layers.PercievePattern(
|
||||
receptive_field_idxes=[[i,j] for i in range(window_size) for j in range(window_size)],
|
||||
center=[window_size//2,window_size//2],
|
||||
window_size=window_size
|
||||
)
|
||||
|
||||
def forward(self, x, script_config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
x = self.stage1_S(x, self._extract_pattern)
|
||||
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
|
||||
return x
|
||||
|
||||
class ChebyKANNet(ChebyKANBase):
|
||||
def __init__(self, config):
|
||||
super(ChebyKANNet, self).__init__()
|
||||
self.config = config
|
||||
window_size = 7
|
||||
self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
|
||||
in_features=window_size*window_size,
|
||||
out_channels=1,
|
||||
hidden_dim=16,
|
||||
layers_count=self.config.layers_count,
|
||||
upscale_factor=self.config.upscale_factor,
|
||||
degree=8
|
||||
)
|
||||
|
||||
class ChebyKANLut(ChebyKANBase):
|
||||
def __init__(self, config):
|
||||
super(ChebyKANLut, self).__init__()
|
||||
self.config = config
|
||||
window_size = 7
|
||||
self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
|
||||
in_features=window_size*window_size,
|
||||
out_channels=1,
|
||||
hidden_dim=16,
|
||||
layers_count=self.config.layers_count,
|
||||
upscale_factor=self.config.upscale_factor
|
||||
).get_lut_model(quantization_interval=self.config.quantization_interval)
|
||||
|
||||
def get_loss_fn(self):
|
||||
ssim_loss = losses.SSIM(data_range=255)
|
||||
l1_loss = losses.CharbonnierLoss()
|
||||
def loss_fn(pred, target):
|
||||
return ssim_loss(pred, target) + l1_loss(pred, target)
|
||||
return loss_fn
|
||||
|
||||
TRANSFERER.register(ChebyKANNet, ChebyKANLut)
|
@ -1,505 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from common.utils import round_func
|
||||
# from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
|
||||
# from pathlib import Path
|
||||
|
||||
# class RCLutCentered_3x3(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(RCLutCentered_3x3, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# rc_conv_luts, dense_conv_lut
|
||||
# ):
|
||||
# scale = int(dense_conv_lut.shape[-1])
|
||||
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
||||
# lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale)
|
||||
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
||||
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
|
||||
# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
|
||||
# x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1]
|
||||
# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
|
||||
# x = x.view(b, c, x.shape[-2], x.shape[-1])
|
||||
# return x
|
||||
|
||||
# def __repr__(self):
|
||||
# return "\n".join([
|
||||
# f"{self.__class__.__name__}(",
|
||||
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
||||
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
||||
# ")"])
|
||||
|
||||
# class RCLutCentered_7x7(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# window_size,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(RCLutCentered_7x7, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# rc_conv_luts, dense_conv_lut
|
||||
# ):
|
||||
# scale = int(dense_conv_lut.shape[-1])
|
||||
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
||||
# lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale)
|
||||
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
||||
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
|
||||
# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
|
||||
# # x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4)
|
||||
# x = x.view(b, c, x.shape[-2], x.shape[-1])
|
||||
# return x
|
||||
|
||||
# def __repr__(self):
|
||||
# return "\n".join([
|
||||
# f"{self.__class__.__name__}(",
|
||||
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
||||
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
||||
# ")"])
|
||||
|
||||
# class RCLutRot90_3x3(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(RCLutRot90_3x3, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# rc_conv_luts, dense_conv_lut
|
||||
# ):
|
||||
# scale = int(dense_conv_lut.shape[-1])
|
||||
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
||||
# lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
|
||||
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
||||
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
|
||||
# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
|
||||
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
||||
# output += unrotated_prediction
|
||||
# output /= 4
|
||||
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
||||
# return output
|
||||
|
||||
# def __repr__(self):
|
||||
# return "\n".join([
|
||||
# f"{self.__class__.__name__}(",
|
||||
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
||||
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
||||
# ")"])
|
||||
|
||||
# class RCLutRot90_7x7(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(RCLutRot90_7x7, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# rc_conv_luts, dense_conv_lut
|
||||
# ):
|
||||
# scale = int(dense_conv_lut.shape[-1])
|
||||
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
||||
# window_size = rc_conv_luts.shape[0]
|
||||
# lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale)
|
||||
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
||||
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
|
||||
# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
|
||||
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
||||
# output += unrotated_prediction
|
||||
# output /= 4
|
||||
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
||||
# return output
|
||||
|
||||
# def __repr__(self):
|
||||
# return "\n".join([
|
||||
# f"{self.__class__.__name__}(",
|
||||
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
||||
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
||||
# ")"])
|
||||
|
||||
# class RCLutx1(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(RCLutx1, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# rc_conv_luts_3x3, dense_conv_lut_3x3,
|
||||
# rc_conv_luts_5x5, dense_conv_lut_5x5,
|
||||
# rc_conv_luts_7x7, dense_conv_lut_7x7
|
||||
# ):
|
||||
# scale = int(dense_conv_lut_3x3.shape[-1])
|
||||
# quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1)
|
||||
|
||||
# lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale)
|
||||
|
||||
# lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32))
|
||||
# lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32))
|
||||
|
||||
# lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32))
|
||||
# lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32))
|
||||
|
||||
# lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32))
|
||||
# lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32))
|
||||
|
||||
# return lut_model
|
||||
|
||||
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
||||
# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
|
||||
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output /= 3*4
|
||||
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
||||
# return output
|
||||
|
||||
# def __repr__(self):
|
||||
# return "\n".join([
|
||||
# f"{self.__class__.__name__}(",
|
||||
# f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}",
|
||||
# f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}",
|
||||
# f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}",
|
||||
# f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}",
|
||||
# f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}",
|
||||
# f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}",
|
||||
# ")"])
|
||||
|
||||
|
||||
|
||||
# class RCLutx2(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(RCLutx2, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
|
||||
# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
|
||||
# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
|
||||
# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
|
||||
# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
|
||||
# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
|
||||
# ):
|
||||
# scale = int(s2_dense_conv_lut_3x3.shape[-1])
|
||||
# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
|
||||
|
||||
# lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale)
|
||||
|
||||
# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
|
||||
# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
|
||||
|
||||
# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
|
||||
# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
|
||||
|
||||
# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
|
||||
# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
|
||||
|
||||
# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
|
||||
# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
|
||||
|
||||
# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
|
||||
# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
|
||||
|
||||
# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
|
||||
# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
|
||||
|
||||
# return lut_model
|
||||
|
||||
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
||||
# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
|
||||
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output /= 3*4
|
||||
# x = output
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output /= 3*4
|
||||
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
||||
# return output
|
||||
|
||||
# def __repr__(self):
|
||||
# return "\n".join([
|
||||
# f"{self.__class__.__name__}(",
|
||||
# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
|
||||
# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
|
||||
# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
|
||||
# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
|
||||
# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
|
||||
# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
|
||||
# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
|
||||
# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
|
||||
# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
|
||||
# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
|
||||
# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
|
||||
# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
|
||||
# ")"])
|
||||
|
||||
|
||||
|
||||
# class RCLutx2Centered(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(RCLutx2Centered, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
||||
# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
|
||||
# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
|
||||
# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
|
||||
# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
|
||||
# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
|
||||
# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
|
||||
# ):
|
||||
# scale = int(s2_dense_conv_lut_3x3.shape[-1])
|
||||
# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
|
||||
|
||||
# lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale)
|
||||
|
||||
# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
|
||||
# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
|
||||
|
||||
# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
|
||||
# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
|
||||
|
||||
# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
|
||||
# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
|
||||
|
||||
# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
|
||||
# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
|
||||
|
||||
# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
|
||||
# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
|
||||
|
||||
# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
|
||||
# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
|
||||
|
||||
# return lut_model
|
||||
|
||||
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
||||
# x = forward_rc_conv_centered(index=index, lut=rc_conv_lut)
|
||||
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output /= 3*4
|
||||
# x = output
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output += torch.rot90(
|
||||
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
|
||||
# k=-rotations_count,
|
||||
# dims=[2, 3]
|
||||
# )
|
||||
# output /= 3*4
|
||||
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
||||
# return output
|
||||
|
||||
# def __repr__(self):
|
||||
# return "\n".join([
|
||||
# f"{self.__class__.__name__}(",
|
||||
# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
|
||||
# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
|
||||
# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
|
||||
# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
|
||||
# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
|
||||
# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
|
||||
# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
|
||||
# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
|
||||
# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
|
||||
# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
|
||||
# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
|
||||
# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
|
||||
# ")"])
|
@ -1,568 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from common.utils import round_func
|
||||
from pathlib import Path
|
||||
from common import lut
|
||||
from . import rclut
|
||||
from common import layers
|
||||
|
||||
# class ReconstructedConvCentered(nn.Module):
|
||||
# def __init__(self, hidden_dim, window_size=7):
|
||||
# super(ReconstructedConvCentered, self).__init__()
|
||||
# self.window_size = window_size
|
||||
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
||||
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
||||
|
||||
# def pixel_wise_forward(self, x):
|
||||
# x = (x-127.5)/127.5
|
||||
# out = torch.einsum('bwk,wh,wh -> bwk', x, self.projection1, self.projection2)
|
||||
# out = torch.tanh(out)
|
||||
# out = out*127.5 + 127.5
|
||||
# return out
|
||||
|
||||
# def forward(self, x):
|
||||
# original_shape = x.shape
|
||||
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
|
||||
# x = F.unfold(x, self.window_size)
|
||||
# x = self.pixel_wise_forward(x)
|
||||
# x = x.mean(1)
|
||||
# x = x.reshape(*original_shape)
|
||||
# x = round_func(x)
|
||||
# return x
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
|
||||
|
||||
# class RCBlockCentered(nn.Module):
|
||||
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
|
||||
# super(RCBlockCentered, self).__init__()
|
||||
# self.window_size = window_size
|
||||
# self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size)
|
||||
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,hs,ws = x.shape
|
||||
# x = self.rc_conv(x)
|
||||
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
|
||||
# x = self.dense_conv_block(x)
|
||||
# return x
|
||||
|
||||
# class RCNetCentered_3x3(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(RCNetCentered_3x3, self).__init__()
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.layers_count = layers_count
|
||||
# self.scale = scale
|
||||
# self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w)
|
||||
# x = self.stage(x)
|
||||
# x = x.view(b, c, h*self.scale, w*self.scale)
|
||||
# return x
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# window_size = self.stage.rc_conv.window_size
|
||||
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
|
||||
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
# lut_model = rclut.RCLutCentered_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
|
||||
# return lut_model
|
||||
|
||||
# class RCNetCentered_7x7(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(RCNetCentered_7x7, self).__init__()
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.layers_count = layers_count
|
||||
# self.scale = scale
|
||||
# window_size = 7
|
||||
# self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size)
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w)
|
||||
# x = self.stage(x)
|
||||
# x = x.view(b, c, h*self.scale, w*self.scale)
|
||||
# return x
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# window_size = self.stage.rc_conv.window_size
|
||||
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
|
||||
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
# lut_model = rclut.RCLutCentered_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
|
||||
# return lut_model
|
||||
|
||||
|
||||
# class ReconstructedConvRot90(nn.Module):
|
||||
# def __init__(self, hidden_dim, window_size=7):
|
||||
# super(ReconstructedConvRot90, self).__init__()
|
||||
# self.window_size = window_size
|
||||
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
||||
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
||||
|
||||
# def pixel_wise_forward(self, x):
|
||||
# x = (x-127.5)/127.5
|
||||
# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
|
||||
# out = torch.tanh(out)
|
||||
# out = out*127.5 + 127.5
|
||||
# return out
|
||||
|
||||
# def forward(self, x):
|
||||
# original_shape = x.shape
|
||||
# x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
|
||||
# x = F.unfold(x, self.window_size)
|
||||
# x = self.pixel_wise_forward(x)
|
||||
# x = x.mean(1)
|
||||
# x = x.reshape(*original_shape)
|
||||
# x = round_func(x) # quality likely suffer from this
|
||||
# return x
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
|
||||
|
||||
# class RCBlockRot90(nn.Module):
|
||||
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
|
||||
# super(RCBlockRot90, self).__init__()
|
||||
# self.window_size = window_size
|
||||
# self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size)
|
||||
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,hs,ws = x.shape
|
||||
# x = self.rc_conv(x)
|
||||
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
|
||||
# x = self.dense_conv_block(x)
|
||||
|
||||
# return x
|
||||
|
||||
# class RCNetRot90_3x3(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(RCNetRot90_3x3, self).__init__()
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.layers_count = layers_count
|
||||
# self.scale = scale
|
||||
# self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# window_size = self.stage.rc_conv.window_size
|
||||
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
|
||||
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
# lut_model = rclut.RCLutRot90_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w)
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# rotated_prediction = self.stage(rotated)
|
||||
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
||||
# output += unrotated_prediction
|
||||
# output /= 4
|
||||
# output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
# return output
|
||||
|
||||
# class RCNetRot90_7x7(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(RCNetRot90_7x7, self).__init__()
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.layers_count = layers_count
|
||||
# self.scale = scale
|
||||
# self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# window_size = self.stage.rc_conv.window_size
|
||||
# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
|
||||
# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
# lut_model = rclut.RCLutRot90_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w)
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# rotated_prediction = self.stage(rotated)
|
||||
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
||||
# output += unrotated_prediction
|
||||
# output /= 4
|
||||
# output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
# return output
|
||||
|
||||
# class RCNetx1(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(RCNetx1, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
|
||||
# self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
|
||||
# self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
||||
# dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
||||
# dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
||||
# dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# lut_model = rclut.RCLutx1.init_from_numpy(
|
||||
# rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3,
|
||||
# rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5,
|
||||
# rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7
|
||||
# )
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w)
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
||||
|
||||
# output /= 3*4
|
||||
# output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
# return output
|
||||
|
||||
|
||||
# class RCNetx2(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(RCNetx2, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
|
||||
# self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
|
||||
# self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
|
||||
# self.stage2_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
|
||||
# self.stage2_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
|
||||
# self.stage2_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
||||
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
||||
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
||||
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
||||
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
||||
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
||||
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# lut_model = rclut.RCLutx2.init_from_numpy(
|
||||
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
|
||||
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
|
||||
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
|
||||
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
|
||||
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
|
||||
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
|
||||
# )
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w)
|
||||
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 3*4
|
||||
# x = output
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 3*4
|
||||
# output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
# return output
|
||||
|
||||
# class RCNetx2Centered(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(RCNetx2Centered, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.stage1_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
|
||||
# self.stage1_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
|
||||
# self.stage1_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
|
||||
# self.stage2_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
|
||||
# self.stage2_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
|
||||
# self.stage2_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
||||
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
||||
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
||||
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
||||
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
||||
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
||||
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# lut_model = rclut.RCLutx2Centered.init_from_numpy(
|
||||
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
|
||||
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
|
||||
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
|
||||
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
|
||||
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
|
||||
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
|
||||
# )
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w)
|
||||
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 3*4
|
||||
# x = output
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 3*4
|
||||
# output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
# return output
|
||||
|
||||
# class ReconstructedConvRot90Unlutable(nn.Module):
|
||||
# def __init__(self, hidden_dim, window_size=7):
|
||||
# super(ReconstructedConvRot90Unlutable, self).__init__()
|
||||
# self.window_size = window_size
|
||||
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
||||
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
||||
|
||||
# def pixel_wise_forward(self, x):
|
||||
# x = (x-127.5)/127.5
|
||||
# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
|
||||
# out = torch.tanh(out)
|
||||
# out = out*127.5 + 127.5
|
||||
# return out
|
||||
|
||||
# def forward(self, x):
|
||||
# original_shape = x.shape
|
||||
# x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
|
||||
# x = F.unfold(x, self.window_size)
|
||||
# x = self.pixel_wise_forward(x)
|
||||
# x = x.mean(1)
|
||||
# x = x.reshape(*original_shape)
|
||||
# # x = round_func(x) # quality likely suffer from this
|
||||
# return x
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
|
||||
|
||||
# class RCBlockRot90Unlutable(nn.Module):
|
||||
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
|
||||
# super(RCBlockRot90Unlutable, self).__init__()
|
||||
# self.window_size = window_size
|
||||
# self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size)
|
||||
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,hs,ws = x.shape
|
||||
# x = self.rc_conv(x)
|
||||
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
|
||||
# x = self.dense_conv_block(x)
|
||||
# return x
|
||||
|
||||
# class RCNetx2Unlutable(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(RCNetx2Unlutable, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.stage1_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
|
||||
# self.stage1_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
|
||||
# self.stage1_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
|
||||
# self.stage2_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
|
||||
# self.stage2_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
|
||||
# self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
||||
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
||||
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
||||
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
||||
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
||||
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
||||
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# lut_model = rclut.RCLutx2.init_from_numpy(
|
||||
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
|
||||
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
|
||||
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
|
||||
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
|
||||
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
|
||||
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
|
||||
# )
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w)
|
||||
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 3*4
|
||||
# x = output
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 3*4
|
||||
# output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
# return output
|
||||
|
||||
|
||||
|
||||
# class ReconstructedConvCenteredUnlutable(nn.Module):
|
||||
# def __init__(self, hidden_dim, window_size=7):
|
||||
# super(ReconstructedConvCenteredUnlutable, self).__init__()
|
||||
# self.window_size = window_size
|
||||
# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
||||
# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
||||
|
||||
# def pixel_wise_forward(self, x):
|
||||
# x = (x-127.5)/127.5
|
||||
# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
|
||||
# out = torch.tanh(out)
|
||||
# out = out*127.5 + 127.5
|
||||
# return out
|
||||
|
||||
# def forward(self, x):
|
||||
# original_shape = x.shape
|
||||
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
|
||||
# x = F.unfold(x, self.window_size)
|
||||
# x = self.pixel_wise_forward(x)
|
||||
# x = x.mean(1)
|
||||
# x = x.reshape(*original_shape)
|
||||
# # x = round_func(x) # quality likely suffer from this
|
||||
# return x
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
|
||||
|
||||
# class RCBlockCenteredUnlutable(nn.Module):
|
||||
# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
|
||||
# super(RCBlockRot90Unlutable, self).__init__()
|
||||
# self.window_size = window_size
|
||||
# self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size)
|
||||
# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,hs,ws = x.shape
|
||||
# x = self.rc_conv(x)
|
||||
# x = F.pad(x, pad=[0,1,0,1], mode='replicate')
|
||||
# x = self.dense_conv_block(x)
|
||||
# return x
|
||||
|
||||
# class RCNetx2CenteredUnlutable(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(RCNetx2CenteredUnlutable, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.stage1_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
|
||||
# self.stage1_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
|
||||
# self.stage1_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
|
||||
# self.stage2_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
|
||||
# self.stage2_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
|
||||
# self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
||||
# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
||||
# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
||||
# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
||||
# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
||||
# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
||||
# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
|
||||
# lut_model = rclut.RCLutx2Centered.init_from_numpy(
|
||||
# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
|
||||
# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
|
||||
# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
|
||||
# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
|
||||
# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
|
||||
# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
|
||||
# )
|
||||
# return lut_model
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.view(b*c, 1, h, w)
|
||||
# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 3*4
|
||||
# x = output
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
# for rotations_count in range(4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 3*4
|
||||
# output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
# return output
|
||||
|
@ -1,395 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from common.lut import select_index_4dlut_tetrahedral
|
||||
from common.layers import PercievePattern
|
||||
from common.utils import round_func
|
||||
|
||||
class SDYLutx1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SDYLutx1, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stageS, stageD, stageY
|
||||
):
|
||||
scale = int(stageS.shape[-1])
|
||||
quantization_interval = 256//(stageS.shape[0]-1)
|
||||
lut_model = SDYLutx1(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
|
||||
lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
|
||||
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY)
|
||||
output /= 3
|
||||
output = round_func(output)
|
||||
output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}" + \
|
||||
f"\n stageS size: {self.stageS.shape}" + \
|
||||
f"\n stageD size: {self.stageD.shape}" + \
|
||||
f"\n stageY size: {self.stageY.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
class SDYLutx2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SDYLutx2, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
|
||||
):
|
||||
scale = int(stage2_S.shape[-1])
|
||||
quantization_interval = 256//(stage2_S.shape[0]-1)
|
||||
lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
|
||||
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
|
||||
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
|
||||
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
|
||||
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
|
||||
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
|
||||
output /= 3
|
||||
output = round_func(output)
|
||||
x = output
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
|
||||
output /= 3
|
||||
output = round_func(output)
|
||||
output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}" + \
|
||||
f"\n stage1_S size: {self.stage1_S.shape}" + \
|
||||
f"\n stage1_D size: {self.stage1_D.shape}" + \
|
||||
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
|
||||
f"\n stage2_S size: {self.stage2_S.shape}" + \
|
||||
f"\n stage2_D size: {self.stage2_D.shape}" + \
|
||||
f"\n stage2_Y size: {self.stage2_Y.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
|
||||
class SDYLutx3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SDYLutx3, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage3_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stage3_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stage3_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y, stage3_S, stage3_D, stage3_Y
|
||||
):
|
||||
scale = int(stage3_S.shape[-1])
|
||||
quantization_interval = 256//(stage3_S.shape[0]-1)
|
||||
lut_model = SDYLutx3(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
|
||||
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
|
||||
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
|
||||
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
|
||||
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
|
||||
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
|
||||
lut_model.stage3_S = nn.Parameter(torch.tensor(stage3_S).type(torch.float32))
|
||||
lut_model.stage3_D = nn.Parameter(torch.tensor(stage3_D).type(torch.float32))
|
||||
lut_model.stage3_Y = nn.Parameter(torch.tensor(stage3_Y).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
|
||||
output /= 3
|
||||
output = round_func(output)
|
||||
x = output
|
||||
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage2_S)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage2_D)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage2_Y)
|
||||
output /= 3
|
||||
output = round_func(output)
|
||||
x = output
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage3_S)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage3_D)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage3_Y)
|
||||
output /= 3
|
||||
output = round_func(output)
|
||||
output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}" + \
|
||||
f"\n stage1_S size: {self.stage1_S.shape}" + \
|
||||
f"\n stage1_D size: {self.stage1_D.shape}" + \
|
||||
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
|
||||
f"\n stage2_S size: {self.stage2_S.shape}" + \
|
||||
f"\n stage2_D size: {self.stage2_D.shape}" + \
|
||||
f"\n stage2_Y size: {self.stage2_Y.shape}" + \
|
||||
f"\n stage3_S size: {self.stage3_S.shape}" + \
|
||||
f"\n stage3_D size: {self.stage3_D.shape}" + \
|
||||
f"\n stage3_Y size: {self.stage3_Y.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
|
||||
class SDYLutR90x1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SDYLutR90x1, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stageS, stageD, stageY
|
||||
):
|
||||
scale = int(stageS.shape[-1])
|
||||
quantization_interval = 256//(stageS.shape[0]-1)
|
||||
lut_model = SDYLutR90x1(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
|
||||
lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
|
||||
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY)
|
||||
for rotations_count in range(1, 4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
|
||||
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stageS), k=-rotations_count, dims=[-2, -1])
|
||||
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stageD), k=-rotations_count, dims=[-2, -1])
|
||||
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stageY), k=-rotations_count, dims=[-2, -1])
|
||||
output /= 4*3
|
||||
output = round_func(output)
|
||||
output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}" + \
|
||||
f"\n stageS size: {self.stageS.shape}" + \
|
||||
f"\n stageD size: {self.stageD.shape}" + \
|
||||
f"\n stageY size: {self.stageY.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
|
||||
class SDYLutR90x2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SDYLutR90x2, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
|
||||
):
|
||||
scale = int(stage2_S.shape[-1])
|
||||
quantization_interval = 256//(stage2_S.shape[0]-1)
|
||||
lut_model = SDYLutR90x2(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
|
||||
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
|
||||
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
|
||||
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
|
||||
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
|
||||
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
|
||||
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
|
||||
for rotations_count in range(1, 4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
|
||||
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
|
||||
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
|
||||
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1])
|
||||
output /= 4*3
|
||||
x = round_func(output)
|
||||
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
|
||||
for rotations_count in range(1, 4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
|
||||
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1])
|
||||
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1])
|
||||
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1])
|
||||
output /= 4*3
|
||||
output = round_func(output)
|
||||
output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}" + \
|
||||
f"\n stage1_S size: {self.stage1_S.shape}" + \
|
||||
f"\n stage1_D size: {self.stage1_D.shape}" + \
|
||||
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
|
||||
f"\n stage2_S size: {self.stage2_S.shape}" + \
|
||||
f"\n stage2_D size: {self.stage2_D.shape}" + \
|
||||
f"\n stage2_Y size: {self.stage2_Y.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
File diff suppressed because it is too large
Load Diff
@ -1,276 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from common.lut import select_index_4dlut_tetrahedral
|
||||
from common import layers
|
||||
from common.utils import round_func
|
||||
from models.base import SRNetBase
|
||||
|
||||
class SRLut(SRNetBase):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SRLut, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = SRLut(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w).type(torch.float32)
|
||||
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
class SRLutY(SRNetBase):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SRLutY, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
self.rgb_to_ycbcr = layers.RgbToYcbcr()
|
||||
self.ycbcr_to_rgb = layers.YcbcrToRgb()
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = SRLutY(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = self.rgb_to_ycbcr(x)
|
||||
y = x[:,0:1,:,:]
|
||||
cbcr = x[:,1:,:,:]
|
||||
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
|
||||
|
||||
output = self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
|
||||
output = torch.cat([output, cbcr_scaled], dim=1)
|
||||
output = self.ycbcr_to_rgb(output).clamp(0, 255)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
class SRLutR90(SRNetBase):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SRLutR90, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
|
||||
for rotations_count in range(1, 4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
|
||||
output /= 4
|
||||
output = output.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
|
||||
class SRLutR90Y(SRNetBase):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SRLutR90Y, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
self.rgb_to_ycbcr = layers.RgbToYcbcr()
|
||||
self.ycbcr_to_rgb = layers.YcbcrToRgb()
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
x = self.rgb_to_ycbcr(x)
|
||||
y = x[:,0:1,:,:]
|
||||
cbcr = x[:,1:,:,:]
|
||||
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
|
||||
|
||||
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
|
||||
for rotations_count in range(1,4):
|
||||
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
|
||||
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
|
||||
output /= 4
|
||||
output = torch.cat([output, cbcr_scaled], dim=1)
|
||||
output = self.ycbcr_to_rgb(output).clamp(0, 255)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
||||
|
||||
class SRLutR90YCbCr(SRNetBase):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SRLutR90YCbCr, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = SRLutR90YCbCr(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x, config=None):
|
||||
b,c,h,w = x.shape
|
||||
y = x[:,0:1,:,:]
|
||||
cbcr = x[:,1:,:,:]
|
||||
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
|
||||
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
|
||||
for rotations_count in range(1,4):
|
||||
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
|
||||
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
|
||||
output /= 4
|
||||
output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
def get_loss_fn(self):
|
||||
def loss_fn(pred, target):
|
||||
return F.mse_loss(pred/255, target/255)
|
||||
return loss_fn
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue