hklut and sdy lut experiments.

main
protsenkovi 7 months ago
parent 9992763c9f
commit 0dd154d0cd

File diff suppressed because one or more lines are too long

@ -7,16 +7,10 @@ from .utils import round_func
class PercievePattern(): class PercievePattern():
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2): def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2):
assert window_size >= (np.max(receptive_field_idxes)+1) assert window_size >= (np.max(receptive_field_idxes)+1)
assert len(receptive_field_idxes) == 4 receptive_field_idxes = np.array(receptive_field_idxes)
self.receptive_field_idxes = np.array(receptive_field_idxes)
self.window_size = window_size self.window_size = window_size
self.center = center self.center = center
self.receptive_field_idxes = [ self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size + receptive_field_idxes[i,1] for i in range(len(receptive_field_idxes))]
self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1],
self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1],
self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1],
self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1],
]
def __call__(self, x): def __call__(self, x):
b,c,h,w = x.shape b,c,h,w = x.shape
@ -27,12 +21,7 @@ class PercievePattern():
mode='replicate' mode='replicate'
) )
x = F.unfold(input=x, kernel_size=self.window_size) x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([ x = torch.stack([x[:,self.receptive_field_idxes[i],:] for i in range(len(self.receptive_field_idxes))], 2)
x[:,self.receptive_field_idxes[0],:],
x[:,self.receptive_field_idxes[1],:],
x[:,self.receptive_field_idxes[2],:],
x[:,self.receptive_field_idxes[3],:]
], 2)
return x return x
class UpscaleBlock(nn.Module): class UpscaleBlock(nn.Module):

