main
Vladimir Protsenko 7 months ago
parent 14a7f00245
commit 3b2c9ee8e3

@ -2,11 +2,13 @@
``` ```
python train.py --train_datasets DIV2K_pillow_bicubic --val_datasets Set5,Set14 --scale 4 --total_iter 10000 --model SRNetRot90 python train.py --train_datasets DIV2K_pillow_bicubic --val_datasets Set5,Set14 --scale 4 --total_iter 10000 --model SRNetRot90
python transfer_to_lut.py --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth python transfer_to_lut.py --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth
python image_demo.py python train.py --model_path /wd/lut_reproduce/models/RCNetx2_DIV2K_pillow_bicubic/checkpoints/RCLutx2_0.pth --train_datasets DIV2K_pillow_bicubic --total_iter 10000
python image_demo.py -n /wd/lut_reproduce/models/RCNetx2_DIV2K_pillow_bicubic/checkpoints/RCNetCentered_3x3_10000.pth -l /wd/lut_reproduce/models/RCLutCentered_3x3_DIV2K_pillow_bicubic/checkpoints/RCLutCentered_3x3_10000.pth
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth
``` ```
Requierements: Requierements:
- [shedulefree](https://github.com/facebookresearch/schedule_free) - [shedulefree](https://github.com/facebookresearch/schedule_free)
- einops - einops
- ray

@ -15,6 +15,7 @@ image_extensions = ['.jpg', '.png']
def load_images_cached(images_dir_path): def load_images_cached(images_dir_path):
image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions]) image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions])
cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_cache.npy" cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_cache.npy"
cache_path = cache_path.resolve()
if not Path(cache_path).exists(): if not Path(cache_path).exists():
print("Caching to:", cache_path) print("Caching to:", cache_path)
value = {f:np.array(Image.open(f)) for f in image_paths} value = {f:np.array(Image.open(f)) for f in image_paths}

