|
|
|
@ -3,74 +3,162 @@ import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import numpy as np
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from common.lut import forward_2x2_input_SxS_output
|
|
|
|
|
from common.lut import forward_2x2_input_SxS_output, forward_unfolded_2x2_input_SxS_output
|
|
|
|
|
from common.layers import PercievePattern
|
|
|
|
|
|
|
|
|
|
class SRLut2x2(nn.Module):
|
|
|
|
|
class SDYLutx1(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
quantization_interval,
|
|
|
|
|
scale
|
|
|
|
|
):
|
|
|
|
|
super(SRLut2x2, self).__init__()
|
|
|
|
|
super(SDYLutx1, self).__init__()
|
|
|
|
|
self.scale = scale
|
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
|
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]])
|
|
|
|
|
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]])
|
|
|
|
|
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]])
|
|
|
|
|
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def init_from_lut(
|
|
|
|
|
stage_lut
|
|
|
|
|
stageS, stageD, stageY
|
|
|
|
|
):
|
|
|
|
|
scale = int(stage_lut.shape[-1])
|
|
|
|
|
quantization_interval = 256//(stage_lut.shape[0]-1)
|
|
|
|
|
lut_model = SRLut2x2(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
|
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
|
|
|
|
scale = int(stageS.shape[-1])
|
|
|
|
|
quantization_interval = 256//(stageS.shape[0]-1)
|
|
|
|
|
lut_model = SDYLutx1(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
|
lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
|
|
|
|
|
lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
|
|
|
|
|
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
|
|
|
|
|
return lut_model
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
|
x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut)
|
|
|
|
|
x = x.view(b, c, x.shape[-2], x.shape[-1])
|
|
|
|
|
return x
|
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
|
|
|
|
for rotations_count in range(4):
|
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
|
|
|
|
|
rb,rc,rh,rw = rotated.shape
|
|
|
|
|
|
|
|
|
|
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS)
|
|
|
|
|
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
|
|
|
|
|
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
|
|
|
|
|
output += s
|
|
|
|
|
|
|
|
|
|
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD)
|
|
|
|
|
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
|
|
|
|
|
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
|
|
|
|
|
output += d
|
|
|
|
|
|
|
|
|
|
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY)
|
|
|
|
|
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
|
|
|
|
|
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
|
|
|
|
|
output += y
|
|
|
|
|
|
|
|
|
|
output /= 4*3
|
|
|
|
|
output = output.view(b, c, h*self.scale, w*self.scale)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
|
|
|
|
return f"{self.__class__.__name__}" + \
|
|
|
|
|
f"\n stageS size: {self.stageS.shape}" + \
|
|
|
|
|
f"\n stageD size: {self.stageD.shape}" + \
|
|
|
|
|
f"\n stageY size: {self.stageY.shape}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SRLut3x3(nn.Module):
|
|
|
|
|
class SDYLutx2(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
quantization_interval,
|
|
|
|
|
scale
|
|
|
|
|
):
|
|
|
|
|
super(SRLut3x3, self).__init__()
|
|
|
|
|
super(SDYLutx2, self).__init__()
|
|
|
|
|
self.scale = scale
|
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
|
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]])
|
|
|
|
|
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]])
|
|
|
|
|
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]])
|
|
|
|
|
self.stageS_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
self.stageD_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
self.stageY_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
self.stageS_2 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
self.stageD_2 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
self.stageY_2 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def init_from_lut(
|
|
|
|
|
stage_lut
|
|
|
|
|
stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2
|
|
|
|
|
):
|
|
|
|
|
scale = int(stage_lut.shape[-1])
|
|
|
|
|
quantization_interval = 256//(stage_lut.shape[0]-1)
|
|
|
|
|
lut_model = SRLut3x3(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
|
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
|
|
|
|
scale = int(stageS_2.shape[-1])
|
|
|
|
|
quantization_interval = 256//(stageS_2.shape[0]-1)
|
|
|
|
|
lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
|
lut_model.stageS_1 = nn.Parameter(torch.tensor(stageS_1).type(torch.float32))
|
|
|
|
|
lut_model.stageD_1 = nn.Parameter(torch.tensor(stageD_1).type(torch.float32))
|
|
|
|
|
lut_model.stageY_1 = nn.Parameter(torch.tensor(stageY_1).type(torch.float32))
|
|
|
|
|
lut_model.stageS_2 = nn.Parameter(torch.tensor(stageS_2).type(torch.float32))
|
|
|
|
|
lut_model.stageD_2 = nn.Parameter(torch.tensor(stageD_2).type(torch.float32))
|
|
|
|
|
lut_model.stageY_2 = nn.Parameter(torch.tensor(stageY_2).type(torch.float32))
|
|
|
|
|
return lut_model
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
|
x = x.view(b*c, 1, h, w)
|
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
|
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
|
|
|
|
for rotations_count in range(4):
|
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
|
|
|
|
|
rb,rc,rh,rw = rotated.shape
|
|
|
|
|
|
|
|
|
|
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_1)
|
|
|
|
|
s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
|
|
|
|
|
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
|
|
|
|
|
output += s
|
|
|
|
|
|
|
|
|
|
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_1)
|
|
|
|
|
d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
|
|
|
|
|
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
|
|
|
|
|
output += d
|
|
|
|
|
|
|
|
|
|
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_1)
|
|
|
|
|
y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
|
|
|
|
|
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
|
|
|
|
|
output += y
|
|
|
|
|
|
|
|
|
|
output /= 4*3
|
|
|
|
|
output = output.view(b, c, h, w)
|
|
|
|
|
x = output
|
|
|
|
|
|
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
|
|
|
|
for rotations_count in range(4):
|
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
|
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
|
|
|
|
|
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output += unrotated_prediction
|
|
|
|
|
output /= 4
|
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
|
|
|
|
|
rb,rc,rh,rw = rotated.shape
|
|
|
|
|
|
|
|
|
|
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_2)
|
|
|
|
|
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
|
|
|
|
|
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
|
|
|
|
|
output += s
|
|
|
|
|
|
|
|
|
|
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_2)
|
|
|
|
|
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
|
|
|
|
|
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
|
|
|
|
|
output += d
|
|
|
|
|
|
|
|
|
|
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_2)
|
|
|
|
|
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
|
|
|
|
|
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
|
|
|
|
|
output += y
|
|
|
|
|
|
|
|
|
|
output /= 4*3
|
|
|
|
|
output = output.view(b, c, h*self.scale, w*self.scale)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
|
|
|
|
return f"{self.__class__.__name__}" + \
|
|
|
|
|
f"\n stageS_1 size: {self.stageS_1.shape}" + \
|
|
|
|
|
f"\n stageD_1 size: {self.stageD_1.shape}" + \
|
|
|
|
|
f"\n stageY_1 size: {self.stageY_1.shape}" + \
|
|
|
|
|
f"\n stageS_2 size: {self.stageS_2.shape}" + \
|
|
|
|
|
f"\n stageD_2 size: {self.stageD_2.shape}" + \
|
|
|
|
|
f"\n stageY_2 size: {self.stageY_2.shape}"
|