model rewrite

main
protsenkovi 6 months ago
parent 57fae95619
commit 53fa827515

@ -36,10 +36,9 @@ class PercievePattern():
return x
class UpscaleBlock(nn.Module):
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, in_features=4, hidden_dim = 32, layers_count=5, upscale_factor=1):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(UpscaleBlock, self).__init__()
assert layers_count > 0
self.percieve_pattern = PercievePattern(receptive_field_idxes=receptive_field_idxes, center=center, window_size=window_size)
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
@ -52,19 +51,13 @@ class UpscaleBlock(nn.Module):
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
def forward(self, x):
b,c,h,w = x.shape
x = (x-127.5)/127.5
x = self.percieve_pattern(x)
x = torch.relu(self.embed(x))
for linear_projection in self.linear_projections:
x = torch.cat([x, torch.relu(linear_projection(x))], dim=2)
x = self.project_channels(x)
x = torch.tanh(x)
x = x*127.5 + 127.5
x = round_func(x)
x = x.reshape(b, c, h, w, self.upscale_factor, self.upscale_factor)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*self.upscale_factor, w*self.upscale_factor)
return x
class RgbToYcbcr(nn.Module):

@ -14,7 +14,7 @@ class Domain4DValues(Dataset):
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([256])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 2, 2)
self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 4)
def __getitem__(self, idx):
if isinstance(idx, slice):
@ -30,7 +30,7 @@ class Domain4DValues(Dataset):
else:
v = self.values[idx]
ix = v[0]//self.quantization_interval
return ix[0,0], ix[0,1], ix[1,0], ix[1,1], v
return ix[0], ix[1], ix[2], ix[3], v
def __len__(self):
return len(self.values)
@ -56,7 +56,7 @@ def transfer_rc_conv(rc_conv, quantization_interval=1):
def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, 1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
domain_values = Domain4DValues(quantization_interval=quantization_interval)
domain_values_loader = DataLoader(
domain_values,
@ -68,88 +68,55 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2*
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
with torch.no_grad():
outputs = block(inputs)[:,:,:]
lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8)[:,:,:scale,:scale]
outputs = block(inputs)
lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.reshape(-1, scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print()
lut = lut.squeeze(-3)
return lut
##################### FORWARD ##########################
def forward_2x2_input_SxS_output(index, lut):
b,c,hs,ws = index.shape
scale = lut.shape[-1]
index = F.pad(input=index, pad=[0,1,0,1], mode='replicate')
out = select_index_4dlut_tetrahedral(
ixA = index,
ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]),
ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]),
ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]),
lut = lut
)
out = out[:,:,0:-1,0:-1,:,:] # unpad
# Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504]
out = out.permute(0,1,2,4,3,5).reshape(b*c,1,hs*scale,ws*scale)
out = round_func(out)
return out
def forward_unfolded_2x2_input_SxS_output(index, lut):
b,c,hs,ws = index.shape
scale = lut.shape[-1]
out = select_index_4dlut_tetrahedral(
ixA = index,
ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]),
ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]),
ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]),
lut = lut
)
out = out[:,:,0:-1,0:-1,:,:] # unpad
# Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504]
out = out.permute(0,1,2,4,3,5).reshape(b*c,1,scale,scale)
out = round_func(out)
return out
def forward_rc_conv_centered(index, lut):
window_size = lut.shape[0]
index = F.pad(index, pad=[window_size//2]*4, mode='replicate')
window_indexes = lut.shape[:-1]
# index = index.unsqueeze(-1)
x = torch.zeros_like(index)
for i in range(window_indexes[-2]):
for j in range(window_indexes[-1]):
shift_i, shift_j = -window_indexes[-2]//2+1 + i, -window_indexes[-1]//2+1 + j
shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
x /= window_indexes[-2]*window_indexes[-1]
x = x.squeeze(-1)
x = round_func(x)
x = x[:,:,window_size//2:-window_size//2+1,window_size//2:-window_size//2+1]
return x
def forward_rc_conv_rot90(index, lut):
window_size = lut.shape[0]
index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate')
window_indexes = lut.shape[:-1]
# index = index.unsqueeze(-1)
x = torch.zeros_like(index)
for i in range(window_indexes[-2]):
for j in range(window_indexes[-1]):
shift_i, shift_j = i, j
shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
x /= window_indexes[-2]*window_indexes[-1]
x = x.squeeze(-1)
x = round_func(x)
x = x[:,:,:-(window_size-1),:-(window_size-1)]
return x
# def forward_rc_conv_centered(index, lut):
# window_size = lut.shape[0]
# index = F.pad(index, pad=[window_size//2]*4, mode='replicate')
# window_indexes = lut.shape[:-1]
# # index = index.unsqueeze(-1)
# x = torch.zeros_like(index)
# for i in range(window_indexes[-2]):
# for j in range(window_indexes[-1]):
# shift_i, shift_j = -window_indexes[-2]//2+1 + i, -window_indexes[-1]//2+1 + j
# shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
# x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
# x /= window_indexes[-2]*window_indexes[-1]
# x = x.squeeze(-1)
# x = round_func(x)
# x = x[:,:,window_size//2:-window_size//2+1,window_size//2:-window_size//2+1]
# return x
# def forward_rc_conv_rot90(index, lut):
# window_size = lut.shape[0]
# index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate')
# window_indexes = lut.shape[:-1]
# # index = index.unsqueeze(-1)
# x = torch.zeros_like(index)
# for i in range(window_indexes[-2]):
# for j in range(window_indexes[-1]):
# shift_i, shift_j = i, j
# shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
# x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
# x /= window_indexes[-2]*window_indexes[-1]
# x = x.squeeze(-1)
# x = round_func(x)
# x = x[:,:,:-(window_size-1),:-(window_size-1)]
# return x
##################### UTILS ##########################
# TODO rewrite for unfolded
def select_index_1dlut_linear(ixA, lut):
lut = torch.clamp(lut, 0, 255)
b,c,h,w = ixA.shape
@ -168,75 +135,71 @@ def select_index_1dlut_linear(ixA, lut):
out = out.reshape((b,c,h,w))
return out
def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
def select_index_4dlut_tetrahedral(index, lut):
b, hw, c = index.shape
lut = torch.clamp(lut, 0, 255)
dimA, dimB, dimC, dimD = lut.shape[:4]
q = 256/(dimA-1)
L = dimA
upscale = lut.shape[-1]
weight = lut.reshape(L**4,upscale,upscale)
img_a1 = torch.floor_divide(ixA, q).type(torch.int64)
img_b1 = torch.floor_divide(ixB, q).type(torch.int64)
img_c1 = torch.floor_divide(ixC, q).type(torch.int64)
img_d1 = torch.floor_divide(ixD, q).type(torch.int64)
# Extract LSBs
fa = ixA % q
fb = ixB % q
fc = ixC % q
fd = ixD % q
img_a2 = img_a1 + 1
img_b2 = img_b1 + 1
img_c2 = img_c1 + 1
img_d2 = img_d1 + 1
p0000 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0001 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0010 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0011 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0100 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0101 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0110 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0111 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1000 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1001 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1010 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1011 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1100 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1101 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1110 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1111 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
out = torch.zeros((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale), dtype=weight.dtype).to(device=weight.device)
sz = img_a1.shape[0] * img_a1.shape[1] * img_a1.shape[2] * img_a1.shape[3]
out = out.reshape(sz, -1)
p0000 = p0000.reshape(sz, -1)
p0100 = p0100.reshape(sz, -1)
p1000 = p1000.reshape(sz, -1)
p1100 = p1100.reshape(sz, -1)
fa = fa.reshape(-1, 1)
p0001 = p0001.reshape(sz, -1)
p0101 = p0101.reshape(sz, -1)
p1001 = p1001.reshape(sz, -1)
p1101 = p1101.reshape(sz, -1)
fb = fb.reshape(-1, 1)
fc = fc.reshape(-1, 1)
p0010 = p0010.reshape(sz, -1)
p0110 = p0110.reshape(sz, -1)
p1010 = p1010.reshape(sz, -1)
p1110 = p1110.reshape(sz, -1)
fd = fd.reshape(-1, 1)
p0011 = p0011.reshape(sz, -1)
p0111 = p0111.reshape(sz, -1)
p1011 = p1011.reshape(sz, -1)
p1111 = p1111.reshape(sz, -1)
weight = lut.reshape(L**4, upscale, upscale)
msbA = torch.floor_divide(index, q).type(torch.int64)
msbB = msbA + 1
lsb = index % q
img_a1 = msbA[:,:,0].reshape(b*hw, 1)
img_b1 = msbA[:,:,1].reshape(b*hw, 1)
img_c1 = msbA[:,:,2].reshape(b*hw, 1)
img_d1 = msbA[:,:,3].reshape(b*hw, 1)
img_a2 = msbB[:,:,0].reshape(b*hw, 1)
img_b2 = msbB[:,:,1].reshape(b*hw, 1)
img_c2 = msbB[:,:,2].reshape(b*hw, 1)
img_d2 = msbB[:,:,3].reshape(b*hw, 1)
fa = lsb[:,:,0].reshape(b*hw, 1)
fb = lsb[:,:,1].reshape(b*hw, 1)
fc = lsb[:,:,2].reshape(b*hw, 1)
fd = lsb[:,:,3].reshape(b*hw, 1)
p0000 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c1 * L + img_d1]
p0001 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c1 * L + img_d2]
p0010 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c2 * L + img_d1]
p0011 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c2 * L + img_d2]
p0100 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c1 * L + img_d1]
p0101 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c1 * L + img_d2]
p0110 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c2 * L + img_d1]
p0111 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c2 * L + img_d2]
p1000 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c1 * L + img_d1]
p1001 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c1 * L + img_d2]
p1010 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c2 * L + img_d1]
p1011 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c2 * L + img_d2]
p1100 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c1 * L + img_d1]
p1101 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c1 * L + img_d2]
p1110 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c2 * L + img_d1]
p1111 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c2 * L + img_d2]
out = torch.zeros((b*hw, upscale*upscale), dtype=weight.dtype).to(device=weight.device)
p0000 = p0000.reshape(b*hw, upscale*upscale)
p0100 = p0100.reshape(b*hw, upscale*upscale)
p1000 = p1000.reshape(b*hw, upscale*upscale)
p1100 = p1100.reshape(b*hw, upscale*upscale)
p0001 = p0001.reshape(b*hw, upscale*upscale)
p0101 = p0101.reshape(b*hw, upscale*upscale)
p1001 = p1001.reshape(b*hw, upscale*upscale)
p1101 = p1101.reshape(b*hw, upscale*upscale)
p0010 = p0010.reshape(b*hw, upscale*upscale)
p0110 = p0110.reshape(b*hw, upscale*upscale)
p1010 = p1010.reshape(b*hw, upscale*upscale)
p1110 = p1110.reshape(b*hw, upscale*upscale)
p0011 = p0011.reshape(b*hw, upscale*upscale)
p0111 = p0111.reshape(b*hw, upscale*upscale)
p1011 = p1011.reshape(b*hw, upscale*upscale)
p1111 = p1111.reshape(b*hw, upscale*upscale)
fab = fa > fb;
fac = fa > fc;
@ -257,14 +220,7 @@ def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
i8 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
i9 = i = torch.all(torch.cat([~(fbc), ~(fac), fab, fbd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
# Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first!
# i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], fab, fcd], dim=1), dim=1)
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
# i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad], dim=1), dim=1)
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
# c > a > d > b
i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], fab, fad], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
# c > d > a > b
i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
i12 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
@ -283,6 +239,6 @@ def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
i23 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
out = out.reshape((b, hw, upscale*upscale))
out = out / q
return out

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -3,8 +3,9 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import forward_2x2_input_SxS_output, forward_unfolded_2x2_input_SxS_output
from common.lut import select_index_4dlut_tetrahedral
from common.layers import PercievePattern
from common.utils import round_func
class SDYLutx1(nn.Module):
def __init__(
@ -15,7 +16,7 @@ class SDYLutx1(nn.Module):
super(SDYLutx1, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3)
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@ -34,30 +35,30 @@ class SDYLutx1(nn.Module):
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS)
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD)
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY)
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stageS), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stageD), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stageY), k=-rotations_count, dims=[-2, -1])
output /= 4*3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
@ -103,55 +104,43 @@ class SDYLutx2(nn.Module):
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage1_S)
s = s.view(rb, rc, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb, rc, rh, rw)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage1_D)
d = d.view(rb, rc, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb, rc, rh, rw)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage1_Y)
y = y.view(rb, rc, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb, rc, rh, rw)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1])
output /= 4*3
x = output
x = round_func(output)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage2_S)
s = s.view(rb, rc, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb, rc, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage2_D)
d = d.view(rb, rc, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb, rc, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage2_Y)
y = y.view(rb, rc, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb, rc, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1])
output /= 4*3
output = round_func(output)
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
@ -162,66 +151,3 @@ class SDYLutx2(nn.Module):
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}"
class SDYLutCenteredx1(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SDYLutCenteredx1, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[1,1], window_size=3)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[1,1], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[1,1], window_size=3)
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_numpy(
stageS, stageD, stageY
):
scale = int(stageS.shape[-1])
quantization_interval = 256//(stageS.shape[0]-1)
lut_model = SDYLutCenteredx1(quantization_interval=quantization_interval, scale=scale)
lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS)
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD)
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY)
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
output /= 4*3
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}" + \
f"\n stageS size: {self.stageS.shape}" + \
f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}"