@ -84,7 +84,7 @@ def forward_2x2_input_SxS_output(index, lut):
b,c,hs,ws = index.shape b,c,hs,ws = index.shape
scale = lut.shape[-1] scale = lut.shape[-1]
index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') #? index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') #?
out = select_index_4dlut_tetrahedral2( out = select_index_4dlut_tetrahedral(
ixA = index, ixA = index,
ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]), ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]),
ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]), ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]),
@ -101,7 +101,7 @@ def forward_rc_conv_centered(index, lut):
window_size = lut.shape[0] window_size = lut.shape[0]
index = F.pad(index, pad=[window_size//2]*4, mode='replicate') index = F.pad(index, pad=[window_size//2]*4, mode='replicate')
window_indexes = lut.shape[:-1] window_indexes = lut.shape[:-1]
index = index.unsqueeze(-1) # index = index.unsqueeze(-1)
x = torch.zeros_like(index) x = torch.zeros_like(index)
for i in range(window_indexes[-2]): for i in range(window_indexes[-2]):
for j in range(window_indexes[-1]): for j in range(window_indexes[-1]):
@ -118,7 +118,7 @@ def forward_rc_conv_rot90(index, lut):
window_size = lut.shape[0] window_size = lut.shape[0]
index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate') index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate')
window_indexes = lut.shape[:-1] window_indexes = lut.shape[:-1]
index = index.unsqueeze(-1) # index = index.unsqueeze(-1)
x = torch.zeros_like(index) x = torch.zeros_like(index)
for i in range(window_indexes[-2]): for i in range(window_indexes[-2]):
for j in range(window_indexes[-1]): for j in range(window_indexes[-1]):
@ -135,166 +135,30 @@ def forward_rc_conv_rot90(index, lut):
##################### UTILS ########################## ##################### UTILS ##########################
def select_index_1dlut_linear(ixA, lut): def select_index_1dlut_linear(ixA, lut):
dimA = lut.shape[0] lut = torch.clamp(lut, 0, 255)
qA = 256/(dimA-1) b,c,h,w = ixA.shape
outDims = lut.shape[1:] ixA = ixA.flatten()
lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) L = lut.shape[0]
index_loop_indexes = ixA.shape Q = 256/(L-1)
lut_loop_indexes = lut.shape[:-1] msbA = torch.floor_divide(ixA, Q).type(torch.int64)
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) msbB = msbA + 1
msbA = torch.floor_divide(ixA, qA).type(torch.int64) msbA = msbA.flatten()
msbB = torch.floor_divide(ixA, qA).type(torch.int64) + 1 msbB = msbB.flatten()
lsb_index = ixA % qA lsb = ixA % Q
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) outA = lut[msbA]
outA = torch.gather(input=lut, dim=-1, index=msbA) outB = lut[msbB]
outB = torch.gather(input=lut, dim=-1, index=msbB) lsb_coef = lsb / Q
out = outA + (lsb_index/qA) * (outB-outA)
out = out.squeeze(-1)
return out
def select_index_1dlut_msb(ixA, lut):
dimA = lut.shape[0]
outDims = lut.shape[1:]
lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
index_loop_indexes = ixA.shape
lut_loop_indexes = lut.shape[:-1]
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
msb_index = torch.floor_divide(ixA, 256/(dimA-1)).type(torch.int64) * dimA**0
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
out = torch.gather(input=lut, dim=-1, index=msb_index)
out = out.squeeze(-1)
return out
def select_index_4dlut_msb(ixA, ixB, ixC, ixD, lut):
dimA, dimB, dimC, dimD = lut.shape[:4]
qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1)
outDims = lut.shape[4:]
lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
index_loop_indexes = ixA.shape
lut_loop_indexes = lut.shape[:-1]
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
msb_index = torch.floor_divide(ixA, qA) * dimA**3
msb_index += torch.floor_divide(ixB, qB) * dimB**2
msb_index += torch.floor_divide(ixC, qC) * dimC**1
msb_index += torch.floor_divide(ixD, qD) * dimD**0
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
out = torch.gather(input=lut, dim=-1, index=msb_index.type(torch.int64))
out = out.squeeze(-1)
return out
def select_index_4dlut_linear(ixA, ixB, ixC, ixD, lut):
dimA, dimB, dimC, dimD = lut.shape[:4]
qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1)
outDims = lut.shape[4:]
lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
index_loop_indexes = ixA.shape
lut_loop_indexes = lut.shape[:-1]
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
msb_index = torch.floor_divide(ixA, qA).type(torch.int64) * dimA**3
msb_index += torch.floor_divide(ixB, qB).type(torch.int64) * dimB**2
msb_index += torch.floor_divide(ixC, qC).type(torch.int64) * dimC**1
msb_index += torch.floor_divide(ixD, qD).type(torch.int64) * dimD**0
outA = torch.gather(input=lut, dim=-1, index=msb_index)
msb_index = (torch.floor_divide(ixA, qA).type(torch.int64) + 1) * dimA**3
msb_index += (torch.floor_divide(ixB, qB).type(torch.int64) + 1) * dimB**2
msb_index += (torch.floor_divide(ixC, qC).type(torch.int64) + 1) * dimC**1
msb_index += (torch.floor_divide(ixD, qD).type(torch.int64) + 1) * dimD**0
outB = torch.gather(input=lut, dim=-1, index=msb_index)
lsb_coef = ((ixA+ixB+ixC+ixD)/4 % qA) / qA
out = outA + lsb_coef*(outB-outA) out = outA + lsb_coef*(outB-outA)
out = out.squeeze(-1) out = out.reshape((b,c,h,w))
return out
def barycentric_interpolate(masks, coefs, vertices):
i = torch.all(torch.stack(masks), dim=0, keepdim = False)
coefs = torch.stack(coefs) * i
vertices = torch.stack(vertices)
out = (coefs*vertices).sum(0)
return i, out
def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
dimA, dimB, dimC, dimD = lut.shape[:4]
qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1)
outDims = lut.shape[4:]
lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
index_loop_indexes = ixA.shape
lut_loop_indexes = lut.shape[:-1]
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
msbA = torch.floor_divide(ixA, qA).type(torch.int64)
msbB = torch.floor_divide(ixB, qB).type(torch.int64)
msbC = torch.floor_divide(ixC, qC).type(torch.int64)
msbD = torch.floor_divide(ixD, qD).type(torch.int64)
fa, fb, fc, fd = ixA % qA, ixB % qB, ixC % qC, ixD % qD
fab, fac, fad, fbc, fbd, fcd = fa>fb, fa>fc, fa>fd, fb>fc, fb>fd, fc>fd
strides = torch.tensor([dimA**3, dimB**2, dimC**1, dimD**0], device=lut.device).view(-1, *((1,)*len(msbA.shape)))
p0000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD ])*strides).sum(0))
p0001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD+1])*strides).sum(0))
p0010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD ])*strides).sum(0))
p0011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD+1])*strides).sum(0))
p0100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD ])*strides).sum(0))
p0101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD+1])*strides).sum(0))
p0110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD ])*strides).sum(0))
p0111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD+1])*strides).sum(0))
p1000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD ])*strides).sum(0))
p1001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD+1])*strides).sum(0))
p1010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD ])*strides).sum(0))
p1011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD+1])*strides).sum(0))
p1100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD ])*strides).sum(0))
p1101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD+1])*strides).sum(0))
p1110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD ])*strides).sum(0))
p1111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD+1])*strides).sum(0))
i1, out1 = barycentric_interpolate([fab, fbc, fcd], [qA-fa, fa - fb, fb - fc, fc - fd, fd], [p0000, p1000, p1100, p1110, p1111])
i2, out2 = barycentric_interpolate([fab, fbc, fbd, ~(i1)], [qA-fa, fa - fb, fb - fd, fd - fc, fc], [p0000, p1000, p1100, p1101, p1111])
i3, out3 = barycentric_interpolate([fab, fbc, fad, ~(i1), ~(i2)], [qA-fa, fa - fd, fd - fb, fb - fc, fc], [p0000, p1000, p1001, p1101, p1111])
i4, out4 = barycentric_interpolate([fab, fbc, ~(i1), ~(i2), ~(i3)], [qA-fd, fd - fa, fa - fb, fb - fc, fc], [p0000, p0001, p1001, p1101, p1111])
i5, out5 = barycentric_interpolate([fab, fac, fbd, ~(fbc)], [qA-fa, fa - fc, fc - fb, fb - fd, fd], [p0000, p1000, p1010, p1110, p1111])
i6, out6 = barycentric_interpolate([fab, fac, fcd, ~(fbc), ~(i5)], [qA-fa, fa - fc, fc - fd, fd - fb, fb], [p0000, p1000, p1010, p1011, p1111])
i7, out7 = barycentric_interpolate([fab, fac, fad, ~(fbc), ~(i5), ~(i6)], [qA-fa, fa - fd, fd - fc, fc - fb, fb], [p0000, p1000, p1001, p1011, p1111])
i8, out8 = barycentric_interpolate([fab, fac, ~(fbc), ~(i5), ~(i6), ~(i7)], [qA-fd, fd - fa, fa - fc, fc - fb, fb], [p0000, p0001, p1001, p1011, p1111])
i9, out9 = barycentric_interpolate([fab, fbd, ~(fbc), ~(fac)], [qA-fc, fc - fa, fa - fb, fb - fd, fd], [p0000, p0010, p1010, p1110, p1111])
i10, out10 = barycentric_interpolate([fab, fad, ~(fbc), ~(fac), ~(i9)], [qA-fc, fc - fa, fa - fd, fd - fb, fb], [p0000, p0010, p1010, p1011, p1111])
i11, out11 = barycentric_interpolate([fab, fcd, ~(fbc), ~(fac), ~(i9), ~(i10)], [qA-fc, fc - fd, fd - fa, fa - fb, fb], [p0000, p0010, p0011, p1011, p1111])
i12, out12 = barycentric_interpolate([fab, ~(fbc), ~(fac), ~(i9), ~(i10), ~(i11)], [qA-fd, fd - fc, fc - fa, fa - fb, fb], [p0000, p0001, p0011, p1011, p1111])
i13, out13 = barycentric_interpolate([fac, fcd, ~(fab)], [qA-fb, fb - fa, fa - fc, fc - fd, fd], [p0000, p0100, p1100, p1110, p1111])
i14, out14 = barycentric_interpolate([fac, fad, ~(fab), ~(i13)], [qA-fb, fb - fa, fa - fd, fd - fc, fc], [p0000, p0100, p1100, p1101, p1111])
i15, out15 = barycentric_interpolate([fac, fbd, ~(fab), ~(i13), ~(i14)], [qA-fb, fb - fd, fd - fa, fa - fc, fc], [p0000, p0100, p0101, p1101, p1111])
i16, out16 = barycentric_interpolate([fac, ~(fab), ~(i13), ~(i14), ~(i15) ], [qA-fd, fd - fb, fb - fa, fa - fc, fc], [p0000, p0001, p0101, p1101, p1111])
i17, out17 = barycentric_interpolate([fbc, fad, ~(fab), ~(fac)], [qA-fb, fb - fc, fc - fa, fa - fd, fd], [p0000, p0100, p0110, p1110, p1111])
i18, out18 = barycentric_interpolate([fbc, fcd, ~(fab), ~(fac), ~(i17)], [qA-fb, fb - fc, fc - fd, fd - fa, fa], [p0000, p0100, p0110, p0111, p1111])
i19, out19 = barycentric_interpolate([fbc, fbd, ~(fab), ~(fac), ~(i17), ~(i18)], [qA-fb, fb - fd, fd - fc, fc - fa, fa], [p0000, p0100, p0101, p0111, p1111])
i20, out20 = barycentric_interpolate([fbc, ~(fab), ~(fac), ~(i17), ~(i18), ~(i19)], [qA-fd, fd - fb, fb - fc, fc - fa, fa], [p0000, p0001, p0101, p0111, p1111])
i21, out21 = barycentric_interpolate([fad, ~(fab), ~(fac), ~(fbc) ], [qA-fc, fc - fb, fb - fa, fa - fd, fd], [p0000, p0010, p0110, p1110, p1111])
i22, out22 = barycentric_interpolate([fbd, ~(fab), ~(fac), ~(fbc), ~(i21)], [qA-fc, fc - fb, fb - fd, fd - fa, fa], [p0000, p0010, p0110, p0111, p1111])
i23, out23 = barycentric_interpolate([fcd, ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22)], [qA-fc, fc - fd, fd - fb, fb - fa, fa], [p0000, p0010, p0011, p0111, p1111])
i24, out24 = barycentric_interpolate([ ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22), ~(i23)], [qA-fd, fd - fc, fc - fb, fb - fa, fa], [p0000, p0001, p0011, p0111, p1111])
out = out1 + out2 + out3 + out4 + out5 + out6 + out7 + out8 + out9 + out10 + out11 + out12 + out13 + out14 + out15 + out16 + out17 + out18 + out19 + out20 + out21 + out22 + out23 + out24
out /= qA
out = out.squeeze(-1)
return out return out
def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): #self, weight, upscale, mode, img_in, bd):
def select_index_4dlut_tetrahedral2(ixA, ixB, ixC, ixD, lut): #self, weight, upscale, mode, img_in, bd): lut = torch.clamp(lut, 0, 255)
dimA, dimB, dimC, dimD = lut.shape[:4] dimA, dimB, dimC, dimD = lut.shape[:4]
q = 256/(dimA-1) q = 256/(dimA-1)
L = dimA L = dimA
upscale = lut.shape[-1] upscale = lut.shape[-1]
weight = lut weight = lut.reshape(L**4,upscale,upscale)
img_a1 = torch.floor_divide(ixA, q).type(torch.int64) img_a1 = torch.floor_divide(ixA, q).type(torch.int64)
img_b1 = torch.floor_divide(ixB, q).type(torch.int64) img_b1 = torch.floor_divide(ixB, q).type(torch.int64)
@ -404,6 +268,7 @@ def select_index_4dlut_tetrahedral2(ixA, ixB, ixC, ixD, lut): #self, weight, ups
i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale)) # out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale))
out = out / q out = out / q
# print(out.shape)
return out return out

