migrate srnet to linear backbone

main
vlpr 6 months ago
parent b63d0f98df
commit a0536fe79d

@ -6,45 +6,27 @@ from common.utils import round_func
from common import lut
from pathlib import Path
from . import srlut
from common.layers import PercievePattern, DenseConvUpscaleBlock, ConvUpscaleBlock, RgbToYcbcr, YcbcrToRgb
from common import layers
class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = ConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = self._extract_pattern_S(x)
x = self.stage(x)
x = x.view(b*c, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_lut(stage_lut)
return lut_model
class SRNetDense(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetDense, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
receptive_field_idxes=s_pattern,
center=[0,0],
window_size=2,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = self._extract_pattern_S(x)
x = self.stage(x)
x = x.view(b*c, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*self.scale, w*self.scale)
x = self.stage1_S(x)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
@ -52,24 +34,28 @@ class SRNetDense(nn.Module):
lut_model = srlut.SRLut.init_from_lut(stage_lut)
return lut_model
class SRNetDenseRot90(nn.Module):
class SRNetR90(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetDenseRot90, self).__init__()
super(SRNetR90, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
receptive_field_idxes=s_pattern,
center=[0,0],
window_size=2,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rx = torch.rot90(x, k=rotations_count, dims=[2, 3])
_,_,rh,rw = rx.shape
rx = self._extract_pattern_S(rx)
rx = self.stage(rx)
rx = rx.view(b*c, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(b*c, 1, rh*self.scale, rw*self.scale)
output += torch.rot90(rx, k=-rotations_count, dims=[2, 3])
output += self.stage1_S(rotated)
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
return output
@ -79,14 +65,21 @@ class SRNetDenseRot90(nn.Module):
lut_model = srlut.SRLutRot90.init_from_lut(stage_lut)
return lut_model
class SRNetDenseRot90Y(nn.Module):
class SRNetR90Y(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetDenseRot90Y, self).__init__()
super(SRNetR90Y, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.rgb_to_ycbcr = RgbToYcbcr()
self.ycbcr_to_rgb = YcbcrToRgb()
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
receptive_field_idxes=s_pattern,
center=[0,0],
window_size=2,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
def forward(self, x):
b,c,h,w = x.shape
@ -97,13 +90,10 @@ class SRNetDenseRot90Y(nn.Module):
x = y.view(b, 1, h, w)
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rx = torch.rot90(x, k=rotations_count, dims=[2, 3])
_,_,rh,rw = rx.shape
rx = self._extract_pattern_S(rx)
rx = self.stage(rx)
rx = rx.view(b, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(b, 1, rh*self.scale, rw*self.scale)
output += torch.rot90(rx, k=-rotations_count, dims=[2, 3])
output += self.stage1_S(rotated)
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)

Loading…
Cancel
Save