diff --git a/src/common/layers.py b/src/common/layers.py index d1601ee..339efa2 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -36,10 +36,9 @@ class PercievePattern(): return x class UpscaleBlock(nn.Module): - def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, in_features=4, hidden_dim = 32, layers_count=5, upscale_factor=1): + def __init__(self, in_features=4, hidden_dim = 32, layers_count=5, upscale_factor=1): super(UpscaleBlock, self).__init__() assert layers_count > 0 - self.percieve_pattern = PercievePattern(receptive_field_idxes=receptive_field_idxes, center=center, window_size=window_size) self.upscale_factor = upscale_factor self.hidden_dim = hidden_dim self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True) @@ -52,19 +51,13 @@ class UpscaleBlock(nn.Module): self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) def forward(self, x): - b,c,h,w = x.shape - x = (x-127.5)/127.5 - x = self.percieve_pattern(x) + x = (x-127.5)/127.5 x = torch.relu(self.embed(x)) for linear_projection in self.linear_projections: x = torch.cat([x, torch.relu(linear_projection(x))], dim=2) x = self.project_channels(x) x = torch.tanh(x) x = x*127.5 + 127.5 - x = round_func(x) - x = x.reshape(b, c, h, w, self.upscale_factor, self.upscale_factor) - x = x.permute(0,1,2,4,3,5) - x = x.reshape(b, c, h*self.upscale_factor, w*self.upscale_factor) return x class RgbToYcbcr(nn.Module): diff --git a/src/common/lut.py b/src/common/lut.py index 42b91ea..006b37a 100644 --- a/src/common/lut.py +++ b/src/common/lut.py @@ -14,7 +14,7 @@ class Domain4DValues(Dataset): values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8) values1d = torch.cat([values1d, torch.tensor([256])]) self.quantization_interval = quantization_interval - self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 2, 2) + self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 4) def __getitem__(self, idx): if isinstance(idx, slice): @@ -30,7 +30,7 @@ class Domain4DValues(Dataset): else: v = self.values[idx] ix = v[0]//self.quantization_interval - return ix[0,0], ix[0,1], ix[1,0], ix[1,1], v + return ix[0], ix[1], ix[2], ix[3], v def __len__(self): return len(self.values) @@ -56,7 +56,7 @@ def transfer_rc_conv(rc_conv, quantization_interval=1): def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10): bucket_count = 256//quantization_interval scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1 - lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, 1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2 + lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2 domain_values = Domain4DValues(quantization_interval=quantization_interval) domain_values_loader = DataLoader( domain_values, @@ -68,88 +68,55 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2* for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader): inputs = batch.type(torch.float32).cuda() with torch.no_grad(): - outputs = block(inputs)[:,:,:] - lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8)[:,:,:scale,:scale] + outputs = block(inputs) + lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.reshape(-1, scale, scale).cpu().numpy().astype(np.uint8) counter += inputs.shape[0] print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ") print() - lut = lut.squeeze(-3) return lut ##################### FORWARD ########################## -def forward_2x2_input_SxS_output(index, lut): - b,c,hs,ws = index.shape - scale = lut.shape[-1] - index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') - out = select_index_4dlut_tetrahedral( - ixA = index, - ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]), - ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]), - ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]), - lut = lut - ) - out = out[:,:,0:-1,0:-1,:,:] # unpad - # Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504] - out = out.permute(0,1,2,4,3,5).reshape(b*c,1,hs*scale,ws*scale) - out = round_func(out) - return out - -def forward_unfolded_2x2_input_SxS_output(index, lut): - b,c,hs,ws = index.shape - scale = lut.shape[-1] - out = select_index_4dlut_tetrahedral( - ixA = index, - ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]), - ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]), - ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]), - lut = lut - ) - out = out[:,:,0:-1,0:-1,:,:] # unpad - # Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504] - out = out.permute(0,1,2,4,3,5).reshape(b*c,1,scale,scale) - out = round_func(out) - return out - -def forward_rc_conv_centered(index, lut): - window_size = lut.shape[0] - index = F.pad(index, pad=[window_size//2]*4, mode='replicate') - window_indexes = lut.shape[:-1] - # index = index.unsqueeze(-1) - x = torch.zeros_like(index) - for i in range(window_indexes[-2]): - for j in range(window_indexes[-1]): - shift_i, shift_j = -window_indexes[-2]//2+1 + i, -window_indexes[-1]//2+1 + j - shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1]) - x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j]) - x /= window_indexes[-2]*window_indexes[-1] - x = x.squeeze(-1) - x = round_func(x) - x = x[:,:,window_size//2:-window_size//2+1,window_size//2:-window_size//2+1] - return x - -def forward_rc_conv_rot90(index, lut): - window_size = lut.shape[0] - index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate') - window_indexes = lut.shape[:-1] - # index = index.unsqueeze(-1) - x = torch.zeros_like(index) - for i in range(window_indexes[-2]): - for j in range(window_indexes[-1]): - shift_i, shift_j = i, j - shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1]) - x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j]) - x /= window_indexes[-2]*window_indexes[-1] - x = x.squeeze(-1) - x = round_func(x) - x = x[:,:,:-(window_size-1),:-(window_size-1)] - return x +# def forward_rc_conv_centered(index, lut): +# window_size = lut.shape[0] +# index = F.pad(index, pad=[window_size//2]*4, mode='replicate') +# window_indexes = lut.shape[:-1] +# # index = index.unsqueeze(-1) +# x = torch.zeros_like(index) +# for i in range(window_indexes[-2]): +# for j in range(window_indexes[-1]): +# shift_i, shift_j = -window_indexes[-2]//2+1 + i, -window_indexes[-1]//2+1 + j +# shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1]) +# x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j]) +# x /= window_indexes[-2]*window_indexes[-1] +# x = x.squeeze(-1) +# x = round_func(x) +# x = x[:,:,window_size//2:-window_size//2+1,window_size//2:-window_size//2+1] +# return x + +# def forward_rc_conv_rot90(index, lut): +# window_size = lut.shape[0] +# index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate') +# window_indexes = lut.shape[:-1] +# # index = index.unsqueeze(-1) +# x = torch.zeros_like(index) +# for i in range(window_indexes[-2]): +# for j in range(window_indexes[-1]): +# shift_i, shift_j = i, j +# shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1]) +# x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j]) +# x /= window_indexes[-2]*window_indexes[-1] +# x = x.squeeze(-1) +# x = round_func(x) +# x = x[:,:,:-(window_size-1),:-(window_size-1)] +# return x ##################### UTILS ########################## +# TODO rewrite for unfolded def select_index_1dlut_linear(ixA, lut): lut = torch.clamp(lut, 0, 255) b,c,h,w = ixA.shape @@ -168,75 +135,71 @@ def select_index_1dlut_linear(ixA, lut): out = out.reshape((b,c,h,w)) return out -def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): +def select_index_4dlut_tetrahedral(index, lut): + b, hw, c = index.shape lut = torch.clamp(lut, 0, 255) dimA, dimB, dimC, dimD = lut.shape[:4] q = 256/(dimA-1) L = dimA upscale = lut.shape[-1] - weight = lut.reshape(L**4,upscale,upscale) - - img_a1 = torch.floor_divide(ixA, q).type(torch.int64) - img_b1 = torch.floor_divide(ixB, q).type(torch.int64) - img_c1 = torch.floor_divide(ixC, q).type(torch.int64) - img_d1 = torch.floor_divide(ixD, q).type(torch.int64) - - # Extract LSBs - fa = ixA % q - fb = ixB % q - fc = ixC % q - fd = ixD % q - - img_a2 = img_a1 + 1 - img_b2 = img_b1 + 1 - img_c2 = img_c1 + 1 - img_d2 = img_d1 + 1 - - p0000 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p0001 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p0010 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p0011 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p0100 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p0101 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p0110 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p0111 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - - p1000 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p1001 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p1010 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p1011 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p1100 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p1101 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p1110 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - p1111 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - - out = torch.zeros((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale), dtype=weight.dtype).to(device=weight.device) - sz = img_a1.shape[0] * img_a1.shape[1] * img_a1.shape[2] * img_a1.shape[3] - out = out.reshape(sz, -1) - - p0000 = p0000.reshape(sz, -1) - p0100 = p0100.reshape(sz, -1) - p1000 = p1000.reshape(sz, -1) - p1100 = p1100.reshape(sz, -1) - fa = fa.reshape(-1, 1) - - p0001 = p0001.reshape(sz, -1) - p0101 = p0101.reshape(sz, -1) - p1001 = p1001.reshape(sz, -1) - p1101 = p1101.reshape(sz, -1) - fb = fb.reshape(-1, 1) - fc = fc.reshape(-1, 1) - - p0010 = p0010.reshape(sz, -1) - p0110 = p0110.reshape(sz, -1) - p1010 = p1010.reshape(sz, -1) - p1110 = p1110.reshape(sz, -1) - fd = fd.reshape(-1, 1) - - p0011 = p0011.reshape(sz, -1) - p0111 = p0111.reshape(sz, -1) - p1011 = p1011.reshape(sz, -1) - p1111 = p1111.reshape(sz, -1) + weight = lut.reshape(L**4, upscale, upscale) + + msbA = torch.floor_divide(index, q).type(torch.int64) + msbB = msbA + 1 + lsb = index % q + + img_a1 = msbA[:,:,0].reshape(b*hw, 1) + img_b1 = msbA[:,:,1].reshape(b*hw, 1) + img_c1 = msbA[:,:,2].reshape(b*hw, 1) + img_d1 = msbA[:,:,3].reshape(b*hw, 1) + img_a2 = msbB[:,:,0].reshape(b*hw, 1) + img_b2 = msbB[:,:,1].reshape(b*hw, 1) + img_c2 = msbB[:,:,2].reshape(b*hw, 1) + img_d2 = msbB[:,:,3].reshape(b*hw, 1) + fa = lsb[:,:,0].reshape(b*hw, 1) + fb = lsb[:,:,1].reshape(b*hw, 1) + fc = lsb[:,:,2].reshape(b*hw, 1) + fd = lsb[:,:,3].reshape(b*hw, 1) + + p0000 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c1 * L + img_d1] + p0001 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c1 * L + img_d2] + p0010 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c2 * L + img_d1] + p0011 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c2 * L + img_d2] + p0100 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c1 * L + img_d1] + p0101 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c1 * L + img_d2] + p0110 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c2 * L + img_d1] + p0111 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c2 * L + img_d2] + + p1000 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c1 * L + img_d1] + p1001 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c1 * L + img_d2] + p1010 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c2 * L + img_d1] + p1011 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c2 * L + img_d2] + p1100 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c1 * L + img_d1] + p1101 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c1 * L + img_d2] + p1110 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c2 * L + img_d1] + p1111 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c2 * L + img_d2] + + out = torch.zeros((b*hw, upscale*upscale), dtype=weight.dtype).to(device=weight.device) + + p0000 = p0000.reshape(b*hw, upscale*upscale) + p0100 = p0100.reshape(b*hw, upscale*upscale) + p1000 = p1000.reshape(b*hw, upscale*upscale) + p1100 = p1100.reshape(b*hw, upscale*upscale) + + p0001 = p0001.reshape(b*hw, upscale*upscale) + p0101 = p0101.reshape(b*hw, upscale*upscale) + p1001 = p1001.reshape(b*hw, upscale*upscale) + p1101 = p1101.reshape(b*hw, upscale*upscale) + + p0010 = p0010.reshape(b*hw, upscale*upscale) + p0110 = p0110.reshape(b*hw, upscale*upscale) + p1010 = p1010.reshape(b*hw, upscale*upscale) + p1110 = p1110.reshape(b*hw, upscale*upscale) + + p0011 = p0011.reshape(b*hw, upscale*upscale) + p0111 = p0111.reshape(b*hw, upscale*upscale) + p1011 = p1011.reshape(b*hw, upscale*upscale) + p1111 = p1111.reshape(b*hw, upscale*upscale) fab = fa > fb; fac = fa > fc; @@ -257,14 +220,7 @@ def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): i8 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] i9 = i = torch.all(torch.cat([~(fbc), ~(fac), fab, fbd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i] - # Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first! - # i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], fab, fcd], dim=1), dim=1) - # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i] - # i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad], dim=1), dim=1) - # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i] - # c > a > d > b i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], fab, fad], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] - # c > d > a > b i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] i12 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] @@ -283,6 +239,6 @@ def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): i23 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] - out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + out = out.reshape((b, hw, upscale*upscale)) out = out / q - return out + return out \ No newline at end of file diff --git a/src/models/rclut.py b/src/models/rclut.py index a4af1a0..f51413b 100644 --- a/src/models/rclut.py +++ b/src/models/rclut.py @@ -3,503 +3,503 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np from common.utils import round_func -from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output -from pathlib import Path - -class RCLutCentered_3x3(nn.Module): - def __init__( - self, - quantization_interval, - scale - ): - super(RCLutCentered_3x3, self).__init__() - self.scale = scale - self.quantization_interval = quantization_interval - self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) - - @staticmethod - def init_from_numpy( - rc_conv_luts, dense_conv_lut - ): - scale = int(dense_conv_lut.shape[-1]) - quantization_interval = 256//(dense_conv_lut.shape[0]-1) - lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale) - lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) - lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w).type(torch.float32) - x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') - x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts) - x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1] - x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) - x = x.view(b, c, x.shape[-2], x.shape[-1]) - return x +# from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output +# from pathlib import Path + +# class RCLutCentered_3x3(nn.Module): +# def __init__( +# self, +# quantization_interval, +# scale +# ): +# super(RCLutCentered_3x3, self).__init__() +# self.scale = scale +# self.quantization_interval = quantization_interval +# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) + +# @staticmethod +# def init_from_numpy( +# rc_conv_luts, dense_conv_lut +# ): +# scale = int(dense_conv_lut.shape[-1]) +# quantization_interval = 256//(dense_conv_lut.shape[0]-1) +# lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale) +# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) +# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w).type(torch.float32) +# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') +# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts) +# x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1] +# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) +# x = x.view(b, c, x.shape[-2], x.shape[-1]) +# return x - def __repr__(self): - return "\n".join([ - f"{self.__class__.__name__}(", - f" rc_conv_luts size: {self.rc_conv_luts.shape}", - f" dense_conv_lut size: {self.dense_conv_lut.shape}", - ")"]) - -class RCLutCentered_7x7(nn.Module): - def __init__( - self, - window_size, - quantization_interval, - scale - ): - super(RCLutCentered_7x7, self).__init__() - self.scale = scale - self.quantization_interval = quantization_interval - self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) - - @staticmethod - def init_from_numpy( - rc_conv_luts, dense_conv_lut - ): - scale = int(dense_conv_lut.shape[-1]) - quantization_interval = 256//(dense_conv_lut.shape[0]-1) - lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale) - lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) - lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w).type(torch.float32) - x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts) - x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) - # x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4) - x = x.view(b, c, x.shape[-2], x.shape[-1]) - return x +# def __repr__(self): +# return "\n".join([ +# f"{self.__class__.__name__}(", +# f" rc_conv_luts size: {self.rc_conv_luts.shape}", +# f" dense_conv_lut size: {self.dense_conv_lut.shape}", +# ")"]) + +# class RCLutCentered_7x7(nn.Module): +# def __init__( +# self, +# window_size, +# quantization_interval, +# scale +# ): +# super(RCLutCentered_7x7, self).__init__() +# self.scale = scale +# self.quantization_interval = quantization_interval +# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) + +# @staticmethod +# def init_from_numpy( +# rc_conv_luts, dense_conv_lut +# ): +# scale = int(dense_conv_lut.shape[-1]) +# quantization_interval = 256//(dense_conv_lut.shape[0]-1) +# lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale) +# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) +# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w).type(torch.float32) +# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts) +# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) +# # x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4) +# x = x.view(b, c, x.shape[-2], x.shape[-1]) +# return x - def __repr__(self): - return "\n".join([ - f"{self.__class__.__name__}(", - f" rc_conv_luts size: {self.rc_conv_luts.shape}", - f" dense_conv_lut size: {self.dense_conv_lut.shape}", - ")"]) - -class RCLutRot90_3x3(nn.Module): - def __init__( - self, - quantization_interval, - scale - ): - super(RCLutRot90_3x3, self).__init__() - self.scale = scale - self.quantization_interval = quantization_interval - self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) - - @staticmethod - def init_from_numpy( - rc_conv_luts, dense_conv_lut - ): - scale = int(dense_conv_lut.shape[-1]) - quantization_interval = 256//(dense_conv_lut.shape[0]-1) - lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale) - lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) - lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w).type(torch.float32) - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts) - rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) - unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) - output += unrotated_prediction - output /= 4 - output = output.view(b, c, output.shape[-2], output.shape[-1]) - return output +# def __repr__(self): +# return "\n".join([ +# f"{self.__class__.__name__}(", +# f" rc_conv_luts size: {self.rc_conv_luts.shape}", +# f" dense_conv_lut size: {self.dense_conv_lut.shape}", +# ")"]) + +# class RCLutRot90_3x3(nn.Module): +# def __init__( +# self, +# quantization_interval, +# scale +# ): +# super(RCLutRot90_3x3, self).__init__() +# self.scale = scale +# self.quantization_interval = quantization_interval +# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) + +# @staticmethod +# def init_from_numpy( +# rc_conv_luts, dense_conv_lut +# ): +# scale = int(dense_conv_lut.shape[-1]) +# quantization_interval = 256//(dense_conv_lut.shape[0]-1) +# lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale) +# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) +# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w).type(torch.float32) +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts) +# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) +# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) +# output += unrotated_prediction +# output /= 4 +# output = output.view(b, c, output.shape[-2], output.shape[-1]) +# return output - def __repr__(self): - return "\n".join([ - f"{self.__class__.__name__}(", - f" rc_conv_luts size: {self.rc_conv_luts.shape}", - f" dense_conv_lut size: {self.dense_conv_lut.shape}", - ")"]) - -class RCLutRot90_7x7(nn.Module): - def __init__( - self, - quantization_interval, - scale - ): - super(RCLutRot90_7x7, self).__init__() - self.scale = scale - self.quantization_interval = quantization_interval - self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) - - @staticmethod - def init_from_numpy( - rc_conv_luts, dense_conv_lut - ): - scale = int(dense_conv_lut.shape[-1]) - quantization_interval = 256//(dense_conv_lut.shape[0]-1) - window_size = rc_conv_luts.shape[0] - lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale) - lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) - lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w).type(torch.float32) - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts) - rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) - unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) - output += unrotated_prediction - output /= 4 - output = output.view(b, c, output.shape[-2], output.shape[-1]) - return output +# def __repr__(self): +# return "\n".join([ +# f"{self.__class__.__name__}(", +# f" rc_conv_luts size: {self.rc_conv_luts.shape}", +# f" dense_conv_lut size: {self.dense_conv_lut.shape}", +# ")"]) + +# class RCLutRot90_7x7(nn.Module): +# def __init__( +# self, +# quantization_interval, +# scale +# ): +# super(RCLutRot90_7x7, self).__init__() +# self.scale = scale +# self.quantization_interval = quantization_interval +# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) + +# @staticmethod +# def init_from_numpy( +# rc_conv_luts, dense_conv_lut +# ): +# scale = int(dense_conv_lut.shape[-1]) +# quantization_interval = 256//(dense_conv_lut.shape[0]-1) +# window_size = rc_conv_luts.shape[0] +# lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale) +# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) +# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w).type(torch.float32) +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts) +# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) +# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) +# output += unrotated_prediction +# output /= 4 +# output = output.view(b, c, output.shape[-2], output.shape[-1]) +# return output - def __repr__(self): - return "\n".join([ - f"{self.__class__.__name__}(", - f" rc_conv_luts size: {self.rc_conv_luts.shape}", - f" dense_conv_lut size: {self.dense_conv_lut.shape}", - ")"]) - -class RCLutx1(nn.Module): - def __init__( - self, - quantization_interval, - scale - ): - super(RCLutx1, self).__init__() - self.scale = scale - self.quantization_interval = quantization_interval - self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) - self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) - self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) - self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - - @staticmethod - def init_from_numpy( - rc_conv_luts_3x3, dense_conv_lut_3x3, - rc_conv_luts_5x5, dense_conv_lut_5x5, - rc_conv_luts_7x7, dense_conv_lut_7x7 - ): - scale = int(dense_conv_lut_3x3.shape[-1]) - quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1) - - lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale) - - lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32)) - lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32)) - - lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32)) - lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32)) - - lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32)) - lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32)) - - return lut_model - - def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): - x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) - x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) - return x - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w).type(torch.float32) - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7), - k=-rotations_count, - dims=[2, 3] - ) - output /= 3*4 - output = output.view(b, c, output.shape[-2], output.shape[-1]) - return output +# def __repr__(self): +# return "\n".join([ +# f"{self.__class__.__name__}(", +# f" rc_conv_luts size: {self.rc_conv_luts.shape}", +# f" dense_conv_lut size: {self.dense_conv_lut.shape}", +# ")"]) + +# class RCLutx1(nn.Module): +# def __init__( +# self, +# quantization_interval, +# scale +# ): +# super(RCLutx1, self).__init__() +# self.scale = scale +# self.quantization_interval = quantization_interval +# self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) +# self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) +# self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) +# self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + +# @staticmethod +# def init_from_numpy( +# rc_conv_luts_3x3, dense_conv_lut_3x3, +# rc_conv_luts_5x5, dense_conv_lut_5x5, +# rc_conv_luts_7x7, dense_conv_lut_7x7 +# ): +# scale = int(dense_conv_lut_3x3.shape[-1]) +# quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1) + +# lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale) + +# lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32)) +# lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32)) + +# lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32)) +# lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32)) + +# lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32)) +# lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32)) + +# return lut_model + +# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): +# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) +# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) +# return x + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w).type(torch.float32) +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output /= 3*4 +# output = output.view(b, c, output.shape[-2], output.shape[-1]) +# return output - def __repr__(self): - return "\n".join([ - f"{self.__class__.__name__}(", - f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}", - f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}", - f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}", - f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}", - f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}", - f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}", - ")"]) - - - -class RCLutx2(nn.Module): - def __init__( - self, - quantization_interval, - scale - ): - super(RCLutx2, self).__init__() - self.scale = scale - self.quantization_interval = quantization_interval - self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) - self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) - self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) - self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) - self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) - self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) - self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) - self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) - self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) - self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - - @staticmethod - def init_from_numpy( - s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, - s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, - s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, - s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3, - s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5, - s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7 - ): - scale = int(s2_dense_conv_lut_3x3.shape[-1]) - quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1) - - lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale) - - lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32)) - lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32)) - - lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32)) - lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32)) - - lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32)) - lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32)) - - lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32)) - lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32)) - - lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32)) - lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32)) - - lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32)) - lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32)) - - return lut_model - - def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): - x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) - x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) - return x - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w).type(torch.float32) - output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), - k=-rotations_count, - dims=[2, 3] - ) - output /= 3*4 - x = output - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), - k=-rotations_count, - dims=[2, 3] - ) - output /= 3*4 - output = output.view(b, c, output.shape[-2], output.shape[-1]) - return output +# def __repr__(self): +# return "\n".join([ +# f"{self.__class__.__name__}(", +# f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}", +# f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}", +# f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}", +# f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}", +# f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}", +# f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}", +# ")"]) + + + +# class RCLutx2(nn.Module): +# def __init__( +# self, +# quantization_interval, +# scale +# ): +# super(RCLutx2, self).__init__() +# self.scale = scale +# self.quantization_interval = quantization_interval +# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) +# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) +# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) +# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) +# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) +# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) +# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) +# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) +# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) +# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + +# @staticmethod +# def init_from_numpy( +# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, +# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, +# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, +# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3, +# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5, +# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7 +# ): +# scale = int(s2_dense_conv_lut_3x3.shape[-1]) +# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1) + +# lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale) + +# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32)) +# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32)) + +# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32)) +# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32)) + +# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32)) +# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32)) + +# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32)) +# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32)) + +# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32)) +# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32)) + +# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32)) +# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32)) + +# return lut_model + +# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): +# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) +# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) +# return x + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w).type(torch.float32) +# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output /= 3*4 +# x = output +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output /= 3*4 +# output = output.view(b, c, output.shape[-2], output.shape[-1]) +# return output - def __repr__(self): - return "\n".join([ - f"{self.__class__.__name__}(", - f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}", - f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}", - f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}", - f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}", - f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}", - f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}", - f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}", - f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}", - f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}", - f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}", - f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}", - f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}", - ")"]) - - - -class RCLutx2Centered(nn.Module): - def __init__( - self, - quantization_interval, - scale - ): - super(RCLutx2Centered, self).__init__() - self.scale = scale - self.quantization_interval = quantization_interval - self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) - self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) - self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) - self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) - self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) - self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) - self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) - self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) - self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) - self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - - @staticmethod - def init_from_numpy( - s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, - s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, - s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, - s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3, - s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5, - s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7 - ): - scale = int(s2_dense_conv_lut_3x3.shape[-1]) - quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1) - - lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale) - - lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32)) - lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32)) - - lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32)) - lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32)) - - lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32)) - lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32)) - - lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32)) - lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32)) - - lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32)) - lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32)) - - lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32)) - lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32)) - - return lut_model - - def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): - x = forward_rc_conv_centered(index=index, lut=rc_conv_lut) - x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) - return x - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w).type(torch.float32) - output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), - k=-rotations_count, - dims=[2, 3] - ) - output /= 3*4 - x = output - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), - k=-rotations_count, - dims=[2, 3] - ) - output += torch.rot90( - self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), - k=-rotations_count, - dims=[2, 3] - ) - output /= 3*4 - output = output.view(b, c, output.shape[-2], output.shape[-1]) - return output +# def __repr__(self): +# return "\n".join([ +# f"{self.__class__.__name__}(", +# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}", +# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}", +# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}", +# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}", +# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}", +# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}", +# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}", +# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}", +# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}", +# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}", +# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}", +# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}", +# ")"]) + + + +# class RCLutx2Centered(nn.Module): +# def __init__( +# self, +# quantization_interval, +# scale +# ): +# super(RCLutx2Centered, self).__init__() +# self.scale = scale +# self.quantization_interval = quantization_interval +# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) +# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) +# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) +# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) +# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) +# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) +# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) +# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) +# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) +# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) +# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + +# @staticmethod +# def init_from_numpy( +# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, +# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, +# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, +# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3, +# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5, +# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7 +# ): +# scale = int(s2_dense_conv_lut_3x3.shape[-1]) +# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1) + +# lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale) + +# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32)) +# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32)) + +# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32)) +# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32)) + +# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32)) +# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32)) + +# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32)) +# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32)) + +# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32)) +# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32)) + +# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32)) +# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32)) + +# return lut_model + +# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): +# x = forward_rc_conv_centered(index=index, lut=rc_conv_lut) +# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) +# return x + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w).type(torch.float32) +# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output /= 3*4 +# x = output +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output += torch.rot90( +# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), +# k=-rotations_count, +# dims=[2, 3] +# ) +# output /= 3*4 +# output = output.view(b, c, output.shape[-2], output.shape[-1]) +# return output - def __repr__(self): - return "\n".join([ - f"{self.__class__.__name__}(", - f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}", - f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}", - f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}", - f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}", - f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}", - f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}", - f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}", - f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}", - f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}", - f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}", - f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}", - f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}", - ")"]) \ No newline at end of file +# def __repr__(self): +# return "\n".join([ +# f"{self.__class__.__name__}(", +# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}", +# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}", +# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}", +# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}", +# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}", +# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}", +# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}", +# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}", +# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}", +# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}", +# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}", +# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}", +# ")"]) \ No newline at end of file diff --git a/src/models/rcnet.py b/src/models/rcnet.py index 70b73a6..9d64a3e 100644 --- a/src/models/rcnet.py +++ b/src/models/rcnet.py @@ -8,561 +8,561 @@ from common import lut from . import rclut from common import layers -class ReconstructedConvCentered(nn.Module): - def __init__(self, hidden_dim, window_size=7): - super(ReconstructedConvCentered, self).__init__() - self.window_size = window_size - self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - - def pixel_wise_forward(self, x): - x = (x-127.5)/127.5 - out = torch.einsum('bwk,wh,wh -> bwk', x, self.projection1, self.projection2) - out = torch.tanh(out) - out = out*127.5 + 127.5 - return out - - def forward(self, x): - original_shape = x.shape - x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') - x = F.unfold(x, self.window_size) - x = self.pixel_wise_forward(x) - x = x.mean(1) - x = x.reshape(*original_shape) - x = round_func(x) - return x - - def __repr__(self): - return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" - -class RCBlockCentered(nn.Module): - def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): - super(RCBlockCentered, self).__init__() - self.window_size = window_size - self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) - - def forward(self, x): - b,c,hs,ws = x.shape - x = self.rc_conv(x) - x = F.pad(x, pad=[0,1,0,1], mode='replicate') - x = self.dense_conv_block(x) - return x - -class RCNetCentered_3x3(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetCentered_3x3, self).__init__() - self.hidden_dim = hidden_dim - self.layers_count = layers_count - self.scale = scale - self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) +# class ReconstructedConvCentered(nn.Module): +# def __init__(self, hidden_dim, window_size=7): +# super(ReconstructedConvCentered, self).__init__() +# self.window_size = window_size +# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) +# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + +# def pixel_wise_forward(self, x): +# x = (x-127.5)/127.5 +# out = torch.einsum('bwk,wh,wh -> bwk', x, self.projection1, self.projection2) +# out = torch.tanh(out) +# out = out*127.5 + 127.5 +# return out + +# def forward(self, x): +# original_shape = x.shape +# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') +# x = F.unfold(x, self.window_size) +# x = self.pixel_wise_forward(x) +# x = x.mean(1) +# x = x.reshape(*original_shape) +# x = round_func(x) +# return x + +# def __repr__(self): +# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" + +# class RCBlockCentered(nn.Module): +# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): +# super(RCBlockCentered, self).__init__() +# self.window_size = window_size +# self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size) +# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + +# def forward(self, x): +# b,c,hs,ws = x.shape +# x = self.rc_conv(x) +# x = F.pad(x, pad=[0,1,0,1], mode='replicate') +# x = self.dense_conv_block(x) +# return x + +# class RCNetCentered_3x3(nn.Module): +# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): +# super(RCNetCentered_3x3, self).__init__() +# self.hidden_dim = hidden_dim +# self.layers_count = layers_count +# self.scale = scale +# self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - x = self.stage(x) - x = x.view(b, c, h*self.scale, w*self.scale) - return x - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - window_size = self.stage.rc_conv.window_size - rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) - dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutCentered_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut) - return lut_model - -class RCNetCentered_7x7(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetCentered_7x7, self).__init__() - self.hidden_dim = hidden_dim - self.layers_count = layers_count - self.scale = scale - window_size = 7 - self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size) +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w) +# x = self.stage(x) +# x = x.view(b, c, h*self.scale, w*self.scale) +# return x + +# def get_lut_model(self, quantization_interval=16, batch_size=2**10): +# window_size = self.stage.rc_conv.window_size +# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) +# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# lut_model = rclut.RCLutCentered_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut) +# return lut_model + +# class RCNetCentered_7x7(nn.Module): +# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): +# super(RCNetCentered_7x7, self).__init__() +# self.hidden_dim = hidden_dim +# self.layers_count = layers_count +# self.scale = scale +# window_size = 7 +# self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size) - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - x = self.stage(x) - x = x.view(b, c, h*self.scale, w*self.scale) - return x - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - window_size = self.stage.rc_conv.window_size - rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) - dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutCentered_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut) - return lut_model - - -class ReconstructedConvRot90(nn.Module): - def __init__(self, hidden_dim, window_size=7): - super(ReconstructedConvRot90, self).__init__() - self.window_size = window_size - self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - - def pixel_wise_forward(self, x): - x = (x-127.5)/127.5 - out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) - out = torch.tanh(out) - out = out*127.5 + 127.5 - return out - - def forward(self, x): - original_shape = x.shape - x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') - x = F.unfold(x, self.window_size) - x = self.pixel_wise_forward(x) - x = x.mean(1) - x = x.reshape(*original_shape) - x = round_func(x) # quality likely suffer from this - return x - - def __repr__(self): - return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" - -class RCBlockRot90(nn.Module): - def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): - super(RCBlockRot90, self).__init__() - self.window_size = window_size - self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) - - def forward(self, x): - b,c,hs,ws = x.shape - x = self.rc_conv(x) - x = F.pad(x, pad=[0,1,0,1], mode='replicate') - x = self.dense_conv_block(x) +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w) +# x = self.stage(x) +# x = x.view(b, c, h*self.scale, w*self.scale) +# return x + +# def get_lut_model(self, quantization_interval=16, batch_size=2**10): +# window_size = self.stage.rc_conv.window_size +# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) +# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# lut_model = rclut.RCLutCentered_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut) +# return lut_model + + +# class ReconstructedConvRot90(nn.Module): +# def __init__(self, hidden_dim, window_size=7): +# super(ReconstructedConvRot90, self).__init__() +# self.window_size = window_size +# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) +# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + +# def pixel_wise_forward(self, x): +# x = (x-127.5)/127.5 +# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) +# out = torch.tanh(out) +# out = out*127.5 + 127.5 +# return out + +# def forward(self, x): +# original_shape = x.shape +# x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') +# x = F.unfold(x, self.window_size) +# x = self.pixel_wise_forward(x) +# x = x.mean(1) +# x = x.reshape(*original_shape) +# x = round_func(x) # quality likely suffer from this +# return x + +# def __repr__(self): +# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" + +# class RCBlockRot90(nn.Module): +# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): +# super(RCBlockRot90, self).__init__() +# self.window_size = window_size +# self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size) +# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + +# def forward(self, x): +# b,c,hs,ws = x.shape +# x = self.rc_conv(x) +# x = F.pad(x, pad=[0,1,0,1], mode='replicate') +# x = self.dense_conv_block(x) - return x - -class RCNetRot90_3x3(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetRot90_3x3, self).__init__() - self.hidden_dim = hidden_dim - self.layers_count = layers_count - self.scale = scale - self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - window_size = self.stage.rc_conv.window_size - rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) - dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutRot90_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - rotated_prediction = self.stage(rotated) - unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) - output += unrotated_prediction - output /= 4 - output = output.view(b, c, h*self.scale, w*self.scale) - return output - -class RCNetRot90_7x7(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetRot90_7x7, self).__init__() - self.hidden_dim = hidden_dim - self.layers_count = layers_count - self.scale = scale - self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - window_size = self.stage.rc_conv.window_size - rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) - dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutRot90_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - rotated_prediction = self.stage(rotated) - unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) - output += unrotated_prediction - output /= 4 - output = output.view(b, c, h*self.scale, w*self.scale) - return output - -class RCNetx1(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetx1, self).__init__() - self.scale = scale - self.hidden_dim = hidden_dim - self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) - self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) - self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) - dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) - dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) - dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# return x + +# class RCNetRot90_3x3(nn.Module): +# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): +# super(RCNetRot90_3x3, self).__init__() +# self.hidden_dim = hidden_dim +# self.layers_count = layers_count +# self.scale = scale +# self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) + +# def get_lut_model(self, quantization_interval=16, batch_size=2**10): +# window_size = self.stage.rc_conv.window_size +# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) +# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# lut_model = rclut.RCLutRot90_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w) +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# rotated_prediction = self.stage(rotated) +# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) +# output += unrotated_prediction +# output /= 4 +# output = output.view(b, c, h*self.scale, w*self.scale) +# return output + +# class RCNetRot90_7x7(nn.Module): +# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): +# super(RCNetRot90_7x7, self).__init__() +# self.hidden_dim = hidden_dim +# self.layers_count = layers_count +# self.scale = scale +# self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + +# def get_lut_model(self, quantization_interval=16, batch_size=2**10): +# window_size = self.stage.rc_conv.window_size +# rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) +# dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# lut_model = rclut.RCLutRot90_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w) +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# rotated_prediction = self.stage(rotated) +# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) +# output += unrotated_prediction +# output /= 4 +# output = output.view(b, c, h*self.scale, w*self.scale) +# return output + +# class RCNetx1(nn.Module): +# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): +# super(RCNetx1, self).__init__() +# self.scale = scale +# self.hidden_dim = hidden_dim +# self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) +# self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) +# self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + +# def get_lut_model(self, quantization_interval=16, batch_size=2**10): +# rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) +# dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) +# dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) +# dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx1.init_from_numpy( - rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3, - rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5, - rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7 - ) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) +# lut_model = rclut.RCLutx1.init_from_numpy( +# rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3, +# rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5, +# rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7 +# ) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w) +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) - output /= 3*4 - output = output.view(b, c, h*self.scale, w*self.scale) - return output +# output /= 3*4 +# output = output.view(b, c, h*self.scale, w*self.scale) +# return output -class RCNetx2(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetx2, self).__init__() - self.scale = scale - self.hidden_dim = hidden_dim - self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) - self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) - self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) - self.stage2_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) - self.stage2_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) - self.stage2_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) +# class RCNetx2(nn.Module): +# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): +# super(RCNetx2, self).__init__() +# self.scale = scale +# self.hidden_dim = hidden_dim +# self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) +# self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) +# self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) +# self.stage2_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) +# self.stage2_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) +# self.stage2_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) - s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# def get_lut_model(self, quantization_interval=16, batch_size=2**10): +# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) +# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) - s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) +# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) - s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) +# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) - s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) +# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) - s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) +# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) - s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) +# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx2.init_from_numpy( - s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, - s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, - s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, - s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, - s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, - s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 - ) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) - output /= 3*4 - x = output - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) - output /= 3*4 - output = output.view(b, c, h*self.scale, w*self.scale) - return output - -class RCNetx2Centered(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetx2Centered, self).__init__() - self.scale = scale - self.hidden_dim = hidden_dim - self.stage1_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) - self.stage1_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) - self.stage1_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) - self.stage2_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) - self.stage2_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) - self.stage2_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) - s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) - s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) - s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) - s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) - s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) - s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# lut_model = rclut.RCLutx2.init_from_numpy( +# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, +# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, +# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, +# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, +# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, +# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 +# ) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w) +# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) +# output /= 3*4 +# x = output +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) +# output /= 3*4 +# output = output.view(b, c, h*self.scale, w*self.scale) +# return output + +# class RCNetx2Centered(nn.Module): +# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): +# super(RCNetx2Centered, self).__init__() +# self.scale = scale +# self.hidden_dim = hidden_dim +# self.stage1_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) +# self.stage1_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) +# self.stage1_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) +# self.stage2_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) +# self.stage2_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) +# self.stage2_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + +# def get_lut_model(self, quantization_interval=16, batch_size=2**10): +# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) +# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) +# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) +# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) +# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) +# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) +# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx2Centered.init_from_numpy( - s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, - s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, - s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, - s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, - s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, - s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 - ) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) - output /= 3*4 - x = output - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) - output /= 3*4 - output = output.view(b, c, h*self.scale, w*self.scale) - return output - -class ReconstructedConvRot90Unlutable(nn.Module): - def __init__(self, hidden_dim, window_size=7): - super(ReconstructedConvRot90Unlutable, self).__init__() - self.window_size = window_size - self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - - def pixel_wise_forward(self, x): - x = (x-127.5)/127.5 - out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) - out = torch.tanh(out) - out = out*127.5 + 127.5 - return out - - def forward(self, x): - original_shape = x.shape - x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') - x = F.unfold(x, self.window_size) - x = self.pixel_wise_forward(x) - x = x.mean(1) - x = x.reshape(*original_shape) - # x = round_func(x) # quality likely suffer from this - return x - - def __repr__(self): - return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" - -class RCBlockRot90Unlutable(nn.Module): - def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): - super(RCBlockRot90Unlutable, self).__init__() - self.window_size = window_size - self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) - - def forward(self, x): - b,c,hs,ws = x.shape - x = self.rc_conv(x) - x = F.pad(x, pad=[0,1,0,1], mode='replicate') - x = self.dense_conv_block(x) - return x - -class RCNetx2Unlutable(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetx2Unlutable, self).__init__() - self.scale = scale - self.hidden_dim = hidden_dim - self.stage1_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) - self.stage1_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) - self.stage1_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) - self.stage2_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) - self.stage2_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) - self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) - s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) - s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) - s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) - s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) - s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) - s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# lut_model = rclut.RCLutx2Centered.init_from_numpy( +# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, +# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, +# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, +# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, +# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, +# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 +# ) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w) +# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) +# output /= 3*4 +# x = output +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) +# output /= 3*4 +# output = output.view(b, c, h*self.scale, w*self.scale) +# return output + +# class ReconstructedConvRot90Unlutable(nn.Module): +# def __init__(self, hidden_dim, window_size=7): +# super(ReconstructedConvRot90Unlutable, self).__init__() +# self.window_size = window_size +# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) +# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + +# def pixel_wise_forward(self, x): +# x = (x-127.5)/127.5 +# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) +# out = torch.tanh(out) +# out = out*127.5 + 127.5 +# return out + +# def forward(self, x): +# original_shape = x.shape +# x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') +# x = F.unfold(x, self.window_size) +# x = self.pixel_wise_forward(x) +# x = x.mean(1) +# x = x.reshape(*original_shape) +# # x = round_func(x) # quality likely suffer from this +# return x + +# def __repr__(self): +# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" + +# class RCBlockRot90Unlutable(nn.Module): +# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): +# super(RCBlockRot90Unlutable, self).__init__() +# self.window_size = window_size +# self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size) +# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + +# def forward(self, x): +# b,c,hs,ws = x.shape +# x = self.rc_conv(x) +# x = F.pad(x, pad=[0,1,0,1], mode='replicate') +# x = self.dense_conv_block(x) +# return x + +# class RCNetx2Unlutable(nn.Module): +# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): +# super(RCNetx2Unlutable, self).__init__() +# self.scale = scale +# self.hidden_dim = hidden_dim +# self.stage1_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) +# self.stage1_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) +# self.stage1_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) +# self.stage2_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) +# self.stage2_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) +# self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + +# def get_lut_model(self, quantization_interval=16, batch_size=2**10): +# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) +# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) +# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) +# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) +# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) +# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) +# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx2.init_from_numpy( - s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, - s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, - s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, - s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, - s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, - s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 - ) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) - output /= 3*4 - x = output - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) - output /= 3*4 - output = output.view(b, c, h*self.scale, w*self.scale) - return output - - - -class ReconstructedConvCenteredUnlutable(nn.Module): - def __init__(self, hidden_dim, window_size=7): - super(ReconstructedConvCenteredUnlutable, self).__init__() - self.window_size = window_size - self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) - - def pixel_wise_forward(self, x): - x = (x-127.5)/127.5 - out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) - out = torch.tanh(out) - out = out*127.5 + 127.5 - return out - - def forward(self, x): - original_shape = x.shape - x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') - x = F.unfold(x, self.window_size) - x = self.pixel_wise_forward(x) - x = x.mean(1) - x = x.reshape(*original_shape) - # x = round_func(x) # quality likely suffer from this - return x - - def __repr__(self): - return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" - -class RCBlockCenteredUnlutable(nn.Module): - def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): - super(RCBlockRot90Unlutable, self).__init__() - self.window_size = window_size - self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) - - def forward(self, x): - b,c,hs,ws = x.shape - x = self.rc_conv(x) - x = F.pad(x, pad=[0,1,0,1], mode='replicate') - x = self.dense_conv_block(x) - return x - -class RCNetx2CenteredUnlutable(nn.Module): - def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(RCNetx2CenteredUnlutable, self).__init__() - self.scale = scale - self.hidden_dim = hidden_dim - self.stage1_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) - self.stage1_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) - self.stage1_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) - self.stage2_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) - self.stage2_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) - self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) - - def get_lut_model(self, quantization_interval=16, batch_size=2**10): - s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) - s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) - s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) - s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) - s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) - s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - - s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) - s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) +# lut_model = rclut.RCLutx2.init_from_numpy( +# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, +# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, +# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, +# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, +# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, +# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 +# ) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w) +# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) +# output /= 3*4 +# x = output +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) +# output /= 3*4 +# output = output.view(b, c, h*self.scale, w*self.scale) +# return output + + + +# class ReconstructedConvCenteredUnlutable(nn.Module): +# def __init__(self, hidden_dim, window_size=7): +# super(ReconstructedConvCenteredUnlutable, self).__init__() +# self.window_size = window_size +# self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) +# self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + +# def pixel_wise_forward(self, x): +# x = (x-127.5)/127.5 +# out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) +# out = torch.tanh(out) +# out = out*127.5 + 127.5 +# return out + +# def forward(self, x): +# original_shape = x.shape +# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') +# x = F.unfold(x, self.window_size) +# x = self.pixel_wise_forward(x) +# x = x.mean(1) +# x = x.reshape(*original_shape) +# # x = round_func(x) # quality likely suffer from this +# return x + +# def __repr__(self): +# return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" + +# class RCBlockCenteredUnlutable(nn.Module): +# def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): +# super(RCBlockRot90Unlutable, self).__init__() +# self.window_size = window_size +# self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size) +# self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + +# def forward(self, x): +# b,c,hs,ws = x.shape +# x = self.rc_conv(x) +# x = F.pad(x, pad=[0,1,0,1], mode='replicate') +# x = self.dense_conv_block(x) +# return x + +# class RCNetx2CenteredUnlutable(nn.Module): +# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): +# super(RCNetx2CenteredUnlutable, self).__init__() +# self.scale = scale +# self.hidden_dim = hidden_dim +# self.stage1_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) +# self.stage1_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) +# self.stage1_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) +# self.stage2_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) +# self.stage2_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) +# self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + +# def get_lut_model(self, quantization_interval=16, batch_size=2**10): +# s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) +# s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) +# s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) +# s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) +# s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) +# s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + +# s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) +# s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = rclut.RCLutx2Centered.init_from_numpy( - s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, - s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, - s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, - s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, - s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, - s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 - ) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) - output /= 3*4 - x = output - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) - output /= 3*4 - output = output.view(b, c, h*self.scale, w*self.scale) - return output +# lut_model = rclut.RCLutx2Centered.init_from_numpy( +# s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, +# s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, +# s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, +# s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, +# s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, +# s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 +# ) +# return lut_model + +# def forward(self, x): +# b,c,h,w = x.shape +# x = x.view(b*c, 1, h, w) +# output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) +# output /= 3*4 +# x = output +# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) +# for rotations_count in range(4): +# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) +# output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) +# output /= 3*4 +# output = output.view(b, c, h*self.scale, w*self.scale) +# return output diff --git a/src/models/sdylut.py b/src/models/sdylut.py index 2fbf128..f229078 100644 --- a/src/models/sdylut.py +++ b/src/models/sdylut.py @@ -3,8 +3,9 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np from pathlib import Path -from common.lut import forward_2x2_input_SxS_output, forward_unfolded_2x2_input_SxS_output +from common.lut import select_index_4dlut_tetrahedral from common.layers import PercievePattern +from common.utils import round_func class SDYLutx1(nn.Module): def __init__( @@ -15,7 +16,7 @@ class SDYLutx1(nn.Module): super(SDYLutx1, self).__init__() self.scale = scale self.quantization_interval = quantization_interval - self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3) + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @@ -34,30 +35,30 @@ class SDYLutx1(nn.Module): lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32)) return lut_model + def forward_stage(self, x, scale, percieve_pattern, lut): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = select_index_4dlut_tetrahedral(index=x, lut=lut) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): + output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS) + output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD) + output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY) + for rotations_count in range(1, 4): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) - rb,rc,rh,rw = rotated.shape - - s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS) - s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) - s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) - output += s - - d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD) - d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) - d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) - output += d - - y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY) - y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) - y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) - output += y - + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stageS), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stageD), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stageY), k=-rotations_count, dims=[-2, -1]) output /= 4*3 + output = round_func(output) output = output.view(b, c, h*self.scale, w*self.scale) return output @@ -103,55 +104,43 @@ class SDYLutx2(nn.Module): lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32)) return lut_model + def forward_stage(self, x, scale, percieve_pattern, lut): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = select_index_4dlut_tetrahedral(index=x, lut=lut) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w).type(torch.float32) output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) - for rotations_count in range(4): + output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S) + output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D) + output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y) + for rotations_count in range(1, 4): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) - rb,rc,rh,rw = rotated.shape - - s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage1_S) - s = s.view(rb, rc, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb, rc, rh, rw) - s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) - output += s - - d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage1_D) - d = d.view(rb, rc, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb, rc, rh, rw) - d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) - output += d - - y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage1_Y) - y = y.view(rb, rc, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb, rc, rh, rw) - y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) - output += y - + output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1]) output /= 4*3 - x = output + x = round_func(output) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): + output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S) + output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D) + output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y) + for rotations_count in range(1, 4): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) - rb,rc,rh,rw = rotated.shape - - s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage2_S) - s = s.view(rb, rc, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb, rc, rh*self.scale, rw*self.scale) - s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) - output += s - - d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage2_D) - d = d.view(rb, rc, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb, rc, rh*self.scale, rw*self.scale) - d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) - output += d - - y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage2_Y) - y = y.view(rb, rc, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb, rc, rh*self.scale, rw*self.scale) - y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) - output += y - + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1]) output /= 4*3 + output = round_func(output) output = output.view(b, c, h*self.scale, w*self.scale) - return output def __repr__(self): @@ -161,67 +150,4 @@ class SDYLutx2(nn.Module): f"\n stage1_Y size: {self.stage1_Y.shape}" + \ f"\n stage2_S size: {self.stage2_S.shape}" + \ f"\n stage2_D size: {self.stage2_D.shape}" + \ - f"\n stage2_Y size: {self.stage2_Y.shape}" - - - -class SDYLutCenteredx1(nn.Module): - def __init__( - self, - quantization_interval, - scale - ): - super(SDYLutCenteredx1, self).__init__() - self.scale = scale - self.quantization_interval = quantization_interval - self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[1,1], window_size=3) - self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[1,1], window_size=3) - self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[1,1], window_size=3) - self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - - @staticmethod - def init_from_numpy( - stageS, stageD, stageY - ): - scale = int(stageS.shape[-1]) - quantization_interval = 256//(stageS.shape[0]-1) - lut_model = SDYLutCenteredx1(quantization_interval=quantization_interval, scale=scale) - lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32)) - lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32)) - lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32)) - return lut_model - - def forward(self, x): - b,c,h,w = x.shape - x = x.view(b*c, 1, h, w).type(torch.float32) - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) - rb,rc,rh,rw = rotated.shape - - s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS) - s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) - s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) - output += s - - d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD) - d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) - d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) - output += d - - y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY) - y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) - y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) - output += y - - output /= 4*3 - output = output.view(b, c, h*self.scale, w*self.scale) - return output - - def __repr__(self): - return f"{self.__class__.__name__}" + \ - f"\n stageS size: {self.stageS.shape}" + \ - f"\n stageD size: {self.stageD.shape}" + \ - f"\n stageY size: {self.stageY.shape}" \ No newline at end of file + f"\n stage2_Y size: {self.stage2_Y.shape}" \ No newline at end of file diff --git a/src/models/sdynet.py b/src/models/sdynet.py index c4b6f38..cc65f37 100644 --- a/src/models/sdynet.py +++ b/src/models/sdynet.py @@ -12,29 +12,39 @@ class SDYNetx1(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(SDYNetx1, self).__init__() self.scale = scale - s_pattern = [[0,0],[0,1],[1,0],[1,1]] - d_pattern = [[0,0],[2,0],[0,2],[2,2]] - y_pattern = [[0,0],[1,1],[1,2],[2,1]] - self.stage1_S = layers.UpscaleBlock(receptive_field_idxes=s_pattern, center=[0,0], window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) - self.stage1_D = layers.UpscaleBlock(receptive_field_idxes=d_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) - self.stage1_Y = layers.UpscaleBlock(receptive_field_idxes=y_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) + self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward_stage(self, x, scale, percieve_pattern, stage): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = stage(x) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x def forward(self, x): b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) + x = x.reshape(b*c, 1, h, w) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - output += self.stage1_S(x) - output += self.stage1_D(x) - output += self.stage1_Y(x) + output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) + output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage1_D) + output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage1_Y) for rotations_count in range(1, 4): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) - output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[-2, -1]) - output += torch.rot90(self.stage1_D(rotated), k=-rotations_count, dims=[-2, -1]) - output += torch.rot90(self.stage1_Y(rotated), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1]) output /= 4*3 x = output x = round_func(x) - x = x.view(b, c, h*self.scale, w*self.scale) + x = x.reshape(b, c, h*self.scale, w*self.scale) return x def get_lut_model(self, quantization_interval=16, batch_size=2**10): @@ -48,39 +58,49 @@ class SDYNetx2(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(SDYNetx2, self).__init__() self.scale = scale - s_pattern = [[0,0],[0,1],[1,0],[1,1]] - d_pattern = [[0,0],[2,0],[0,2],[2,2]] - y_pattern = [[0,0],[1,1],[1,2],[2,1]] - self.stage1_S = layers.UpscaleBlock(receptive_field_idxes=s_pattern, center=[0,0], window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) - self.stage1_D = layers.UpscaleBlock(receptive_field_idxes=d_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) - self.stage1_Y = layers.UpscaleBlock(receptive_field_idxes=y_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) - self.stage2_S = layers.UpscaleBlock(receptive_field_idxes=s_pattern, center=[0,0], window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) - self.stage2_D = layers.UpscaleBlock(receptive_field_idxes=d_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) - self.stage2_Y = layers.UpscaleBlock(receptive_field_idxes=y_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) + self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage2_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward_stage(self, x, scale, percieve_pattern, stage): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = stage(x) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w) output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) - output_1 += self.stage1_S(x) - output_1 += self.stage1_D(x) - output_1 += self.stage1_Y(x) + output_1 += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S) + output_1 += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D) + output_1 += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y) for rotations_count in range(1,4): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) - output_1 += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[-2, -1]) - output_1 += torch.rot90(self.stage1_D(rotated), k=-rotations_count, dims=[-2, -1]) - output_1 += torch.rot90(self.stage1_Y(rotated), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1]) + output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1]) output_1 /= 4*3 x = round_func(output_1) output_2 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - output_2 += self.stage2_S(x) - output_2 += self.stage2_D(x) - output_2 += self.stage2_Y(x) + output_2 += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S) + output_2 += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D) + output_2 += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y) for rotations_count in range(1,4): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) - output_2 += torch.rot90(self.stage2_S(rotated), k=-rotations_count, dims=[-2, -1]) - output_2 += torch.rot90(self.stage2_D(rotated), k=-rotations_count, dims=[-2, -1]) - output_2 += torch.rot90(self.stage2_Y(rotated), k=-rotations_count, dims=[-2, -1]) + output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1]) + output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1]) + output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1]) output_2 /= 4*3 x = round_func(output_2) x = x.view(b, c, h*self.scale, w*self.scale) diff --git a/src/models/srlut.py b/src/models/srlut.py index 4a03c42..f8a591d 100644 --- a/src/models/srlut.py +++ b/src/models/srlut.py @@ -3,8 +3,9 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np from pathlib import Path -from common.lut import forward_2x2_input_SxS_output +from common.lut import select_index_4dlut_tetrahedral from common import layers +from common.utils import round_func class SRLut(nn.Module): def __init__( @@ -16,6 +17,7 @@ class SRLut(nn.Module): self.scale = scale self.quantization_interval = quantization_interval self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) @staticmethod def init_from_numpy( @@ -27,11 +29,21 @@ class SRLut(nn.Module): lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model + def forward_stage(self, x, scale, percieve_pattern, lut): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = select_index_4dlut_tetrahedral(index=x, lut=lut) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + def forward(self, x): b,c,h,w = x.shape - x = x.view(b*c, 1, h, w).type(torch.float32) - x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut) - x = x.view(b, c, x.shape[-2], x.shape[-1]) + x = x.reshape(b*c, 1, h, w).type(torch.float32) + x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut) + x = x.reshape(b, c, h*self.scale, w*self.scale) return x def __repr__(self): @@ -49,6 +61,7 @@ class SRLutR90(nn.Module): self.scale = scale self.quantization_interval = quantization_interval self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) @staticmethod def init_from_numpy( @@ -60,17 +73,26 @@ class SRLutR90(nn.Module): lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model + def forward_stage(self, x, scale, percieve_pattern, lut): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = select_index_4dlut_tetrahedral(index=x, lut=lut) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + def forward(self, x): b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) + x = x.reshape(b*c, 1, h, w) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) - for rotations_count in range(4): + output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut) + for rotations_count in range(1, 4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut) - unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) - output += unrotated_prediction + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3]) output /= 4 - output = output.view(b, c, h*self.scale, w*self.scale) + output = output.reshape(b, c, h*self.scale, w*self.scale) return output def __repr__(self): @@ -87,6 +109,7 @@ class SRLutR90Y(nn.Module): self.scale = scale self.quantization_interval = quantization_interval self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) self.rgb_to_ycbcr = layers.RgbToYcbcr() self.ycbcr_to_rgb = layers.YcbcrToRgb() @@ -100,6 +123,16 @@ class SRLutR90Y(nn.Module): lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model + def forward_stage(self, x, scale, percieve_pattern, lut): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = select_index_4dlut_tetrahedral(index=x, lut=lut) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + def forward(self, x): b,c,h,w = x.shape x = self.rgb_to_ycbcr(x) @@ -108,11 +141,10 @@ class SRLutR90Y(nn.Module): cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear') output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) - for rotations_count in range(4): + output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut) + for rotations_count in range(1,4): rotated = torch.rot90(y, k=rotations_count, dims=[2, 3]) - rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut) - unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) - output += unrotated_prediction + output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3]) output /= 4 output = torch.cat([output, cbcr_scaled], dim=1) output = self.ycbcr_to_rgb(output).clamp(0, 255) diff --git a/src/models/srnet.py b/src/models/srnet.py index e70125d..c2af809 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -12,20 +12,27 @@ class SRNet(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(SRNet, self).__init__() self.scale = scale - s_pattern=[[0,0],[0,1],[1,0],[1,1]] self.stage1_S = layers.UpscaleBlock( - receptive_field_idxes=s_pattern, - center=[0,0], - window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale ) + self._unfold_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + + def forward_stage(self, x, scale, percieve_pattern, stage): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = stage(x) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x def forward(self, x): b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) - x = self.stage1_S(x) + x = x.reshape(b*c, 1, h, w) + x = self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S) x = x.reshape(b, c, h*self.scale, w*self.scale) return x @@ -38,26 +45,33 @@ class SRNetR90(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(SRNetR90, self).__init__() self.scale = scale - s_pattern=[[0,0],[0,1],[1,0],[1,1]] self.stage1_S = layers.UpscaleBlock( - receptive_field_idxes=s_pattern, - center=[0,0], - window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale ) + self._unfold_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + + def forward_stage(self, x, scale, percieve_pattern, stage): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = stage(x) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x def forward(self, x): b,c,h,w = x.shape - x = x.view(b*c, 1, h, w) + x = x.reshape(b*c, 1, h, w) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - output += self.stage1_S(x) + output += self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S) for rotations_count in range(1,4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output /= 4 - output = output.view(b, c, h*self.scale, w*self.scale) + output = output.reshape(b, c, h*self.scale, w*self.scale) return output def get_lut_model(self, quantization_interval=16, batch_size=2**10): @@ -71,16 +85,24 @@ class SRNetR90Y(nn.Module): self.scale = scale s_pattern=[[0,0],[0,1],[1,0],[1,1]] self.stage1_S = layers.UpscaleBlock( - receptive_field_idxes=s_pattern, - center=[0,0], - window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale ) + self._unfold_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) self.rgb_to_ycbcr = layers.RgbToYcbcr() self.ycbcr_to_rgb = layers.YcbcrToRgb() + def forward_stage(self, x, scale, percieve_pattern, stage): + b,c,h,w = x.shape + x = percieve_pattern(x) + x = stage(x) + x = round_func(x) + x = x.reshape(b, c, h, w, scale, scale) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*scale, w*scale) + return x + def forward(self, x): b,c,h,w = x.shape x = self.rgb_to_ycbcr(x) @@ -90,10 +112,10 @@ class SRNetR90Y(nn.Module): x = y.view(b, 1, h, w) output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - output += self.stage1_S(x) + output += self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S) for rotations_count in range(1,4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output /= 4 output = torch.cat([output, cbcr_scaled], dim=1) output = self.ycbcr_to_rgb(output).clamp(0, 255)