@ -12,29 +12,39 @@ class SDYNetx1(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx1, self).__init__()
self.scale = scale
s_pattern = [[0,0],[0,1],[1,0],[1,1]]
d_pattern = [[0,0],[2,0],[0,2],[2,2]]
y_pattern = [[0,0],[1,1],[1,2],[2,1]]
self.stage1_S = layers.UpscaleBlock(receptive_field_idxes=s_pattern, center=[0,0], window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_D = layers.UpscaleBlock(receptive_field_idxes=d_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_Y = layers.UpscaleBlock(receptive_field_idxes=y_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.stage1_S(x)
output += self.stage1_D(x)
output += self.stage1_Y(x)
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage1_D)
output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage1_Y)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.stage1_D(rotated), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.stage1_Y(rotated), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1])
output /= 4*3
x = output
x = round_func(x)
x = x.view(b, c, h*self.scale, w*self.scale)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
@ -48,39 +58,49 @@ class SDYNetx2(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx2, self).__init__()
self.scale = scale
s_pattern = [[0,0],[0,1],[1,0],[1,1]]
d_pattern = [[0,0],[2,0],[0,2],[2,2]]
y_pattern = [[0,0],[1,1],[1,2],[2,1]]
self.stage1_S = layers.UpscaleBlock(receptive_field_idxes=s_pattern, center=[0,0], window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_D = layers.UpscaleBlock(receptive_field_idxes=d_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_Y = layers.UpscaleBlock(receptive_field_idxes=y_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage2_S = layers.UpscaleBlock(receptive_field_idxes=s_pattern, center=[0,0], window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_D = layers.UpscaleBlock(receptive_field_idxes=d_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_Y = layers.UpscaleBlock(receptive_field_idxes=y_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage2_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
output_1 += self.stage1_S(x)
output_1 += self.stage1_D(x)
output_1 += self.stage1_Y(x)
output_1 += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S)
output_1 += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
output_1 += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.stage1_D(rotated), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.stage1_Y(rotated), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
output_1 += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1])
output_1 /= 4*3
x = round_func(output_1)
output_2 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_2 += self.stage2_S(x)
output_2 += self.stage2_D(x)
output_2 += self.stage2_Y(x)
output_2 += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
output_2 += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
output_2 += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.stage2_S(rotated), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.stage2_D(rotated), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.stage2_Y(rotated), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1])
output_2 += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1])
output_2 /= 4*3
x = round_func(output_2)
x = x.view(b, c, h*self.scale, w*self.scale)

