renaming and comments

main
vlpr 6 months ago committed by protsenkovi
parent 801402503a
commit c5259398a7

@ -67,7 +67,6 @@ class UpscaleBlock(nn.Module):
x = x.reshape(b, c, h*self.upscale_factor, w*self.upscale_factor) x = x.reshape(b, c, h*self.upscale_factor, w*self.upscale_factor)
return x return x
# https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105
class RgbToYcbcr(nn.Module): class RgbToYcbcr(nn.Module):
r"""Convert an image from RGB to YCbCr. r"""Convert an image from RGB to YCbCr.
@ -84,6 +83,8 @@ class RgbToYcbcr(nn.Module):
>>> input = torch.rand(2, 3, 4, 5) >>> input = torch.rand(2, 3, 4, 5)
>>> ycbcr = RgbToYcbcr() >>> ycbcr = RgbToYcbcr()
>>> output = ycbcr(input) # 2x3x4x5 >>> output = ycbcr(input) # 2x3x4x5
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105
""" """
def forward(self, image): def forward(self, image):
@ -113,6 +114,8 @@ class YcbcrToRgb(nn.Module):
>>> input = torch.rand(2, 3, 4, 5) >>> input = torch.rand(2, 3, 4, 5)
>>> rgb = YcbcrToRgb() >>> rgb = YcbcrToRgb()
>>> output = rgb(input) # 2x3x4x5 >>> output = rgb(input) # 2x3x4x5
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L127
""" """
def forward(self, image): def forward(self, image):

@ -83,7 +83,7 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2*
def forward_2x2_input_SxS_output(index, lut): def forward_2x2_input_SxS_output(index, lut):
b,c,hs,ws = index.shape b,c,hs,ws = index.shape
scale = lut.shape[-1] scale = lut.shape[-1]
index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') #? index = F.pad(input=index, pad=[0,1,0,1], mode='replicate')
out = select_index_4dlut_tetrahedral( out = select_index_4dlut_tetrahedral(
ixA = index, ixA = index,
ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]), ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]),
@ -168,7 +168,7 @@ def select_index_1dlut_linear(ixA, lut):
out = out.reshape((b,c,h,w)) out = out.reshape((b,c,h,w))
return out return out
def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): #self, weight, upscale, mode, img_in, bd): def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
lut = torch.clamp(lut, 0, 255) lut = torch.clamp(lut, 0, 255)
dimA, dimB, dimC, dimD = lut.shape[:4] dimA, dimB, dimC, dimD = lut.shape[:4]
q = 256/(dimA-1) q = 256/(dimA-1)
@ -284,7 +284,5 @@ def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): #self, weight, upsc
i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
# out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale))
out = out / q out = out / q
# print(out.shape)
return out return out

@ -19,7 +19,7 @@ class RCLutCentered_3x3(nn.Module):
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
rc_conv_luts, dense_conv_lut rc_conv_luts, dense_conv_lut
): ):
scale = int(dense_conv_lut.shape[-1]) scale = int(dense_conv_lut.shape[-1])
@ -60,7 +60,7 @@ class RCLutCentered_7x7(nn.Module):
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
rc_conv_luts, dense_conv_lut rc_conv_luts, dense_conv_lut
): ):
scale = int(dense_conv_lut.shape[-1]) scale = int(dense_conv_lut.shape[-1])
@ -99,7 +99,7 @@ class RCLutRot90_3x3(nn.Module):
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
rc_conv_luts, dense_conv_lut rc_conv_luts, dense_conv_lut
): ):
scale = int(dense_conv_lut.shape[-1]) scale = int(dense_conv_lut.shape[-1])
@ -143,7 +143,7 @@ class RCLutRot90_7x7(nn.Module):
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
rc_conv_luts, dense_conv_lut rc_conv_luts, dense_conv_lut
): ):
scale = int(dense_conv_lut.shape[-1]) scale = int(dense_conv_lut.shape[-1])
@ -192,7 +192,7 @@ class RCLutx1(nn.Module):
self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
rc_conv_luts_3x3, dense_conv_lut_3x3, rc_conv_luts_3x3, dense_conv_lut_3x3,
rc_conv_luts_5x5, dense_conv_lut_5x5, rc_conv_luts_5x5, dense_conv_lut_5x5,
rc_conv_luts_7x7, dense_conv_lut_7x7 rc_conv_luts_7x7, dense_conv_lut_7x7
@ -279,7 +279,7 @@ class RCLutx2(nn.Module):
self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
@ -404,7 +404,7 @@ class RCLutx2Centered(nn.Module):
self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,