@ -36,9 +36,10 @@ class Domain4DValues(Dataset):
return len(self.values) return len(self.values)
def __iter__(self): def __iter__(self):
for i in range(0, len(self.values)): for i in range(len(self.values)):
yield self.__getitem__(i) yield self.__getitem__(i)
def transfer_rc_conv(rc_conv, quantization_interval=1): def transfer_rc_conv(rc_conv, quantization_interval=1):
receptive_field_pixel_count = rc_conv.window_size**2 receptive_field_pixel_count = rc_conv.window_size**2
bucket_count = 256//quantization_interval bucket_count = 256//quantization_interval
@ -62,14 +63,15 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2*
domain_values, domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16, batch_size=batch_size if quantization_interval >= 16 else 2**16,
pin_memory=True, pin_memory=True,
num_workers=1 if quantization_interval >= 16 else mp.cpu_count() num_workers=1 if quantization_interval >= 16 else mp.cpu_count(),
shuffle=False,
) )
counter = 0 counter = 0
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader): for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda() inputs = batch.type(torch.float32).cuda()
with torch.no_grad(): with torch.no_grad():
outputs = block(inputs) outputs = block(inputs)
lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.reshape(-1, scale, scale).cpu().numpy().astype(np.uint8) lut[ix1s, ix2s, ix3s, ix4s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0] counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ") print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print() print()
@ -135,6 +137,70 @@ def select_index_1dlut_linear(ixA, lut):
out = out.reshape((b,c,h,w)) out = out.reshape((b,c,h,w))
return out return out
def select_index_3dlut_tetrahedral(index, lut):
b, hw, c = index.shape
lut = torch.clamp(lut, 0, 255)
dimA, dimB, dimC, dimD = lut.shape[:4]
q = 256/(dimA-1)
L = dimA
upscale = lut.shape[-1]
weight = lut.reshape(L**4, upscale, upscale)
msbA = torch.floor_divide(index, q).type(torch.int64)
msbB = msbA + 1
lsb = index % q
img_a1 = msbA[:,:,0].reshape(b*hw, 1)
img_b1 = msbA[:,:,1].reshape(b*hw, 1)
img_c1 = msbA[:,:,2].reshape(b*hw, 1)
img_a2 = msbB[:,:,0].reshape(b*hw, 1)
img_b2 = msbB[:,:,1].reshape(b*hw, 1)
img_c2 = msbB[:,:,2].reshape(b*hw, 1)
fa = lsb[:,:,0].reshape(b*hw, 1)
fb = lsb[:,:,1].reshape(b*hw, 1)
fc = lsb[:,:,2].reshape(b*hw, 1)
p000 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c1 * L + img_d1].reshape(b*hw, upscale*upscale)
p001 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c1 * L + img_d2].reshape(b*hw, upscale*upscale)
p010 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c2 * L + img_d1].reshape(b*hw, upscale*upscale)
p011 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c2 * L + img_d2].reshape(b*hw, upscale*upscale)
p100 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c1 * L + img_d1].reshape(b*hw, upscale*upscale)
p101 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c1 * L + img_d2].reshape(b*hw, upscale*upscale)
p110 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c2 * L + img_d1].reshape(b*hw, upscale*upscale)
p111 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c2 * L + img_d2].reshape(b*hw, upscale*upscale)
fab = fa > fb
fbc = fb > fc
fac = fa > fc
fableq = fa <= fb
fbcleq = fb <= fc
facleq = fa <= fc
out = torch.zeros((b*hw, upscale*upscale), dtype=weight.dtype).to(device=weight.device)
i1 = i = torch.all(torch.cat([fab, fbc], dim=1), dim=1);
out[i] = (q - fa[i]) * p000[i] + (fa[i] - fb[i]) * p100[i] + (fb[i] - fc[i]) * p110[i] + fc[i] * p111[i]
i2 = i = torch.all(torch.cat([fab, fac], dim=1), dim=1);
out[i] = (q - fa[i]) * p000[i] + (fa[i] - fc[i]) * p100[i] + (fc[i] - fb[i]) * p101[i] + fb[i] * p111[i]
i3 = i = torch.all(torch.cat([fab, ~i1[:, None], ~i2[:, None]], dim=1), dim=1);
out[i] = (q - fc[i]) * p000[i] + (fc[i] - fa[i]) * p001[i] + (fa[i] - fb[i]) * p101[i] + fb[i] * p111[i]
i4 = i = torch.all(torch.cat([~fab, ~fbcleq], dim=1), dim=1);
out[i] = (q - fc[i]) * p000[i] + (fc[i] - fb[i]) * p001[i] + (fb[i] - fa[i]) * p011[i] + fa[i] * p111[i]
i5 = i = torch.all(torch.cat([~fab, ~facleq], dim=1), dim=1);
out[i] = (q - fb[i]) * p000[i] + (fb[i] - fc[i]) * p010[i] + (fc[i] - fa[i]) * p011[i] + fa[i] * p111[i]
i6 = i = torch.all(torch.cat([~fab, ~fbc, ~fac], dim=1), dim=1);
out[i] = (q - fb[i]) * p000[i] + (fb[i] - fa[i]) * p010[i] + (fa[i] - fc[i]) * p110[i] + fc[i] * p111[i]
out = out.reshape((b, hw, upscale*upscale))
out = out / q
return out
def select_index_4dlut_tetrahedral(index, lut): def select_index_4dlut_tetrahedral(index, lut):
b, hw, c = index.shape b, hw, c = index.shape
lut = torch.clamp(lut, 0, 255) lut = torch.clamp(lut, 0, 255)