@ -3,8 +3,9 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import forward_2x2_input_SxS_output
from common.lut import select_index_4dlut_tetrahedral
from common import layers
from common.utils import round_func
class SRLut(nn.Module):
def __init__(
@ -16,6 +17,7 @@ class SRLut(nn.Module):
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
@ -27,11 +29,21 @@ class SRLut(nn.Module):
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut)
x = x.view(b, c, x.shape[-2], x.shape[-1])
x = x.reshape(b*c, 1, h, w).type(torch.float32)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def __repr__(self):
@ -49,6 +61,7 @@ class SRLutR90(nn.Module):
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
@staticmethod
def init_from_numpy(
@ -60,17 +73,26 @@ class SRLutR90(nn.Module):
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
output = output.reshape(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
@ -87,6 +109,7 @@ class SRLutR90Y(nn.Module):
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
@ -100,6 +123,16 @@ class SRLutR90Y(nn.Module):
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
@ -108,11 +141,10 @@ class SRLutR90Y(nn.Module):
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
for rotations_count in range(1,4):
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)

@ -12,20 +12,27 @@ class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__()
self.scale = scale
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
receptive_field_idxes=s_pattern,
center=[0,0],
window_size=2,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._unfold_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = self.stage1_S(x)
x = x.reshape(b*c, 1, h, w)
x = self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
@ -38,26 +45,33 @@ class SRNetR90(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90, self).__init__()
self.scale = scale
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
receptive_field_idxes=s_pattern,
center=[0,0],
window_size=2,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._unfold_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.stage1_S(x)
output += self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S)
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
output = output.reshape(b, c, h*self.scale, w*self.scale)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
@ -71,16 +85,24 @@ class SRNetR90Y(nn.Module):
self.scale = scale
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
receptive_field_idxes=s_pattern,
center=[0,0],
window_size=2,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._unfold_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
@ -90,10 +112,10 @@ class SRNetR90Y(nn.Module):
x = y.view(b, 1, h, w)
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.stage1_S(x)
output += self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S)
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)

Loading…
Cancel
Save