@ -3,9 +3,9 @@ import numpy as np
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
import ray # import ray
@ray.remote(num_cpus=1, num_gpus=0.3) # @ray.remote(num_cpus=1, num_gpus=0.3)
def val_image_pair(model, hr_image, lr_image, output_image_path=None): def val_image_pair(model, hr_image, lr_image, output_image_path=None):
with torch.no_grad(): with torch.no_grad():
# prepare lr_image # prepare lr_image
@ -30,7 +30,7 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None):
return PSNR(left, right, model.scale), cal_ssim(left, right) return PSNR(left, right, model.scale), cal_ssim(left, right)
def valid_steps(model, datasets, config, log_prefix=""): def valid_steps(model, datasets, config, log_prefix=""):
ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"}) # ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"})
dataset_names = list(datasets.keys()) dataset_names = list(datasets.keys())
for i in range(len(dataset_names)): for i in range(len(dataset_names)):
@ -45,22 +45,21 @@ def valid_steps(model, datasets, config, log_prefix=""):
tasks = [] tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_val_predictions else None output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_val_predictions else None
task = val_image_pair.remote(model, hr_image, lr_image, output_image_path) task = val_image_pair(model, hr_image, lr_image, output_image_path)
tasks.append(task) tasks.append(task)
ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None) # ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
while len(remaining_refs) > 0: # while len(remaining_refs) > 0:
print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ") # print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ")
ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None) # ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None)
print("\r", end=" ") # print("\r", end=" ")
tasks = [ray.get(task) for task in tasks] # tasks = [ray.get(task) for task in tasks]
for psnr, ssim in tasks: for psnr, ssim in tasks:
psnrs.append(psnr) psnrs.append(psnr)
ssims.append(ssim) ssims.append(ssim)
config.logger.info( config.logger.info(
'{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims)))) '\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims))))
config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter) config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter)
config.writer.flush() config.writer.flush()
print()

