diff --git a/src/common/layers.py b/src/common/layers.py index 8c88d31..d1601ee 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -67,7 +67,6 @@ class UpscaleBlock(nn.Module): x = x.reshape(b, c, h*self.upscale_factor, w*self.upscale_factor) return x -# https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105 class RgbToYcbcr(nn.Module): r"""Convert an image from RGB to YCbCr. @@ -84,6 +83,8 @@ class RgbToYcbcr(nn.Module): >>> input = torch.rand(2, 3, 4, 5) >>> ycbcr = RgbToYcbcr() >>> output = ycbcr(input) # 2x3x4x5 + + https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105 """ def forward(self, image): @@ -113,6 +114,8 @@ class YcbcrToRgb(nn.Module): >>> input = torch.rand(2, 3, 4, 5) >>> rgb = YcbcrToRgb() >>> output = rgb(input) # 2x3x4x5 + + https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L127 """ def forward(self, image): diff --git a/src/common/lut.py b/src/common/lut.py index 3de99b5..42b91ea 100644 --- a/src/common/lut.py +++ b/src/common/lut.py @@ -83,7 +83,7 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2* def forward_2x2_input_SxS_output(index, lut): b,c,hs,ws = index.shape scale = lut.shape[-1] - index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') #? + index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') out = select_index_4dlut_tetrahedral( ixA = index, ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]), @@ -168,7 +168,7 @@ def select_index_1dlut_linear(ixA, lut): out = out.reshape((b,c,h,w)) return out -def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): #self, weight, upscale, mode, img_in, bd): +def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): lut = torch.clamp(lut, 0, 255) dimA, dimB, dimC, dimD = lut.shape[:4] q = 256/(dimA-1) @@ -284,7 +284,5 @@ def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): #self, weight, upsc i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - # out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale)) out = out / q - # print(out.shape) return out diff --git a/src/models/rclut.py b/src/models/rclut.py index 61d780b..a4af1a0 100644 --- a/src/models/rclut.py +++ b/src/models/rclut.py @@ -19,7 +19,7 @@ class RCLutCentered_3x3(nn.Module): self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( rc_conv_luts, dense_conv_lut ): scale = int(dense_conv_lut.shape[-1]) @@ -60,7 +60,7 @@ class RCLutCentered_7x7(nn.Module): self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( rc_conv_luts, dense_conv_lut ): scale = int(dense_conv_lut.shape[-1]) @@ -99,7 +99,7 @@ class RCLutRot90_3x3(nn.Module): self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( rc_conv_luts, dense_conv_lut ): scale = int(dense_conv_lut.shape[-1]) @@ -143,7 +143,7 @@ class RCLutRot90_7x7(nn.Module): self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( rc_conv_luts, dense_conv_lut ): scale = int(dense_conv_lut.shape[-1]) @@ -192,7 +192,7 @@ class RCLutx1(nn.Module): self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( rc_conv_luts_3x3, dense_conv_lut_3x3, rc_conv_luts_5x5, dense_conv_lut_5x5, rc_conv_luts_7x7, dense_conv_lut_7x7 @@ -279,7 +279,7 @@ class RCLutx2(nn.Module): self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, @@ -404,7 +404,7 @@ class RCLutx2Centered(nn.Module): self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, diff --git a/src/models/rcnet.py b/src/models/rcnet.py index d15c23a..70b73a6 100644 --- a/src/models/rcnet.py +++ b/src/models/rcnet.py @@ -68,7 +68,7 @@ class RCNetCentered_3x3(nn.Module): window_size = self.stage.rc_conv.window_size rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutCentered_3x3.init_from_lut(rc_conv_luts, dense_conv_lut) + lut_model = rclut.RCLutCentered_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut) return lut_model class RCNetCentered_7x7(nn.Module): @@ -91,7 +91,7 @@ class RCNetCentered_7x7(nn.Module): window_size = self.stage.rc_conv.window_size rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutCentered_7x7.init_from_lut(rc_conv_luts, dense_conv_lut) + lut_model = rclut.RCLutCentered_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut) return lut_model @@ -149,7 +149,7 @@ class RCNetRot90_3x3(nn.Module): window_size = self.stage.rc_conv.window_size rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutRot90_3x3.init_from_lut(rc_conv_luts, dense_conv_lut) + lut_model = rclut.RCLutRot90_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut) return lut_model def forward(self, x): @@ -177,7 +177,7 @@ class RCNetRot90_7x7(nn.Module): window_size = self.stage.rc_conv.window_size rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutRot90_7x7.init_from_lut(rc_conv_luts, dense_conv_lut) + lut_model = rclut.RCLutRot90_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut) return lut_model def forward(self, x): @@ -212,7 +212,7 @@ class RCNetx1(nn.Module): rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx1.init_from_lut( + lut_model = rclut.RCLutx1.init_from_numpy( rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3, rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5, rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7 @@ -265,7 +265,7 @@ class RCNetx2(nn.Module): s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx2.init_from_lut( + lut_model = rclut.RCLutx2.init_from_numpy( s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, @@ -327,7 +327,7 @@ class RCNetx2Centered(nn.Module): s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx2Centered.init_from_lut( + lut_model = rclut.RCLutx2Centered.init_from_numpy( s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, @@ -430,7 +430,7 @@ class RCNetx2Unlutable(nn.Module): s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx2.init_from_lut( + lut_model = rclut.RCLutx2.init_from_numpy( s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, @@ -535,7 +535,7 @@ class RCNetx2CenteredUnlutable(nn.Module): s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx2Centered.init_from_lut( + lut_model = rclut.RCLutx2Centered.init_from_numpy( s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, diff --git a/src/models/sdylut.py b/src/models/sdylut.py index db4fae7..2fbf128 100644 --- a/src/models/sdylut.py +++ b/src/models/sdylut.py @@ -23,7 +23,7 @@ class SDYLutx1(nn.Module): self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( stageS, stageD, stageY ): scale = int(stageS.shape[-1]) @@ -89,7 +89,7 @@ class SDYLutx2(nn.Module): self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y ): scale = int(stage2_S.shape[-1]) @@ -182,7 +182,7 @@ class SDYLutCenteredx1(nn.Module): self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( stageS, stageD, stageY ): scale = int(stageS.shape[-1]) diff --git a/src/models/sdynet.py b/src/models/sdynet.py index 4ec03d1..64f46db 100644 --- a/src/models/sdynet.py +++ b/src/models/sdynet.py @@ -41,7 +41,7 @@ class SDYNetx1(nn.Module): stageS = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stageD = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) stageY = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = sdylut.SDYLutx1.init_from_lut(stageS, stageD, stageY) + lut_model = sdylut.SDYLutx1.init_from_numpy(stageS, stageD, stageY) return lut_model class SDYNetx2(nn.Module): @@ -93,5 +93,5 @@ class SDYNetx2(nn.Module): stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size) stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size) stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = sdylut.SDYLutx2.init_from_lut(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) + lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) return lut_model \ No newline at end of file diff --git a/src/models/srlut.py b/src/models/srlut.py index 0f8e2fb..4a03c42 100644 --- a/src/models/srlut.py +++ b/src/models/srlut.py @@ -18,7 +18,7 @@ class SRLut(nn.Module): self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( stage_lut ): scale = int(stage_lut.shape[-1]) @@ -51,7 +51,7 @@ class SRLutR90(nn.Module): self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @staticmethod - def init_from_lut( + def init_from_numpy( stage_lut ): scale = int(stage_lut.shape[-1]) @@ -91,7 +91,7 @@ class SRLutR90Y(nn.Module): self.ycbcr_to_rgb = layers.YcbcrToRgb() @staticmethod - def init_from_lut( + def init_from_numpy( stage_lut ): scale = int(stage_lut.shape[-1]) diff --git a/src/models/srnet.py b/src/models/srnet.py index c73030f..e70125d 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -31,7 +31,7 @@ class SRNet(nn.Module): def get_lut_model(self, quantization_interval=16, batch_size=2**10): stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = srlut.SRLut.init_from_lut(stage_lut) + lut_model = srlut.SRLut.init_from_numpy(stage_lut) return lut_model class SRNetR90(nn.Module): @@ -62,7 +62,7 @@ class SRNetR90(nn.Module): def get_lut_model(self, quantization_interval=16, batch_size=2**10): stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = srlut.SRLutR90.init_from_lut(stage_lut) + lut_model = srlut.SRLutR90.init_from_numpy(stage_lut) return lut_model class SRNetR90Y(nn.Module): @@ -95,11 +95,11 @@ class SRNetR90Y(nn.Module): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3]) output /= 4 - output = torch.cat([output, cbcr_scaled], dim=1) + output = torch.cat([output, cbcr_scaled], dim=1) output = self.ycbcr_to_rgb(output).clamp(0, 255) return output def get_lut_model(self, quantization_interval=16, batch_size=2**10): stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = srlut.SRLutR90Y.init_from_lut(stage_lut) - return lut_model \ No newline at end of file + lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut) + return lut_model \ No newline at end of file