added mulut

main
Vladimir Protsenko 8 months ago
parent 5e08fb3eef
commit b1f2f6d76b

@ -0,0 +1,28 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PercievePattern():
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]):
self.receptive_field_idxes = np.array(receptive_field_idxes)
self.window_size = np.max(self.receptive_field_idxes) + 1
self.receptive_field_idxes = [
self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1],
self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1],
self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1],
self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1],
]
def __call__(self, x):
b,c,h,w = x.shape
x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([
x[:,self.receptive_field_idxes[0],:],
x[:,self.receptive_field_idxes[1],:],
x[:,self.receptive_field_idxes[2],:],
x[:,self.receptive_field_idxes[3],:]
], 2)
x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2)
return x

@ -93,7 +93,23 @@ def forward_2x2_input_SxS_output(index, lut):
) )
out = out[:,:,0:-1,0:-1,:,:] # unpad out = out[:,:,0:-1,0:-1,:,:] # unpad
# Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504] # Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504]
out = out.permute(0,1,2,4,3,5).reshape(b,1,hs*scale,ws*scale) out = out.permute(0,1,2,4,3,5).reshape(b*c,1,hs*scale,ws*scale)
out = round_func(out)
return out
def forward_unfolded_2x2_input_SxS_output(index, lut):
b,c,hs,ws = index.shape
scale = lut.shape[-1]
out = select_index_4dlut_tetrahedral(
ixA = index,
ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]),
ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]),
ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]),
lut = lut
)
out = out[:,:,0:-1,0:-1,:,:] # unpad
# Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504]
out = out.permute(0,1,2,4,3,5).reshape(b*c,1,scale,scale)
out = round_func(out) out = round_func(out)
return out return out

@ -17,7 +17,8 @@ AVAILABLE_MODELS = {
'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1, 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1,
'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2, 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2,
'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered, 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered,
'SDYNetx1': sdynet.SDYNetx1, 'SDYNetx1': sdynet.SDYNetx1, 'SDYLutx1': sdylut.SDYLutx1,
'SDYNetx2': sdynet.SDYNetx2, 'SDYLutx2': sdylut.SDYLutx2,
'RCNetx2Unlutable': rcnet.RCNetx2Unlutable, 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable,
'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable, 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable,
} }

@ -3,74 +3,162 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from common.lut import forward_2x2_input_SxS_output from common.lut import forward_2x2_input_SxS_output, forward_unfolded_2x2_input_SxS_output
from common.layers import PercievePattern
class SRLut2x2(nn.Module): class SDYLutx1(nn.Module):
def __init__( def __init__(
self, self,
quantization_interval, quantization_interval,
scale scale
): ):
super(SRLut2x2, self).__init__() super(SDYLutx1, self).__init__()
self.scale = scale self.scale = scale
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]])
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]])
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]])
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_lut(
stage_lut stageS, stageD, stageY
): ):
scale = int(stage_lut.shape[-1]) scale = int(stageS.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1) quantization_interval = 256//(stageS.shape[0]-1)
lut_model = SRLut2x2(quantization_interval=quantization_interval, scale=scale) lut_model = SDYLutx1(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
return lut_model return lut_model
def forward(self, x): def forward(self, x):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32) x = x.view(b*c, 1, h, w).type(torch.float32)
x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
x = x.view(b, c, x.shape[-2], x.shape[-1]) for rotations_count in range(4):
return x rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS)
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD)
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY)
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
output /= 4*3
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" return f"{self.__class__.__name__}" + \
f"\n stageS size: {self.stageS.shape}" + \
f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}"
class SRLut3x3(nn.Module): class SDYLutx2(nn.Module):
def __init__( def __init__(
self, self,
quantization_interval, quantization_interval,
scale scale
): ):
super(SRLut3x3, self).__init__() super(SDYLutx2, self).__init__()
self.scale = scale self.scale = scale
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]])
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]])
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]])
self.stageS_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageS_2 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD_2 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY_2 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_lut(
stage_lut stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2
): ):
scale = int(stage_lut.shape[-1]) scale = int(stageS_2.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1) quantization_interval = 256//(stageS_2.shape[0]-1)
lut_model = SRLut3x3(quantization_interval=quantization_interval, scale=scale) lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) lut_model.stageS_1 = nn.Parameter(torch.tensor(stageS_1).type(torch.float32))
lut_model.stageD_1 = nn.Parameter(torch.tensor(stageD_1).type(torch.float32))
lut_model.stageY_1 = nn.Parameter(torch.tensor(stageY_1).type(torch.float32))
lut_model.stageS_2 = nn.Parameter(torch.tensor(stageS_2).type(torch.float32))
lut_model.stageD_2 = nn.Parameter(torch.tensor(stageD_2).type(torch.float32))
lut_model.stageY_2 = nn.Parameter(torch.tensor(stageY_2).type(torch.float32))
return lut_model return lut_model
def forward(self, x): def forward(self, x):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w) x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_1)
s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_1)
d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_1)
y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
output /= 4*3
output = output.view(b, c, h, w)
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4): for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut) rb,rc,rh,rw = rotated.shape
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_2)
output /= 4 s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_2)
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_2)
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
output /= 4*3
output = output.view(b, c, h*self.scale, w*self.scale) output = output.view(b, c, h*self.scale, w*self.scale)
return output return output
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" return f"{self.__class__.__name__}" + \
f"\n stageS_1 size: {self.stageS_1.shape}" + \
f"\n stageD_1 size: {self.stageD_1.shape}" + \
f"\n stageY_1 size: {self.stageY_1.shape}" + \
f"\n stageS_2 size: {self.stageS_2.shape}" + \
f"\n stageD_2 size: {self.stageD_2.shape}" + \
f"\n stageY_2 size: {self.stageY_2.shape}"