@ -1,20 +1,21 @@
from .rcnet import RCNetCentered_3x3, RCNetCentered_7x7, RCNetRot90_3x3, RCNetRot90_7x7, RCNetx1, RCNetx2 from . import rcnet
from .rclut import RCLutCentered_3x3, RCLutCentered_7x7, RCLutRot90_3x3, RCLutRot90_7x7, RCLutx1, RCLutx2 from . import rclut
from .srnet import SRNet, SRNetRot90 from . import srnet
from .srlut import SRLut, SRLutRot90 from . import srlut
import torch import torch
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
AVAILABLE_MODELS = { AVAILABLE_MODELS = {
'SRNet': SRNet, 'SRLut': SRLut, 'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut,
'SRNetRot90': SRNetRot90, 'SRLutRot90': SRLutRot90, 'SRNetRot90': srnet.SRNetRot90, 'SRLutRot90': srlut.SRLutRot90,
'RCNetCentered_3x3': RCNetCentered_3x3, 'RCLutCentered_3x3': RCLutCentered_3x3, 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
'RCNetCentered_7x7': RCNetCentered_7x7, 'RCLutCentered_7x7': RCLutCentered_7x7, 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
'RCNetRot90_3x3': RCNetRot90_3x3, 'RCLutRot90_3x3': RCLutRot90_3x3, 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,
'RCNetRot90_7x7': RCNetRot90_7x7, 'RCLutRot90_7x7': RCLutRot90_7x7, 'RCNetRot90_7x7': rcnet.RCNetRot90_7x7, 'RCLutRot90_7x7': rclut.RCLutRot90_7x7,
'RCNetx1': RCNetx1, 'RCLutx1': RCLutx1, 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1,
'RCNetx2': RCNetx2, 'RCLutx2': RCLutx2, 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2,
'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered,
} }
def SaveCheckpoint(model, path): def SaveCheckpoint(model, path):