@ -68,7 +68,7 @@ class RCNetCentered_3x3(nn.Module):
window_size = self.stage.rc_conv.window_size window_size = self.stage.rc_conv.window_size
rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutCentered_3x3.init_from_lut(rc_conv_luts, dense_conv_lut) lut_model = rclut.RCLutCentered_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
return lut_model return lut_model
class RCNetCentered_7x7(nn.Module): class RCNetCentered_7x7(nn.Module):
@ -91,7 +91,7 @@ class RCNetCentered_7x7(nn.Module):
window_size = self.stage.rc_conv.window_size window_size = self.stage.rc_conv.window_size
rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutCentered_7x7.init_from_lut(rc_conv_luts, dense_conv_lut) lut_model = rclut.RCLutCentered_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
return lut_model return lut_model
@ -149,7 +149,7 @@ class RCNetRot90_3x3(nn.Module):
window_size = self.stage.rc_conv.window_size window_size = self.stage.rc_conv.window_size
rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutRot90_3x3.init_from_lut(rc_conv_luts, dense_conv_lut) lut_model = rclut.RCLutRot90_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
return lut_model return lut_model
def forward(self, x): def forward(self, x):
@ -177,7 +177,7 @@ class RCNetRot90_7x7(nn.Module):
window_size = self.stage.rc_conv.window_size window_size = self.stage.rc_conv.window_size
rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutRot90_7x7.init_from_lut(rc_conv_luts, dense_conv_lut) lut_model = rclut.RCLutRot90_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
return lut_model return lut_model
def forward(self, x): def forward(self, x):
@ -212,7 +212,7 @@ class RCNetx1(nn.Module):
rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutx1.init_from_lut( lut_model = rclut.RCLutx1.init_from_numpy(
rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3, rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3,
rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5, rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5,
rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7 rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7
@ -265,7 +265,7 @@ class RCNetx2(nn.Module):
s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutx2.init_from_lut( lut_model = rclut.RCLutx2.init_from_numpy(
s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
@ -327,7 +327,7 @@ class RCNetx2Centered(nn.Module):
s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutx2Centered.init_from_lut( lut_model = rclut.RCLutx2Centered.init_from_numpy(
s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
@ -430,7 +430,7 @@ class RCNetx2Unlutable(nn.Module):
s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutx2.init_from_lut( lut_model = rclut.RCLutx2.init_from_numpy(
s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
@ -535,7 +535,7 @@ class RCNetx2CenteredUnlutable(nn.Module):
s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutx2Centered.init_from_lut( lut_model = rclut.RCLutx2Centered.init_from_numpy(
s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,

@ -23,7 +23,7 @@ class SDYLutx1(nn.Module):
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
stageS, stageD, stageY stageS, stageD, stageY
): ):
scale = int(stageS.shape[-1]) scale = int(stageS.shape[-1])
@ -89,7 +89,7 @@ class SDYLutx2(nn.Module):
self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
): ):
scale = int(stage2_S.shape[-1]) scale = int(stage2_S.shape[-1])
@ -182,7 +182,7 @@ class SDYLutCenteredx1(nn.Module):
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
stageS, stageD, stageY stageS, stageD, stageY
): ):
scale = int(stageS.shape[-1]) scale = int(stageS.shape[-1])

@ -41,7 +41,7 @@ class SDYNetx1(nn.Module):
stageS = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stageS = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
stageD = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) stageD = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size)
stageY = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) stageY = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx1.init_from_lut(stageS, stageD, stageY) lut_model = sdylut.SDYLutx1.init_from_numpy(stageS, stageD, stageY)
return lut_model return lut_model
class SDYNetx2(nn.Module): class SDYNetx2(nn.Module):
@ -93,5 +93,5 @@ class SDYNetx2(nn.Module):
stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size) stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size) stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size) stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx2.init_from_lut(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y)
return lut_model return lut_model

@ -18,7 +18,7 @@ class SRLut(nn.Module):
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
stage_lut stage_lut
): ):
scale = int(stage_lut.shape[-1]) scale = int(stage_lut.shape[-1])
@ -51,7 +51,7 @@ class SRLutR90(nn.Module):
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
stage_lut stage_lut
): ):
scale = int(stage_lut.shape[-1]) scale = int(stage_lut.shape[-1])
@ -91,7 +91,7 @@ class SRLutR90Y(nn.Module):
self.ycbcr_to_rgb = layers.YcbcrToRgb() self.ycbcr_to_rgb = layers.YcbcrToRgb()
@staticmethod @staticmethod
def init_from_lut( def init_from_numpy(
stage_lut stage_lut
): ):
scale = int(stage_lut.shape[-1]) scale = int(stage_lut.shape[-1])

@ -31,7 +31,7 @@ class SRNet(nn.Module):
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_lut(stage_lut) lut_model = srlut.SRLut.init_from_numpy(stage_lut)
return lut_model return lut_model
class SRNetR90(nn.Module): class SRNetR90(nn.Module):
@ -62,7 +62,7 @@ class SRNetR90(nn.Module):
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90.init_from_lut(stage_lut) lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
return lut_model return lut_model
class SRNetR90Y(nn.Module): class SRNetR90Y(nn.Module):
@ -95,11 +95,11 @@ class SRNetR90Y(nn.Module):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3]) output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])
output /= 4 output /= 4
output = torch.cat([output, cbcr_scaled], dim=1) output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255) output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90Y.init_from_lut(stage_lut) lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut)
return lut_model return lut_model
Loading…
Cancel
Save