@ -3,6 +3,8 @@ from . import rclut
from . import srnet from . import srnet
from . import srlut from . import srlut
from . import sdynet from . import sdynet
from . import hdbnet
from . import hdblut
import torch import torch
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
@ -18,12 +20,16 @@ AVAILABLE_MODELS = {
'SDYLutx1': sdylut.SDYLutx1, 'SDYLutx1': sdylut.SDYLutx1,
'SDYNetx2': sdynet.SDYNetx2, 'SDYNetx2': sdynet.SDYNetx2,
'SDYLutx2': sdylut.SDYLutx2, 'SDYLutx2': sdylut.SDYLutx2,
'SDYNetx3': sdynet.SDYNetx3,
'SDYLutx3': sdylut.SDYLutx3,
'SDYNetR90x1': sdynet.SDYNetR90x1, 'SDYNetR90x1': sdynet.SDYNetR90x1,
'SDYLutR90x1': sdylut.SDYLutR90x1, 'SDYLutR90x1': sdylut.SDYLutR90x1,
'SDYNetR90x2': sdynet.SDYNetR90x2, 'SDYNetR90x2': sdynet.SDYNetR90x2,
'SDYLutR90x2': sdylut.SDYLutR90x2, 'SDYLutR90x2': sdylut.SDYLutR90x2,
'SRNetY': srnet.SRNetY, 'SRNetY': srnet.SRNetY,
'SRLutY': srlut.SRLutY, 'SRLutY': srlut.SRLutY,
'HDBNet': hdbnet.HDBNet,
'HDBLNet': hdbnet.HDBLNet,
# 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, # 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
# 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, # 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
# 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, # 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,

@ -0,0 +1,201 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import select_index_4dlut_tetrahedral
from common import layers
from common.utils import round_func
class HDBLut(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(HDBLut, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = HDBLut(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w).type(torch.float32)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
# class SRLutY(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(SRLutY, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
# self.rgb_to_ycbcr = layers.RgbToYcbcr()
# self.ycbcr_to_rgb = layers.YcbcrToRgb()
# @staticmethod
# def init_from_numpy(
# stage_lut
# ):
# scale = int(stage_lut.shape[-1])
# quantization_interval = 256//(stage_lut.shape[0]-1)
# lut_model = SRLutY(quantization_interval=quantization_interval, scale=scale)
# lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
# return lut_model
# def forward_stage(self, x, scale, percieve_pattern, lut):
# b,c,h,w = x.shape
# x = percieve_pattern(x)
# x = select_index_4dlut_tetrahedral(index=x, lut=lut)
# x = round_func(x)
# x = x.reshape(b, c, h, w, scale, scale)
# x = x.permute(0,1,2,4,3,5)
# x = x.reshape(b, c, h*scale, w*scale)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = self.rgb_to_ycbcr(x)
# y = x[:,0:1,:,:]
# cbcr = x[:,1:,:,:]
# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
# output = self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
# output = torch.cat([output, cbcr_scaled], dim=1)
# output = self.ycbcr_to_rgb(output).clamp(0, 255)
# return output
# def __repr__(self):
# return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
# class SRLutR90(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(SRLutR90, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
# @staticmethod
# def init_from_numpy(
# stage_lut
# ):
# scale = int(stage_lut.shape[-1])
# quantization_interval = 256//(stage_lut.shape[0]-1)
# lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale)
# lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
# return lut_model
# def forward_stage(self, x, scale, percieve_pattern, lut):
# b,c,h,w = x.shape
# x = percieve_pattern(x)
# x = select_index_4dlut_tetrahedral(index=x, lut=lut)
# x = round_func(x)
# x = x.reshape(b, c, h, w, scale, scale)
# x = x.permute(0,1,2,4,3,5)
# x = x.reshape(b, c, h*scale, w*scale)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.reshape(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
# for rotations_count in range(1, 4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
# output /= 4
# output = output.reshape(b, c, h*self.scale, w*self.scale)
# return output
# def __repr__(self):
# return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
# class SRLutR90Y(nn.Module):
# def __init__(
# self,
# quantization_interval,
# scale
# ):
# super(SRLutR90Y, self).__init__()
# self.scale = scale
# self.quantization_interval = quantization_interval
# self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
# self.rgb_to_ycbcr = layers.RgbToYcbcr()
# self.ycbcr_to_rgb = layers.YcbcrToRgb()
# @staticmethod
# def init_from_numpy(
# stage_lut
# ):
# scale = int(stage_lut.shape[-1])
# quantization_interval = 256//(stage_lut.shape[0]-1)
# lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale)
# lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
# return lut_model
# def forward_stage(self, x, scale, percieve_pattern, lut):
# b,c,h,w = x.shape
# x = percieve_pattern(x)
# x = select_index_4dlut_tetrahedral(index=x, lut=lut)
# x = round_func(x)
# x = x.reshape(b, c, h, w, scale, scale)
# x = x.permute(0,1,2,4,3,5)
# x = x.reshape(b, c, h*scale, w*scale)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = self.rgb_to_ycbcr(x)
# y = x[:,0:1,:,:]
# cbcr = x[:,1:,:,:]
# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
# output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
# output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
# for rotations_count in range(1,4):
# rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
# output /= 4
# output = torch.cat([output, cbcr_scaled], dim=1)
# output = self.ycbcr_to_rgb(output).clamp(0, 255)
# return output
# def __repr__(self):
# return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"

@ -0,0 +1,281 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import lut
from pathlib import Path
# from . import srlut
from common import layers
class HDBNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBNet, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4*2
output_msb = output_msb + output_lsb
x = output_msb
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4*2
output_msb = output_msb + output_lsb
x = output_msb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_numpy(stage_lut)
return lut_model
class HDBLNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNet, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=3)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_3L, self.stage1_3L), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4
output_msb = output_msb + output_lsb
x = output_msb
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_3L, self.stage2_3L), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4
output_msb = output_msb + output_lsb
x = output_msb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_numpy(stage_lut)
return lut_model
# class SRNetY(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(SRNetY, self).__init__()
# self.scale = scale
# self.stage1_S = layers.UpscaleBlock(
# hidden_dim=hidden_dim,
# layers_count=layers_count,
# upscale_factor=self.scale
# )
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
# self.rgb_to_ycbcr = layers.RgbToYcbcr()
# self.ycbcr_to_rgb = layers.YcbcrToRgb()
# def forward_stage(self, x, scale, percieve_pattern, stage):
# b,c,h,w = x.shape
# x = percieve_pattern(x)
# x = stage(x)
# x = round_func(x)
# x = x.reshape(b, c, h, w, scale, scale)
# x = x.permute(0,1,2,4,3,5)
# x = x.reshape(b, c, h*scale, w*scale)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = self.rgb_to_ycbcr(x)
# y = x[:,0:1,:,:]
# cbcr = x[:,1:,:,:]
# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
# x = y.view(b, 1, h, w)
# output = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
# output = torch.cat([output, cbcr_scaled], dim=1)
# output = self.ycbcr_to_rgb(output).clamp(0, 255)
# return output
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = srlut.SRLutY.init_from_numpy(stage_lut)
# return lut_model
# class SRNetR90(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(SRNetR90, self).__init__()
# self.scale = scale
# self.stage1_S = layers.UpscaleBlock(
# hidden_dim=hidden_dim,
# layers_count=layers_count,
# upscale_factor=self.scale
# )
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
# def forward_stage(self, x, scale, percieve_pattern, stage):
# b,c,h,w = x.shape
# x = percieve_pattern(x)
# x = stage(x)
# x = round_func(x)
# x = x.reshape(b, c, h, w, scale, scale)
# x = x.permute(0,1,2,4,3,5)
# x = x.reshape(b, c, h*scale, w*scale)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = x.reshape(b*c, 1, h, w)
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
# for rotations_count in range(1,4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
# output /= 4
# output = output.reshape(b, c, h*self.scale, w*self.scale)
# return output
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
# return lut_model
# class SRNetR90Y(nn.Module):
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
# super(SRNetR90Y, self).__init__()
# self.scale = scale
# s_pattern=[[0,0],[0,1],[1,0],[1,1]]
# self.stage1_S = layers.UpscaleBlock(
# hidden_dim=hidden_dim,
# layers_count=layers_count,
# upscale_factor=self.scale
# )
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
# self.rgb_to_ycbcr = layers.RgbToYcbcr()
# self.ycbcr_to_rgb = layers.YcbcrToRgb()
# def forward_stage(self, x, scale, percieve_pattern, stage):
# b,c,h,w = x.shape
# x = percieve_pattern(x)
# x = stage(x)
# x = round_func(x)
# x = x.reshape(b, c, h, w, scale, scale)
# x = x.permute(0,1,2,4,3,5)
# x = x.reshape(b, c, h*scale, w*scale)
# return x
# def forward(self, x):
# b,c,h,w = x.shape
# x = self.rgb_to_ycbcr(x)
# y = x[:,0:1,:,:]
# cbcr = x[:,1:,:,:]
# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
# x = y.view(b, 1, h, w)
# output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
# output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
# for rotations_count in range(1,4):
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
# output /= 4
# output = torch.cat([output, cbcr_scaled], dim=1)
# output = self.ycbcr_to_rgb(output).clamp(0, 255)
# return output
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut)
# return lut_model