@ -3,24 +3,21 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from common.utils import round_func from common.utils import round_func
from common.lut import select_index_1dlut_msb, select_index_4dlut_msb, select_index_4dlut_tetrahedral, select_index_1dlut_linear, \ from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
from pathlib import Path from pathlib import Path
from einops import repeat from einops import repeat
class RCLutCentered_3x3(nn.Module): class RCLutCentered_3x3(nn.Module):
def __init__( def __init__(
self, self,
window_size,
quantization_interval, quantization_interval,
scale scale
): ):
super(RCLutCentered_3x3, self).__init__() super(RCLutCentered_3x3, self).__init__()
self.scale = scale self.scale = scale
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
self.window_size = window_size
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_lut(
@ -28,8 +25,7 @@ class RCLutCentered_3x3(nn.Module):
): ):
scale = int(dense_conv_lut.shape[-1]) scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1) quantization_interval = 256//(dense_conv_lut.shape[0]-1)
window_size = rc_conv_luts.shape[0] lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale)
lut_model = RCLutCentered_3x3(window_size=window_size, quantization_interval=quantization_interval, scale=scale)
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
return lut_model return lut_model
@ -61,9 +57,8 @@ class RCLutCentered_7x7(nn.Module):
super(RCLutCentered_7x7, self).__init__() super(RCLutCentered_7x7, self).__init__()
self.scale = scale self.scale = scale
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
self.window_size = window_size
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_lut(
@ -71,8 +66,7 @@ class RCLutCentered_7x7(nn.Module):
): ):
scale = int(dense_conv_lut.shape[-1]) scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1) quantization_interval = 256//(dense_conv_lut.shape[0]-1)
window_size = rc_conv_luts.shape[0] lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale)
lut_model = RCLutCentered_7x7(window_size=window_size, quantization_interval=quantization_interval, scale=scale)
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
return lut_model return lut_model
@ -102,9 +96,8 @@ class RCLutRot90_3x3(nn.Module):
super(RCLutRot90_3x3, self).__init__() super(RCLutRot90_3x3, self).__init__()
self.scale = scale self.scale = scale
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
window_size = 3
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_lut(
@ -112,7 +105,6 @@ class RCLutRot90_3x3(nn.Module):
): ):
scale = int(dense_conv_lut.shape[-1]) scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1) quantization_interval = 256//(dense_conv_lut.shape[0]-1)
window_size = rc_conv_luts.shape[0]
lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale) lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
@ -149,8 +141,7 @@ class RCLutRot90_7x7(nn.Module):
self.scale = scale self.scale = scale
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
window_size = 7 self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
@staticmethod @staticmethod
def init_from_lut( def init_from_lut(
@ -194,12 +185,9 @@ class RCLutx1(nn.Module):
super(RCLutx1, self).__init__() super(RCLutx1, self).__init__()
self.scale = scale self.scale = scale
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
window_size = 3 self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
window_size = 5 self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
window_size = 7
self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@ -229,7 +217,6 @@ class RCLutx1(nn.Module):
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
print("lut:", x.min(), x.max(), x.mean())
return x return x
def forward(self, x): def forward(self, x):
@ -392,3 +379,128 @@ class RCLutx2(nn.Module):
f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}", f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}", f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
")"]) ")"])
class RCLutx2Centered(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutx2Centered, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
):
scale = int(s2_dense_conv_lut_3x3.shape[-1])
quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale)
lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
return lut_model
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
x = forward_rc_conv_centered(index=index, lut=rc_conv_lut)
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
")"])

