diff --git a/readme.md b/readme.md index 9af4c4f..b526e2f 100644 --- a/readme.md +++ b/readme.md @@ -4,7 +4,7 @@ Example python train.py --model SRNetRot90 python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth -python transfer_to_lut.py --model_path /wd/lut_reproduce/models/last_trained_net.pth +python transfer_to_lut.py python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000 python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth diff --git a/src/models/__init__.py b/src/models/__init__.py index 33b588d..900d147 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -20,7 +20,6 @@ AVAILABLE_MODELS = { 'SDYNetx1': sdynet.SDYNetx1, 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable, 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable, - 'RCNetCentered_7x7v2': rcnet.RCNetCentered_7x7v2 } def SaveCheckpoint(model, path): diff --git a/src/models/rcnet.py b/src/models/rcnet.py index c08bff0..4fbc181 100644 --- a/src/models/rcnet.py +++ b/src/models/rcnet.py @@ -443,7 +443,33 @@ class RCNetx2Unlutable(nn.Module): self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) def get_lut_model(self, quantization_interval=16, batch_size=2**10): - raise NotImplementedError + s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) + s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) + s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) + s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) + s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) + s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) + s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + lut_model = rclut.RCLutx2.init_from_lut( + s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, + s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, + s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, + s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, + s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, + s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 + ) + return lut_model def forward(self, x): b,c,h,w = x.shape @@ -522,7 +548,33 @@ class RCNetx2CenteredUnlutable(nn.Module): self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) def get_lut_model(self, quantization_interval=16, batch_size=2**10): - raise NotImplementedError + s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) + s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) + s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) + s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) + s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) + s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) + s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + lut_model = rclut.RCLutx2Centered.init_from_lut( + s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, + s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, + s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, + s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, + s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, + s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 + ) + return lut_model def forward(self, x): b,c,h,w = x.shape @@ -545,69 +597,3 @@ class RCNetx2CenteredUnlutable(nn.Module): output = output.view(b, c, h*self.scale, w*self.scale) return output - - -class ReconstructedConvCenteredv2(nn.Module): - def __init__(self, hidden_dim, window_size=7): - super(ReconstructedConvCenteredv2, self).__init__() - self.window_size = window_size - self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - - def pixel_wise_forward(self, x): - x = (x-127.5)/127.5 - out = torch.einsum('bwk,wh,wh -> bwk', x, self.projection1, self.projection2) - out = torch.tanh(out) - out = out*127.5 + 127.5 - return out - - def forward(self, x): - original_shape = x.shape - x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') - x = F.unfold(x, self.window_size) - x = self.pixel_wise_forward(x) - x = (x.mean(1) - (x.max().values - x.min().values))/2 - x = x.reshape(*original_shape) - x = round_func(x) - return x - - def __repr__(self): - return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" - -class RCBlockCenteredv2(nn.Module): - def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): - super(RCBlockCenteredv2, self).__init__() - self.window_size = window_size - self.rc_conv = ReconstructedConvCenteredv2(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) - - def forward(self, x): - b,c,hs,ws = x.shape - x = self.rc_conv(x) - x = F.pad(x, pad=[0,1,0,1], mode='replicate') - x = self.dense_conv_block(x) - return x - - -class RCNetCentered_7x7v2(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetCentered_7x7v2, self).__init__() - self.hidden_dim = hidden_dim - self.layers_count = layers_count - self.scale = scale - window_size = 7 - self.stage = RCBlockCenteredv2(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size) - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - x = self.stage(x) - x = x.view(b, c, h*self.scale, w*self.scale) - return x - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - window_size = self.stage.rc_conv.window_size - rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) - dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutCentered_7x7.init_from_lut(rc_conv_luts, dense_conv_lut) - return lut_model diff --git a/src/scripts/transfer_to_lut.py b/src/scripts/transfer_to_lut.py index 4e314fe..b06475d 100644 --- a/src/scripts/transfer_to_lut.py +++ b/src/scripts/transfer_to_lut.py @@ -19,7 +19,7 @@ import models class TransferToLutOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - self.parser.add_argument('--model_path', '-m', type=str, default='', help="model path folder") + self.parser.add_argument('--model_path', '-m', type=str, default='../../models/last_trained_net.pth', help="model path folder") self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")