|
|
|
@ -389,3 +389,160 @@ class RCNetx2Centered(nn.Module):
|
|
|
|
|
output /= 3*4
|
|
|
|
|
output = output.view(b, c, h*self.scale, w*self.scale)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReconstructedConvRot90Unlutable(nn.Module):
|
|
|
|
|
def __init__(self, hidden_dim, window_size=7):
|
|
|
|
|
super(ReconstructedConvRot90Unlutable, self).__init__()
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
|
|
|
|
self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
|
|
|
|
|
|
|
|
|
def pixel_wise_forward(self, x):
|
|
|
|
|
x = (x-127.5)/127.5
|
|
|
|
|
out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
|
|
|
|
|
out = torch.tanh(out)
|
|
|
|
|
out = out*127.5 + 127.5
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
original_shape = x.shape
|
|
|
|
|
x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
|
|
|
|
|
x = F.unfold(x, self.window_size)
|
|
|
|
|
x = self.pixel_wise_forward(x)
|
|
|
|
|
x = x.mean(1)
|
|
|
|
|
x = x.reshape(*original_shape)
|
|
|
|
|
# x = round_func(x) # quality likely suffer from this
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
|
|
|
|
|
|
|
|
|
|
class RCBlockRot90Unlutable(nn.Module):
|
|
|
|
|
def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
|
|
|
|
|
super(RCBlockRot90Unlutable, self).__init__()
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
|
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
|
|
x = self.rc_conv(x)
|
|
|
|
|
x = F.pad(x, pad=[0,1,0,1], mode='replicate')
|
|
|
|
|
x = self.dense_conv_block(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class RCNetx2Unlutable(nn.Module):
|
|
|
|
|
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
|
|
|
|
super(RCNetx2Unlutable, self).__init__()
|
|
|
|
|
self.scale = scale
|
|
|
|
|
self.hidden_dim = hidden_dim
|
|
|
|
|
self.stage1_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
|
|
|
|
|
self.stage1_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
|
|
|
|
|
self.stage1_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
|
|
|
|
|
self.stage2_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
|
|
|
|
|
self.stage2_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
|
|
|
|
|
self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
|
|
|
|
|
|
|
|
|
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
|
x = x.view(b*c, 1, h, w)
|
|
|
|
|
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
|
|
|
|
for rotations_count in range(4):
|
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output /= 3*4
|
|
|
|
|
x = output
|
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
|
|
|
|
for rotations_count in range(4):
|
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output /= 3*4
|
|
|
|
|
output = output.view(b, c, h*self.scale, w*self.scale)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReconstructedConvCenteredUnlutable(nn.Module):
|
|
|
|
|
def __init__(self, hidden_dim, window_size=7):
|
|
|
|
|
super(ReconstructedConvCenteredUnlutable, self).__init__()
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
|
|
|
|
self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
|
|
|
|
|
|
|
|
|
def pixel_wise_forward(self, x):
|
|
|
|
|
x = (x-127.5)/127.5
|
|
|
|
|
out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
|
|
|
|
|
out = torch.tanh(out)
|
|
|
|
|
out = out*127.5 + 127.5
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
original_shape = x.shape
|
|
|
|
|
x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
|
|
|
|
|
x = F.unfold(x, self.window_size)
|
|
|
|
|
x = self.pixel_wise_forward(x)
|
|
|
|
|
x = x.mean(1)
|
|
|
|
|
x = x.reshape(*original_shape)
|
|
|
|
|
# x = round_func(x) # quality likely suffer from this
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
|
|
|
|
|
|
|
|
|
|
class RCBlockCenteredUnlutable(nn.Module):
|
|
|
|
|
def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
|
|
|
|
|
super(RCBlockRot90Unlutable, self).__init__()
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
|
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
|
|
x = self.rc_conv(x)
|
|
|
|
|
x = F.pad(x, pad=[0,1,0,1], mode='replicate')
|
|
|
|
|
x = self.dense_conv_block(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class RCNetx2CenteredUnlutable(nn.Module):
|
|
|
|
|
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
|
|
|
|
super(RCNetx2CenteredUnlutable, self).__init__()
|
|
|
|
|
self.scale = scale
|
|
|
|
|
self.hidden_dim = hidden_dim
|
|
|
|
|
self.stage1_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
|
|
|
|
|
self.stage1_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
|
|
|
|
|
self.stage1_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
|
|
|
|
|
self.stage2_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
|
|
|
|
|
self.stage2_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
|
|
|
|
|
self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
|
|
|
|
|
|
|
|
|
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
|
x = x.view(b*c, 1, h, w)
|
|
|
|
|
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
|
|
|
|
|
for rotations_count in range(4):
|
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output /= 3*4
|
|
|
|
|
x = output
|
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
|
|
|
|
for rotations_count in range(4):
|
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
|
|
|
|
|
output /= 3*4
|
|
|
|
|
output = output.view(b, c, h*self.scale, w*self.scale)
|
|
|
|
|
return output
|