@ -75,9 +75,9 @@ class SDYLutx2(nn.Module):
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@ -86,8 +86,8 @@ class SDYLutx2(nn.Module):
def init_from_numpy( def init_from_numpy(
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
): ):
scale = int(stageS.shape[-1]) scale = int(stage2_S.shape[-1])
quantization_interval = 256//(stageS.shape[0]-1) quantization_interval = 256//(stage2_S.shape[0]-1)
lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale) lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale)
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32)) lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32)) lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
@ -128,9 +128,101 @@ class SDYLutx2(nn.Module):
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}" + \ return f"{self.__class__.__name__}" + \
f"\n stageS size: {self.stageS.shape}" + \ f"\n stage1_S size: {self.stage1_S.shape}" + \
f"\n stageD size: {self.stageD.shape}" + \ f"\n stage1_D size: {self.stage1_D.shape}" + \
f"\n stageY size: {self.stageY.shape}" f"\n stage1_Y size: {self.stage1_Y.shape}" + \
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}"
class SDYLutx3(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutx3, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.stage3_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage3_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stage3_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y, stage3_S, stage3_D, stage3_Y
):
scale = int(stage3_S.shape[-1])
quantization_interval = 256//(stage3_S.shape[0]-1)
lut_model = SDYLutx3(quantization_interval=quantization_interval, scale=scale)
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
lut_model.stage3_S = nn.Parameter(torch.tensor(stage3_S).type(torch.float32))
lut_model.stage3_D = nn.Parameter(torch.tensor(stage3_D).type(torch.float32))
lut_model.stage3_Y = nn.Parameter(torch.tensor(stage3_Y).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
output /= 3
output = round_func(output)
x = output
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage2_Y)
output /= 3
output = round_func(output)
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage3_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage3_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage3_Y)
output /= 3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stage1_S size: {self.stage1_S.shape}" + \
f"\n stage1_D size: {self.stage1_D.shape}" + \
f"\n stage1_Y size: {self.stage1_Y.shape}" + \
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}" + \
f"\n stage3_S size: {self.stage3_S.shape}" + \
f"\n stage3_D size: {self.stage3_D.shape}" + \
f"\n stage3_Y size: {self.stage3_Y.shape}"
class SDYLutR90x1(nn.Module): class SDYLutR90x1(nn.Module):
def __init__( def __init__(

@ -98,12 +98,79 @@ class SDYNetx2(nn.Module):
stage1_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage1_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) stage1_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) stage1_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y)
return lut_model return lut_model
class SDYNetx3(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx3, self).__init__()
self.scale = scale
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage2_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage2_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage2_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage3_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage3_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage3_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
output /= 3
x = output
x = round_func(x)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage2_Y)
output /= 3
x = output
x = round_func(x)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage3_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage3_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage3_Y)
output /= 3
x = output
x = round_func(x)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage1_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size)
stage3_S = lut.transfer_2x2_input_SxS_output(self.stage3_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage3_D = lut.transfer_2x2_input_SxS_output(self.stage3_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage3_Y = lut.transfer_2x2_input_SxS_output(self.stage3_Y, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx3.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y, stage3_S, stage3_D, stage3_Y)
return lut_model
class SDYNetR90x1(nn.Module): class SDYNetR90x1(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetR90x1, self).__init__() super(SDYNetR90x1, self).__init__()

@ -33,6 +33,7 @@ class TrainOptions:
parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.") parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.")
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor") parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers") parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers")
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size') parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training") parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training")
parser.add_argument('--models_dir', type=str, default='../models/', help="experiment folder") parser.add_argument('--models_dir', type=str, default='../models/', help="experiment folder")
@ -105,7 +106,7 @@ if __name__ == "__main__":
config.model = model.__class__.__name__ config.model = model.__class__.__name__
else: else:
if 'net' in config.model.lower(): if 'net' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale, layers_count=config.layers_count)
if 'lut' in config.model.lower(): if 'lut' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale) model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale)
model = model.to(torch.device(config.device)) model = model.to(torch.device(config.device))

Loading…
Cancel
Save