@ -4,6 +4,7 @@ import torch.nn.functional as F
import numpy as np import numpy as np
from common.utils import round_func from common.utils import round_func
from common import lut from common import lut
from common.layers import PercievePattern
from pathlib import Path from pathlib import Path
from . import sdylut from . import sdylut
@ -39,31 +40,6 @@ class DenseConvUpscaleBlock(nn.Module):
x = round_func(x*127.5 + 127.5) x = round_func(x*127.5 + 127.5)
return x return x
class PercievePattern():
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]):
self.receptive_field_idxes = np.array(receptive_field_idxes)
self.window_size = np.max(self.receptive_field_idxes) + 1
self.receptive_field_idxes = [
self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1],
self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1],
self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1],
self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1],
]
def __call__(self, x):
b,c,h,w = x.shape
x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([
x[:,self.receptive_field_idxes[0],:],
x[:,self.receptive_field_idxes[1],:],
x[:,self.receptive_field_idxes[2],:],
x[:,self.receptive_field_idxes[3],:]
], 2)
x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2)
return x
class SDYNetx1(nn.Module): class SDYNetx1(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx1, self).__init__() super(SDYNetx1, self).__init__()
@ -103,6 +79,84 @@ class SDYNetx1(nn.Module):
return output return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) stageS = lut.transfer_2x2_input_SxS_output(self.stageS, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx1.init_from_lut(stage_lut) stageD = lut.transfer_2x2_input_SxS_output(self.stageD, quantization_interval=quantization_interval, batch_size=batch_size)
stageY = lut.transfer_2x2_input_SxS_output(self.stageY, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx1.init_from_lut(stageS, stageD, stageY)
return lut_model
class SDYNetx2(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx2, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]])
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]])
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]])
self.stageS_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageD_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageY_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageS_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageD_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageY_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = self.stageS_1(self._extract_pattern_S(rotated))
s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output_1 += s
d = self.stageD_1(self._extract_pattern_D(rotated))
d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output_1 += d
y = self.stageY_1(self._extract_pattern_Y(rotated))
y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output_1 += y
output_1 /= 4*3
output_1 = output_1.view(b, c, h, w)
x = output_1
output_2 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = self.stageS_2(self._extract_pattern_S(rotated))
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output_2 += s
d = self.stageD_2(self._extract_pattern_D(rotated))
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output_2 += d
y = self.stageY_2(self._extract_pattern_Y(rotated))
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output_2 += y
output_2 /= 4*3
output_2 = output_2.view(b, c, h*self.scale, w*self.scale)
return output_2
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stageS_1 = lut.transfer_2x2_input_SxS_output(self.stageS_1, quantization_interval=quantization_interval, batch_size=batch_size)
stageD_1 = lut.transfer_2x2_input_SxS_output(self.stageD_1, quantization_interval=quantization_interval, batch_size=batch_size)
stageY_1 = lut.transfer_2x2_input_SxS_output(self.stageY_1, quantization_interval=quantization_interval, batch_size=batch_size)
stageS_2 = lut.transfer_2x2_input_SxS_output(self.stageS_2, quantization_interval=quantization_interval, batch_size=batch_size)
stageD_2 = lut.transfer_2x2_input_SxS_output(self.stageD_2, quantization_interval=quantization_interval, batch_size=batch_size)
stageY_2 = lut.transfer_2x2_input_SxS_output(self.stageY_2, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx2.init_from_lut(stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2)
return lut_model return lut_model
Loading…
Cancel
Save