@ -327,3 +327,66 @@ class RCNetx2(nn.Module):
output /= 3*4 output /= 3*4
output = output.view(b, c, h*self.scale, w*self.scale) output = output.view(b, c, h*self.scale, w*self.scale)
return output return output
class RCNetx2Centered(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(RCNetx2Centered, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.stage1_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
self.stage1_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
self.stage1_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
self.stage2_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
self.stage2_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
self.stage2_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutx2Centered.init_from_lut(
s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
)
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
output /= 3*4
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
output /= 3*4
output = output.view(b, c, h*self.scale, w*self.scale)
return output

@ -8,8 +8,43 @@ import numpy as np
import cv2 import cv2
from PIL import Image from PIL import Image
from datetime import datetime from datetime import datetime
import argparse
project_path = Path("../../").resolve() class DemoOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--net_model_path', '-n', type=str, default=None, help="Net model path folder")
self.parser.add_argument('--lut_model_path', '-l', type=str, default=None, help="Lut model path folder")
self.parser.add_argument('--project_path', '-q', type=str, default="../../", help="Project path.")
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")
def parse_args(self):
args = self.parser.parse_args()
args.project_path = Path(args.project_path).resolve()
if args.net_model_path is None:
args.project_path / "models" / "last_transfered_net.pth"
else:
args.net_model_path = Path(args.net_model_path).resolve()
if args.lut_model_path is None:
args.project_path / "models" / "last_transfered_lut.pth"
else:
args.lut_model_path = Path(args.lut_model_path).resolve()
return args
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
print()
config_inst = DemoOptions()
config = config_inst.parse_args()
start_script_time = datetime.now() start_script_time = datetime.now()
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCNetCentered_10000.pth") # net_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCNetCentered_10000.pth")
@ -24,14 +59,14 @@ start_script_time = datetime.now()
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCNetx1_46000.pth") # net_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCNetx1_46000.pth")
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCLutx1_0.pth") # lut_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCLutx1_0.pth")
net_model = LoadCheckpoint(project_path / "models" / "last_transfered_net.pth").cuda() net_model = LoadCheckpoint(config.net_model_path).cuda()
lut_model = LoadCheckpoint(project_path / "models" / "last_transfered_lut.pth").cuda() lut_model = LoadCheckpoint(config.lut_model_path).cuda()
print(net_model) print(net_model)
print(lut_model) print(lut_model)
lr_image = cv2.imread(str(project_path / "data" / "Set14/LR/X4/lenna.png"))[:,:,::-1].copy() lr_image = cv2.imread(str(config.project_path / "data" / "Set14/LR/X4/lenna.png"))[:,:,::-1].copy()
image_gt = cv2.imread(str(project_path / "data" / "Set14/HR/lenna.png"))[:,:,::-1].copy() image_gt = cv2.imread(str(config.project_path / "data" / "Set14/HR/lenna.png"))[:,:,::-1].copy()
# lr_image = cv2.imread(str(project_path / "data" / "Synthetic/LR/X4/linear.png"))[:,:,::-1].copy() # lr_image = cv2.imread(str(project_path / "data" / "Synthetic/LR/X4/linear.png"))[:,:,::-1].copy()
# image_gt = cv2.imread(str(project_path / "data" / "Synthetic/HR/linear.png"))[:,:,::-1].copy() # image_gt = cv2.imread(str(project_path / "data" / "Synthetic/HR/linear.png"))[:,:,::-1].copy()
@ -46,6 +81,6 @@ image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_
image_net = cv2.putText(net_prediction, net_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) image_net = cv2.putText(net_prediction, net_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
image_lut = cv2.putText(lut_prediction, lut_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) image_lut = cv2.putText(lut_prediction, lut_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
Image.fromarray(np.concatenate([image_gt, image_net, image_lut], 1)).save(project_path / "models" / 'last_transfered_demo.png') Image.fromarray(np.concatenate([image_gt, image_net, image_lut], 1)).save(config.project_path / "models" / 'last_transfered_demo.png')
print(datetime.now() - start_script_time ) print(datetime.now() - start_script_time )

@ -29,21 +29,21 @@ class TrainOptions:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False)
parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}") parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.") parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.")
parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.")
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor") parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers") parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--models_dir', type=str, default='../../models/', help="experiment folder")
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size') parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training")
parser.add_argument('--models_dir', type=str, default='../../models/', help="experiment folder")
parser.add_argument('--datasets_dir', type=str, default="../../data/") parser.add_argument('--datasets_dir', type=str, default="../../data/")
parser.add_argument('--train_datasets', type=str, default='DIV2K')
parser.add_argument('--val_datasets', type=str, default='Set5,Set14')
parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further') parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further')
parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations') parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations')
parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration') parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration')
parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration') parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration')
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration') parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
parser.add_argument('--worker_num', '-n', type=int, default=1) parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16) parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
self.parser = parser self.parser = parser
@ -72,20 +72,20 @@ def prepare_experiment_folder(config):
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}."
config.exp_dir = config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}" config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}").resolve()
if not config.exp_dir.exists(): if not config.exp_dir.exists():
config.exp_dir.mkdir() config.exp_dir.mkdir()
config.checkpoint_dir = config.exp_dir / "checkpoints" config.checkpoint_dir = (config.exp_dir / "checkpoints").resolve()
if not config.checkpoint_dir.exists(): if not config.checkpoint_dir.exists():
config.checkpoint_dir.mkdir() config.checkpoint_dir.mkdir()
config.valout_dir = config.exp_dir / 'val' config.valout_dir = (config.exp_dir / 'val').resolve()
if not config.valout_dir.exists(): if not config.valout_dir.exists():
config.valout_dir.mkdir() config.valout_dir.mkdir()
config.logs_dir = config.exp_dir / 'logs' config.logs_dir = (config.exp_dir / 'logs').resolve()
if not config.logs_dir.exists(): if not config.logs_dir.exists():
config.logs_dir.mkdir() config.logs_dir.mkdir()
@ -101,7 +101,7 @@ if __name__ == "__main__":
config.model = model.__class__.__name__ config.model = model.__class__.__name__
else: else:
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale)
# model = model.cuda() model = model.cuda()
optimizer = AdamWScheduleFree(model.parameters()) optimizer = AdamWScheduleFree(model.parameters())
prepare_experiment_folder(config) prepare_experiment_folder(config)
@ -150,7 +150,12 @@ if __name__ == "__main__":
# TRAINING # TRAINING
i = config.start_iter i = config.start_iter
if not config.model_path is None:
config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
for i in range(config.start_iter + 1, config.total_iter + 1): for i in range(config.start_iter + 1, config.total_iter + 1):
# prof.step()
torch.cuda.empty_cache() torch.cuda.empty_cache()
start_time = time.time() start_time = time.time()
try: try:
@ -158,15 +163,18 @@ if __name__ == "__main__":
except StopIteration: except StopIteration:
train_iter = iter(train_loader) train_iter = iter(train_loader)
hr_patch, lr_patch = next(train_iter) hr_patch, lr_patch = next(train_iter)
# hr_patch = hr_patch.cuda() hr_patch = hr_patch.cuda()
# lr_patch = lr_patch.cuda() lr_patch = lr_patch.cuda()
prepare_data_time += time.time() - start_time prepare_data_time += time.time() - start_time
start_time = time.time() start_time = time.time()
optimizer.zero_grad()
pred = model(lr_patch) pred = model(lr_patch)
loss = F.mse_loss(pred/255, hr_patch/255) loss = F.mse_loss(pred/255, hr_patch/255)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad()
forward_backward_time += time.time() - start_time forward_backward_time += time.time() - start_time
# For monitoring # For monitoring
@ -178,7 +186,7 @@ if __name__ == "__main__":
config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i) config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i)
config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format( config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
config.exp_dir, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step, model.__class__.__name__, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step,
forward_backward_time / config.display_step)) forward_backward_time / config.display_step))
l_accum = [0., 0., 0.] l_accum = [0., 0., 0.]
prepare_data_time = 0. prepare_data_time = 0.
@ -193,7 +201,5 @@ if __name__ == "__main__":
config.current_iter = i config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
total_script_time = datetime.now() - script_start_time total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}") config.logger.info(f"Completed after {total_script_time}")

@ -20,7 +20,7 @@ class TransferToLutOptions():
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', '-m', type=str, default='', help="model path folder") self.parser.add_argument('--model_path', '-m', type=str, default='', help="model path folder")
self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets in 2**bits. Value is in range [1, 8].") self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")
def parse_args(self): def parse_args(self):
@ -52,6 +52,9 @@ if __name__ == "__main__":
config_inst.print_options(config) config_inst.print_options(config)
model = models.LoadCheckpoint(config.model_path).cuda() model = models.LoadCheckpoint(config.model_path).cuda()
if getattr(model, 'get_lut_model', None) is None:
print("Transfer to lut can be applied only to the network model.")
exit(1)
print(model) print(model)
print() print()

@ -24,8 +24,8 @@ class ValOptions():
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, help="Model path.") self.parser.add_argument('--model_path', type=str, help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../../data/") self.parser.add_argument('--datasets_dir', type=str, default="../../data/", help="Path to datasets.")
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14') self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.")
self.parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
def parse_args(self): def parse_args(self):

Loading…
Cancel
Save