|
|
@ -443,7 +443,33 @@ class RCNetx2Unlutable(nn.Module):
|
|
|
|
self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
|
|
|
self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
|
|
|
|
|
|
|
|
|
|
|
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
|
|
|
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
|
|
|
raise NotImplementedError
|
|
|
|
s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
|
|
|
|
|
|
|
s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
|
|
|
|
|
|
|
s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
|
|
|
|
|
|
|
s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
|
|
|
|
|
|
|
s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
|
|
|
|
|
|
|
s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
|
|
|
|
|
|
|
s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lut_model = rclut.RCLutx2.init_from_lut(
|
|
|
|
|
|
|
|
s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
|
|
|
|
|
|
|
|
s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
|
|
|
|
|
|
|
|
s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
|
|
|
|
|
|
|
|
s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
|
|
|
|
|
|
|
|
s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
|
|
|
|
|
|
|
|
s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return lut_model
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
b,c,h,w = x.shape
|
|
|
@ -522,7 +548,33 @@ class RCNetx2CenteredUnlutable(nn.Module):
|
|
|
|
self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
|
|
|
self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
|
|
|
|
|
|
|
|
|
|
|
|
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
|
|
|
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
|
|
|
raise NotImplementedError
|
|
|
|
s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
|
|
|
|
|
|
|
s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
|
|
|
|
|
|
|
s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
|
|
|
|
|
|
|
s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
|
|
|
|
|
|
|
|
s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
|
|
|
|
|
|
|
|
s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
|
|
|
|
|
|
|
|
s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lut_model = rclut.RCLutx2Centered.init_from_lut(
|
|
|
|
|
|
|
|
s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
|
|
|
|
|
|
|
|
s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
|
|
|
|
|
|
|
|
s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
|
|
|
|
|
|
|
|
s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
|
|
|
|
|
|
|
|
s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
|
|
|
|
|
|
|
|
s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return lut_model
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
b,c,h,w = x.shape
|
|
|
@ -545,69 +597,3 @@ class RCNetx2CenteredUnlutable(nn.Module):
|
|
|
|
output = output.view(b, c, h*self.scale, w*self.scale)
|
|
|
|
output = output.view(b, c, h*self.scale, w*self.scale)
|
|
|
|
return output
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReconstructedConvCenteredv2(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, hidden_dim, window_size=7):
|
|
|
|
|
|
|
|
super(ReconstructedConvCenteredv2, self).__init__()
|
|
|
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
|
|
|
self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
|
|
|
|
|
|
|
self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pixel_wise_forward(self, x):
|
|
|
|
|
|
|
|
x = (x-127.5)/127.5
|
|
|
|
|
|
|
|
out = torch.einsum('bwk,wh,wh -> bwk', x, self.projection1, self.projection2)
|
|
|
|
|
|
|
|
out = torch.tanh(out)
|
|
|
|
|
|
|
|
out = out*127.5 + 127.5
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
original_shape = x.shape
|
|
|
|
|
|
|
|
x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
|
|
|
|
|
|
|
|
x = F.unfold(x, self.window_size)
|
|
|
|
|
|
|
|
x = self.pixel_wise_forward(x)
|
|
|
|
|
|
|
|
x = (x.mean(1) - (x.max().values - x.min().values))/2
|
|
|
|
|
|
|
|
x = x.reshape(*original_shape)
|
|
|
|
|
|
|
|
x = round_func(x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
|
|
|
return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RCBlockCenteredv2(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
|
|
|
|
|
|
|
|
super(RCBlockCenteredv2, self).__init__()
|
|
|
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
|
|
|
self.rc_conv = ReconstructedConvCenteredv2(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
|
|
|
|
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
|
|
|
|
|
x = self.rc_conv(x)
|
|
|
|
|
|
|
|
x = F.pad(x, pad=[0,1,0,1], mode='replicate')
|
|
|
|
|
|
|
|
x = self.dense_conv_block(x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RCNetCentered_7x7v2(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
|
|
|
|
|
|
|
super(RCNetCentered_7x7v2, self).__init__()
|
|
|
|
|
|
|
|
self.hidden_dim = hidden_dim
|
|
|
|
|
|
|
|
self.layers_count = layers_count
|
|
|
|
|
|
|
|
self.scale = scale
|
|
|
|
|
|
|
|
window_size = 7
|
|
|
|
|
|
|
|
self.stage = RCBlockCenteredv2(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
|
|
|
|
x = x.view(b*c, 1, h, w)
|
|
|
|
|
|
|
|
x = self.stage(x)
|
|
|
|
|
|
|
|
x = x.view(b, c, h*self.scale, w*self.scale)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
|
|
|
|
|
|
|
window_size = self.stage.rc_conv.window_size
|
|
|
|
|
|
|
|
rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
|
|
|
|
|
|
|
|
dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
|
|
|
|
|
|
|
|
lut_model = rclut.RCLutCentered_7x7.init_from_lut(rc_conv_luts, dense_conv_lut)
|
|
|
|
|
|
|
|
return lut_model
|
|
|
|
|
|
|
|