|
|
|
@ -14,7 +14,7 @@ class Domain4DValues(Dataset):
|
|
|
|
|
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
|
|
|
|
|
values1d = torch.cat([values1d, torch.tensor([256])])
|
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
|
self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 2, 2)
|
|
|
|
|
self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 4)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
if isinstance(idx, slice):
|
|
|
|
@ -30,7 +30,7 @@ class Domain4DValues(Dataset):
|
|
|
|
|
else:
|
|
|
|
|
v = self.values[idx]
|
|
|
|
|
ix = v[0]//self.quantization_interval
|
|
|
|
|
return ix[0,0], ix[0,1], ix[1,0], ix[1,1], v
|
|
|
|
|
return ix[0], ix[1], ix[2], ix[3], v
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.values)
|
|
|
|
@ -56,7 +56,7 @@ def transfer_rc_conv(rc_conv, quantization_interval=1):
|
|
|
|
|
def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
|
|
|
|
|
bucket_count = 256//quantization_interval
|
|
|
|
|
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
|
|
|
|
|
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, 1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
|
|
|
|
|
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
|
|
|
|
|
domain_values = Domain4DValues(quantization_interval=quantization_interval)
|
|
|
|
|
domain_values_loader = DataLoader(
|
|
|
|
|
domain_values,
|
|
|
|
@ -68,88 +68,55 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2*
|
|
|
|
|
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
|
|
|
|
|
inputs = batch.type(torch.float32).cuda()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
outputs = block(inputs)[:,:,:]
|
|
|
|
|
lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8)[:,:,:scale,:scale]
|
|
|
|
|
outputs = block(inputs)
|
|
|
|
|
lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.reshape(-1, scale, scale).cpu().numpy().astype(np.uint8)
|
|
|
|
|
counter += inputs.shape[0]
|
|
|
|
|
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
|
|
|
|
|
print()
|
|
|
|
|
lut = lut.squeeze(-3)
|
|
|
|
|
return lut
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################### FORWARD ##########################
|
|
|
|
|
|
|
|
|
|
def forward_2x2_input_SxS_output(index, lut):
|
|
|
|
|
b,c,hs,ws = index.shape
|
|
|
|
|
scale = lut.shape[-1]
|
|
|
|
|
index = F.pad(input=index, pad=[0,1,0,1], mode='replicate')
|
|
|
|
|
out = select_index_4dlut_tetrahedral(
|
|
|
|
|
ixA = index,
|
|
|
|
|
ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]),
|
|
|
|
|
ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]),
|
|
|
|
|
ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]),
|
|
|
|
|
lut = lut
|
|
|
|
|
)
|
|
|
|
|
out = out[:,:,0:-1,0:-1,:,:] # unpad
|
|
|
|
|
# Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504]
|
|
|
|
|
out = out.permute(0,1,2,4,3,5).reshape(b*c,1,hs*scale,ws*scale)
|
|
|
|
|
out = round_func(out)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def forward_unfolded_2x2_input_SxS_output(index, lut):
|
|
|
|
|
b,c,hs,ws = index.shape
|
|
|
|
|
scale = lut.shape[-1]
|
|
|
|
|
out = select_index_4dlut_tetrahedral(
|
|
|
|
|
ixA = index,
|
|
|
|
|
ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]),
|
|
|
|
|
ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]),
|
|
|
|
|
ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]),
|
|
|
|
|
lut = lut
|
|
|
|
|
)
|
|
|
|
|
out = out[:,:,0:-1,0:-1,:,:] # unpad
|
|
|
|
|
# Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504]
|
|
|
|
|
out = out.permute(0,1,2,4,3,5).reshape(b*c,1,scale,scale)
|
|
|
|
|
out = round_func(out)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def forward_rc_conv_centered(index, lut):
|
|
|
|
|
window_size = lut.shape[0]
|
|
|
|
|
index = F.pad(index, pad=[window_size//2]*4, mode='replicate')
|
|
|
|
|
window_indexes = lut.shape[:-1]
|
|
|
|
|
# index = index.unsqueeze(-1)
|
|
|
|
|
x = torch.zeros_like(index)
|
|
|
|
|
for i in range(window_indexes[-2]):
|
|
|
|
|
for j in range(window_indexes[-1]):
|
|
|
|
|
shift_i, shift_j = -window_indexes[-2]//2+1 + i, -window_indexes[-1]//2+1 + j
|
|
|
|
|
shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
|
|
|
|
|
x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
|
|
|
|
|
x /= window_indexes[-2]*window_indexes[-1]
|
|
|
|
|
x = x.squeeze(-1)
|
|
|
|
|
x = round_func(x)
|
|
|
|
|
x = x[:,:,window_size//2:-window_size//2+1,window_size//2:-window_size//2+1]
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward_rc_conv_rot90(index, lut):
|
|
|
|
|
window_size = lut.shape[0]
|
|
|
|
|
index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate')
|
|
|
|
|
window_indexes = lut.shape[:-1]
|
|
|
|
|
# index = index.unsqueeze(-1)
|
|
|
|
|
x = torch.zeros_like(index)
|
|
|
|
|
for i in range(window_indexes[-2]):
|
|
|
|
|
for j in range(window_indexes[-1]):
|
|
|
|
|
shift_i, shift_j = i, j
|
|
|
|
|
shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
|
|
|
|
|
x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
|
|
|
|
|
x /= window_indexes[-2]*window_indexes[-1]
|
|
|
|
|
x = x.squeeze(-1)
|
|
|
|
|
x = round_func(x)
|
|
|
|
|
x = x[:,:,:-(window_size-1),:-(window_size-1)]
|
|
|
|
|
return x
|
|
|
|
|
# def forward_rc_conv_centered(index, lut):
|
|
|
|
|
# window_size = lut.shape[0]
|
|
|
|
|
# index = F.pad(index, pad=[window_size//2]*4, mode='replicate')
|
|
|
|
|
# window_indexes = lut.shape[:-1]
|
|
|
|
|
# # index = index.unsqueeze(-1)
|
|
|
|
|
# x = torch.zeros_like(index)
|
|
|
|
|
# for i in range(window_indexes[-2]):
|
|
|
|
|
# for j in range(window_indexes[-1]):
|
|
|
|
|
# shift_i, shift_j = -window_indexes[-2]//2+1 + i, -window_indexes[-1]//2+1 + j
|
|
|
|
|
# shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
|
|
|
|
|
# x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
|
|
|
|
|
# x /= window_indexes[-2]*window_indexes[-1]
|
|
|
|
|
# x = x.squeeze(-1)
|
|
|
|
|
# x = round_func(x)
|
|
|
|
|
# x = x[:,:,window_size//2:-window_size//2+1,window_size//2:-window_size//2+1]
|
|
|
|
|
# return x
|
|
|
|
|
|
|
|
|
|
# def forward_rc_conv_rot90(index, lut):
|
|
|
|
|
# window_size = lut.shape[0]
|
|
|
|
|
# index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate')
|
|
|
|
|
# window_indexes = lut.shape[:-1]
|
|
|
|
|
# # index = index.unsqueeze(-1)
|
|
|
|
|
# x = torch.zeros_like(index)
|
|
|
|
|
# for i in range(window_indexes[-2]):
|
|
|
|
|
# for j in range(window_indexes[-1]):
|
|
|
|
|
# shift_i, shift_j = i, j
|
|
|
|
|
# shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
|
|
|
|
|
# x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
|
|
|
|
|
# x /= window_indexes[-2]*window_indexes[-1]
|
|
|
|
|
# x = x.squeeze(-1)
|
|
|
|
|
# x = round_func(x)
|
|
|
|
|
# x = x[:,:,:-(window_size-1),:-(window_size-1)]
|
|
|
|
|
# return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################### UTILS ##########################
|
|
|
|
|
|
|
|
|
|
# TODO rewrite for unfolded
|
|
|
|
|
def select_index_1dlut_linear(ixA, lut):
|
|
|
|
|
lut = torch.clamp(lut, 0, 255)
|
|
|
|
|
b,c,h,w = ixA.shape
|
|
|
|
@ -168,75 +135,71 @@ def select_index_1dlut_linear(ixA, lut):
|
|
|
|
|
out = out.reshape((b,c,h,w))
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
|
|
|
|
|
def select_index_4dlut_tetrahedral(index, lut):
|
|
|
|
|
b, hw, c = index.shape
|
|
|
|
|
lut = torch.clamp(lut, 0, 255)
|
|
|
|
|
dimA, dimB, dimC, dimD = lut.shape[:4]
|
|
|
|
|
q = 256/(dimA-1)
|
|
|
|
|
L = dimA
|
|
|
|
|
upscale = lut.shape[-1]
|
|
|
|
|
weight = lut.reshape(L**4,upscale,upscale)
|
|
|
|
|
|
|
|
|
|
img_a1 = torch.floor_divide(ixA, q).type(torch.int64)
|
|
|
|
|
img_b1 = torch.floor_divide(ixB, q).type(torch.int64)
|
|
|
|
|
img_c1 = torch.floor_divide(ixC, q).type(torch.int64)
|
|
|
|
|
img_d1 = torch.floor_divide(ixD, q).type(torch.int64)
|
|
|
|
|
|
|
|
|
|
# Extract LSBs
|
|
|
|
|
fa = ixA % q
|
|
|
|
|
fb = ixB % q
|
|
|
|
|
fc = ixC % q
|
|
|
|
|
fd = ixD % q
|
|
|
|
|
|
|
|
|
|
img_a2 = img_a1 + 1
|
|
|
|
|
img_b2 = img_b1 + 1
|
|
|
|
|
img_c2 = img_c1 + 1
|
|
|
|
|
img_d2 = img_d1 + 1
|
|
|
|
|
|
|
|
|
|
p0000 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p0001 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p0010 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p0011 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p0100 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p0101 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p0110 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p0111 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
|
|
|
|
|
p1000 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p1001 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p1010 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p1011 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p1100 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p1101 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p1110 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
p1111 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
|
|
|
|
|
out = torch.zeros((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale), dtype=weight.dtype).to(device=weight.device)
|
|
|
|
|
sz = img_a1.shape[0] * img_a1.shape[1] * img_a1.shape[2] * img_a1.shape[3]
|
|
|
|
|
out = out.reshape(sz, -1)
|
|
|
|
|
|
|
|
|
|
p0000 = p0000.reshape(sz, -1)
|
|
|
|
|
p0100 = p0100.reshape(sz, -1)
|
|
|
|
|
p1000 = p1000.reshape(sz, -1)
|
|
|
|
|
p1100 = p1100.reshape(sz, -1)
|
|
|
|
|
fa = fa.reshape(-1, 1)
|
|
|
|
|
|
|
|
|
|
p0001 = p0001.reshape(sz, -1)
|
|
|
|
|
p0101 = p0101.reshape(sz, -1)
|
|
|
|
|
p1001 = p1001.reshape(sz, -1)
|
|
|
|
|
p1101 = p1101.reshape(sz, -1)
|
|
|
|
|
fb = fb.reshape(-1, 1)
|
|
|
|
|
fc = fc.reshape(-1, 1)
|
|
|
|
|
|
|
|
|
|
p0010 = p0010.reshape(sz, -1)
|
|
|
|
|
p0110 = p0110.reshape(sz, -1)
|
|
|
|
|
p1010 = p1010.reshape(sz, -1)
|
|
|
|
|
p1110 = p1110.reshape(sz, -1)
|
|
|
|
|
fd = fd.reshape(-1, 1)
|
|
|
|
|
|
|
|
|
|
p0011 = p0011.reshape(sz, -1)
|
|
|
|
|
p0111 = p0111.reshape(sz, -1)
|
|
|
|
|
p1011 = p1011.reshape(sz, -1)
|
|
|
|
|
p1111 = p1111.reshape(sz, -1)
|
|
|
|
|
weight = lut.reshape(L**4, upscale, upscale)
|
|
|
|
|
|
|
|
|
|
msbA = torch.floor_divide(index, q).type(torch.int64)
|
|
|
|
|
msbB = msbA + 1
|
|
|
|
|
lsb = index % q
|
|
|
|
|
|
|
|
|
|
img_a1 = msbA[:,:,0].reshape(b*hw, 1)
|
|
|
|
|
img_b1 = msbA[:,:,1].reshape(b*hw, 1)
|
|
|
|
|
img_c1 = msbA[:,:,2].reshape(b*hw, 1)
|
|
|
|
|
img_d1 = msbA[:,:,3].reshape(b*hw, 1)
|
|
|
|
|
img_a2 = msbB[:,:,0].reshape(b*hw, 1)
|
|
|
|
|
img_b2 = msbB[:,:,1].reshape(b*hw, 1)
|
|
|
|
|
img_c2 = msbB[:,:,2].reshape(b*hw, 1)
|
|
|
|
|
img_d2 = msbB[:,:,3].reshape(b*hw, 1)
|
|
|
|
|
fa = lsb[:,:,0].reshape(b*hw, 1)
|
|
|
|
|
fb = lsb[:,:,1].reshape(b*hw, 1)
|
|
|
|
|
fc = lsb[:,:,2].reshape(b*hw, 1)
|
|
|
|
|
fd = lsb[:,:,3].reshape(b*hw, 1)
|
|
|
|
|
|
|
|
|
|
p0000 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c1 * L + img_d1]
|
|
|
|
|
p0001 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c1 * L + img_d2]
|
|
|
|
|
p0010 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c2 * L + img_d1]
|
|
|
|
|
p0011 = weight[img_a1 * L * L * L + img_b1 * L * L + img_c2 * L + img_d2]
|
|
|
|
|
p0100 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c1 * L + img_d1]
|
|
|
|
|
p0101 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c1 * L + img_d2]
|
|
|
|
|
p0110 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c2 * L + img_d1]
|
|
|
|
|
p0111 = weight[img_a1 * L * L * L + img_b2 * L * L + img_c2 * L + img_d2]
|
|
|
|
|
|
|
|
|
|
p1000 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c1 * L + img_d1]
|
|
|
|
|
p1001 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c1 * L + img_d2]
|
|
|
|
|
p1010 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c2 * L + img_d1]
|
|
|
|
|
p1011 = weight[img_a2 * L * L * L + img_b1 * L * L + img_c2 * L + img_d2]
|
|
|
|
|
p1100 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c1 * L + img_d1]
|
|
|
|
|
p1101 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c1 * L + img_d2]
|
|
|
|
|
p1110 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c2 * L + img_d1]
|
|
|
|
|
p1111 = weight[img_a2 * L * L * L + img_b2 * L * L + img_c2 * L + img_d2]
|
|
|
|
|
|
|
|
|
|
out = torch.zeros((b*hw, upscale*upscale), dtype=weight.dtype).to(device=weight.device)
|
|
|
|
|
|
|
|
|
|
p0000 = p0000.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p0100 = p0100.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p1000 = p1000.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p1100 = p1100.reshape(b*hw, upscale*upscale)
|
|
|
|
|
|
|
|
|
|
p0001 = p0001.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p0101 = p0101.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p1001 = p1001.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p1101 = p1101.reshape(b*hw, upscale*upscale)
|
|
|
|
|
|
|
|
|
|
p0010 = p0010.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p0110 = p0110.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p1010 = p1010.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p1110 = p1110.reshape(b*hw, upscale*upscale)
|
|
|
|
|
|
|
|
|
|
p0011 = p0011.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p0111 = p0111.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p1011 = p1011.reshape(b*hw, upscale*upscale)
|
|
|
|
|
p1111 = p1111.reshape(b*hw, upscale*upscale)
|
|
|
|
|
|
|
|
|
|
fab = fa > fb;
|
|
|
|
|
fac = fa > fc;
|
|
|
|
@ -257,14 +220,7 @@ def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
|
|
|
|
|
i8 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
|
|
|
|
|
|
|
|
|
i9 = i = torch.all(torch.cat([~(fbc), ~(fac), fab, fbd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
|
|
|
|
|
# Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first!
|
|
|
|
|
# i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], fab, fcd], dim=1), dim=1)
|
|
|
|
|
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
|
|
|
|
# i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad], dim=1), dim=1)
|
|
|
|
|
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
|
|
|
|
# c > a > d > b
|
|
|
|
|
i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], fab, fad], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
|
|
|
|
# c > d > a > b
|
|
|
|
|
i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
|
|
|
|
i12 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
|
|
|
|
|
|
|
|
@ -283,6 +239,6 @@ def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
|
|
|
|
|
i23 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
|
|
|
|
|
i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
|
|
|
|
|
|
|
|
|
|
out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
|
|
|
|
out = out.reshape((b, hw, upscale*upscale))
|
|
|
|
|
out = out / q
|
|
|
|
|
return out
|