From 3b2c9ee8e3e0e3f632e1983ea8cfdb6b55da4949 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sat, 20 Apr 2024 08:57:59 +0000 Subject: [PATCH] update --- readme.md | 8 +- src/common/data.py | 1 + src/common/lut.py | 183 +++++---------------------------- src/common/validation.py | 25 +++-- src/models/__init__.py | 25 ++--- src/models/rclut.py | 158 +++++++++++++++++++++++----- src/models/rcnet.py | 63 ++++++++++++ src/scripts/image_demo.py | 49 +++++++-- src/scripts/train.py | 42 ++++---- src/scripts/transfer_to_lut.py | 9 +- src/scripts/validate.py | 4 +- 11 files changed, 327 insertions(+), 240 deletions(-) diff --git a/readme.md b/readme.md index 5a1b5de..16fb3ed 100644 --- a/readme.md +++ b/readme.md @@ -2,11 +2,13 @@ ``` python train.py --train_datasets DIV2K_pillow_bicubic --val_datasets Set5,Set14 --scale 4 --total_iter 10000 --model SRNetRot90 python transfer_to_lut.py --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth -python image_demo.py +python train.py --model_path /wd/lut_reproduce/models/RCNetx2_DIV2K_pillow_bicubic/checkpoints/RCLutx2_0.pth --train_datasets DIV2K_pillow_bicubic --total_iter 10000 + +python image_demo.py -n /wd/lut_reproduce/models/RCNetx2_DIV2K_pillow_bicubic/checkpoints/RCNetCentered_3x3_10000.pth -l /wd/lut_reproduce/models/RCLutCentered_3x3_DIV2K_pillow_bicubic/checkpoints/RCLutCentered_3x3_10000.pth + python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth ``` Requierements: - [shedulefree](https://github.com/facebookresearch/schedule_free) -- einops -- ray \ No newline at end of file +- einops \ No newline at end of file diff --git a/src/common/data.py b/src/common/data.py index 4e39b6b..a19c0c5 100644 --- a/src/common/data.py +++ b/src/common/data.py @@ -15,6 +15,7 @@ image_extensions = ['.jpg', '.png'] def load_images_cached(images_dir_path): image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions]) cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_cache.npy" + cache_path = cache_path.resolve() if not Path(cache_path).exists(): print("Caching to:", cache_path) value = {f:np.array(Image.open(f)) for f in image_paths} diff --git a/src/common/lut.py b/src/common/lut.py index 78a9cd1..7e2ef13 100644 --- a/src/common/lut.py +++ b/src/common/lut.py @@ -84,7 +84,7 @@ def forward_2x2_input_SxS_output(index, lut): b,c,hs,ws = index.shape scale = lut.shape[-1] index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') #? - out = select_index_4dlut_tetrahedral2( + out = select_index_4dlut_tetrahedral( ixA = index, ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]), ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]), @@ -101,7 +101,7 @@ def forward_rc_conv_centered(index, lut): window_size = lut.shape[0] index = F.pad(index, pad=[window_size//2]*4, mode='replicate') window_indexes = lut.shape[:-1] - index = index.unsqueeze(-1) + # index = index.unsqueeze(-1) x = torch.zeros_like(index) for i in range(window_indexes[-2]): for j in range(window_indexes[-1]): @@ -118,7 +118,7 @@ def forward_rc_conv_rot90(index, lut): window_size = lut.shape[0] index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate') window_indexes = lut.shape[:-1] - index = index.unsqueeze(-1) + # index = index.unsqueeze(-1) x = torch.zeros_like(index) for i in range(window_indexes[-2]): for j in range(window_indexes[-1]): @@ -135,166 +135,30 @@ def forward_rc_conv_rot90(index, lut): ##################### UTILS ########################## def select_index_1dlut_linear(ixA, lut): - dimA = lut.shape[0] - qA = 256/(dimA-1) - outDims = lut.shape[1:] - lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) - index_loop_indexes = ixA.shape - lut_loop_indexes = lut.shape[:-1] - ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - msbA = torch.floor_divide(ixA, qA).type(torch.int64) - msbB = torch.floor_divide(ixA, qA).type(torch.int64) + 1 - lsb_index = ixA % qA - lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) - outA = torch.gather(input=lut, dim=-1, index=msbA) - outB = torch.gather(input=lut, dim=-1, index=msbB) - out = outA + (lsb_index/qA) * (outB-outA) - out = out.squeeze(-1) + lut = torch.clamp(lut, 0, 255) + b,c,h,w = ixA.shape + ixA = ixA.flatten() + L = lut.shape[0] + Q = 256/(L-1) + msbA = torch.floor_divide(ixA, Q).type(torch.int64) + msbB = msbA + 1 + msbA = msbA.flatten() + msbB = msbB.flatten() + lsb = ixA % Q + outA = lut[msbA] + outB = lut[msbB] + lsb_coef = lsb / Q + out = outA + lsb_coef*(outB-outA) + out = out.reshape((b,c,h,w)) return out -def select_index_1dlut_msb(ixA, lut): - dimA = lut.shape[0] - outDims = lut.shape[1:] - lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) - index_loop_indexes = ixA.shape - lut_loop_indexes = lut.shape[:-1] - ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - msb_index = torch.floor_divide(ixA, 256/(dimA-1)).type(torch.int64) * dimA**0 - lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) - out = torch.gather(input=lut, dim=-1, index=msb_index) - out = out.squeeze(-1) - return out - -def select_index_4dlut_msb(ixA, ixB, ixC, ixD, lut): - dimA, dimB, dimC, dimD = lut.shape[:4] - qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1) - outDims = lut.shape[4:] - lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) - index_loop_indexes = ixA.shape - lut_loop_indexes = lut.shape[:-1] - ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - msb_index = torch.floor_divide(ixA, qA) * dimA**3 - msb_index += torch.floor_divide(ixB, qB) * dimB**2 - msb_index += torch.floor_divide(ixC, qC) * dimC**1 - msb_index += torch.floor_divide(ixD, qD) * dimD**0 - lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) - out = torch.gather(input=lut, dim=-1, index=msb_index.type(torch.int64)) - out = out.squeeze(-1) - return out - -def select_index_4dlut_linear(ixA, ixB, ixC, ixD, lut): - dimA, dimB, dimC, dimD = lut.shape[:4] - qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1) - outDims = lut.shape[4:] - lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) - index_loop_indexes = ixA.shape - lut_loop_indexes = lut.shape[:-1] - lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) - ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - msb_index = torch.floor_divide(ixA, qA).type(torch.int64) * dimA**3 - msb_index += torch.floor_divide(ixB, qB).type(torch.int64) * dimB**2 - msb_index += torch.floor_divide(ixC, qC).type(torch.int64) * dimC**1 - msb_index += torch.floor_divide(ixD, qD).type(torch.int64) * dimD**0 - outA = torch.gather(input=lut, dim=-1, index=msb_index) - msb_index = (torch.floor_divide(ixA, qA).type(torch.int64) + 1) * dimA**3 - msb_index += (torch.floor_divide(ixB, qB).type(torch.int64) + 1) * dimB**2 - msb_index += (torch.floor_divide(ixC, qC).type(torch.int64) + 1) * dimC**1 - msb_index += (torch.floor_divide(ixD, qD).type(torch.int64) + 1) * dimD**0 - outB = torch.gather(input=lut, dim=-1, index=msb_index) - lsb_coef = ((ixA+ixB+ixC+ixD)/4 % qA) / qA - out = outA + lsb_coef*(outB-outA) - out = out.squeeze(-1) - return out - -def barycentric_interpolate(masks, coefs, vertices): - i = torch.all(torch.stack(masks), dim=0, keepdim = False) - coefs = torch.stack(coefs) * i - vertices = torch.stack(vertices) - out = (coefs*vertices).sum(0) - return i, out - -def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): - dimA, dimB, dimC, dimD = lut.shape[:4] - qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1) - outDims = lut.shape[4:] - lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) - index_loop_indexes = ixA.shape - lut_loop_indexes = lut.shape[:-1] - lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) - ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) - - msbA = torch.floor_divide(ixA, qA).type(torch.int64) - msbB = torch.floor_divide(ixB, qB).type(torch.int64) - msbC = torch.floor_divide(ixC, qC).type(torch.int64) - msbD = torch.floor_divide(ixD, qD).type(torch.int64) - fa, fb, fc, fd = ixA % qA, ixB % qB, ixC % qC, ixD % qD - fab, fac, fad, fbc, fbd, fcd = fa>fb, fa>fc, fa>fd, fb>fc, fb>fd, fc>fd - - strides = torch.tensor([dimA**3, dimB**2, dimC**1, dimD**0], device=lut.device).view(-1, *((1,)*len(msbA.shape))) - p0000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD ])*strides).sum(0)) - p0001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD+1])*strides).sum(0)) - p0010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD ])*strides).sum(0)) - p0011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD+1])*strides).sum(0)) - p0100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD ])*strides).sum(0)) - p0101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD+1])*strides).sum(0)) - p0110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD ])*strides).sum(0)) - p0111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD+1])*strides).sum(0)) - p1000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD ])*strides).sum(0)) - p1001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD+1])*strides).sum(0)) - p1010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD ])*strides).sum(0)) - p1011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD+1])*strides).sum(0)) - p1100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD ])*strides).sum(0)) - p1101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD+1])*strides).sum(0)) - p1110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD ])*strides).sum(0)) - p1111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD+1])*strides).sum(0)) - - i1, out1 = barycentric_interpolate([fab, fbc, fcd], [qA-fa, fa - fb, fb - fc, fc - fd, fd], [p0000, p1000, p1100, p1110, p1111]) - i2, out2 = barycentric_interpolate([fab, fbc, fbd, ~(i1)], [qA-fa, fa - fb, fb - fd, fd - fc, fc], [p0000, p1000, p1100, p1101, p1111]) - i3, out3 = barycentric_interpolate([fab, fbc, fad, ~(i1), ~(i2)], [qA-fa, fa - fd, fd - fb, fb - fc, fc], [p0000, p1000, p1001, p1101, p1111]) - i4, out4 = barycentric_interpolate([fab, fbc, ~(i1), ~(i2), ~(i3)], [qA-fd, fd - fa, fa - fb, fb - fc, fc], [p0000, p0001, p1001, p1101, p1111]) - i5, out5 = barycentric_interpolate([fab, fac, fbd, ~(fbc)], [qA-fa, fa - fc, fc - fb, fb - fd, fd], [p0000, p1000, p1010, p1110, p1111]) - i6, out6 = barycentric_interpolate([fab, fac, fcd, ~(fbc), ~(i5)], [qA-fa, fa - fc, fc - fd, fd - fb, fb], [p0000, p1000, p1010, p1011, p1111]) - i7, out7 = barycentric_interpolate([fab, fac, fad, ~(fbc), ~(i5), ~(i6)], [qA-fa, fa - fd, fd - fc, fc - fb, fb], [p0000, p1000, p1001, p1011, p1111]) - i8, out8 = barycentric_interpolate([fab, fac, ~(fbc), ~(i5), ~(i6), ~(i7)], [qA-fd, fd - fa, fa - fc, fc - fb, fb], [p0000, p0001, p1001, p1011, p1111]) - i9, out9 = barycentric_interpolate([fab, fbd, ~(fbc), ~(fac)], [qA-fc, fc - fa, fa - fb, fb - fd, fd], [p0000, p0010, p1010, p1110, p1111]) - i10, out10 = barycentric_interpolate([fab, fad, ~(fbc), ~(fac), ~(i9)], [qA-fc, fc - fa, fa - fd, fd - fb, fb], [p0000, p0010, p1010, p1011, p1111]) - i11, out11 = barycentric_interpolate([fab, fcd, ~(fbc), ~(fac), ~(i9), ~(i10)], [qA-fc, fc - fd, fd - fa, fa - fb, fb], [p0000, p0010, p0011, p1011, p1111]) - i12, out12 = barycentric_interpolate([fab, ~(fbc), ~(fac), ~(i9), ~(i10), ~(i11)], [qA-fd, fd - fc, fc - fa, fa - fb, fb], [p0000, p0001, p0011, p1011, p1111]) - - i13, out13 = barycentric_interpolate([fac, fcd, ~(fab)], [qA-fb, fb - fa, fa - fc, fc - fd, fd], [p0000, p0100, p1100, p1110, p1111]) - i14, out14 = barycentric_interpolate([fac, fad, ~(fab), ~(i13)], [qA-fb, fb - fa, fa - fd, fd - fc, fc], [p0000, p0100, p1100, p1101, p1111]) - i15, out15 = barycentric_interpolate([fac, fbd, ~(fab), ~(i13), ~(i14)], [qA-fb, fb - fd, fd - fa, fa - fc, fc], [p0000, p0100, p0101, p1101, p1111]) - i16, out16 = barycentric_interpolate([fac, ~(fab), ~(i13), ~(i14), ~(i15) ], [qA-fd, fd - fb, fb - fa, fa - fc, fc], [p0000, p0001, p0101, p1101, p1111]) - i17, out17 = barycentric_interpolate([fbc, fad, ~(fab), ~(fac)], [qA-fb, fb - fc, fc - fa, fa - fd, fd], [p0000, p0100, p0110, p1110, p1111]) - i18, out18 = barycentric_interpolate([fbc, fcd, ~(fab), ~(fac), ~(i17)], [qA-fb, fb - fc, fc - fd, fd - fa, fa], [p0000, p0100, p0110, p0111, p1111]) - i19, out19 = barycentric_interpolate([fbc, fbd, ~(fab), ~(fac), ~(i17), ~(i18)], [qA-fb, fb - fd, fd - fc, fc - fa, fa], [p0000, p0100, p0101, p0111, p1111]) - i20, out20 = barycentric_interpolate([fbc, ~(fab), ~(fac), ~(i17), ~(i18), ~(i19)], [qA-fd, fd - fb, fb - fc, fc - fa, fa], [p0000, p0001, p0101, p0111, p1111]) - i21, out21 = barycentric_interpolate([fad, ~(fab), ~(fac), ~(fbc) ], [qA-fc, fc - fb, fb - fa, fa - fd, fd], [p0000, p0010, p0110, p1110, p1111]) - i22, out22 = barycentric_interpolate([fbd, ~(fab), ~(fac), ~(fbc), ~(i21)], [qA-fc, fc - fb, fb - fd, fd - fa, fa], [p0000, p0010, p0110, p0111, p1111]) - i23, out23 = barycentric_interpolate([fcd, ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22)], [qA-fc, fc - fd, fd - fb, fb - fa, fa], [p0000, p0010, p0011, p0111, p1111]) - i24, out24 = barycentric_interpolate([ ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22), ~(i23)], [qA-fd, fd - fc, fc - fb, fb - fa, fa], [p0000, p0001, p0011, p0111, p1111]) - - out = out1 + out2 + out3 + out4 + out5 + out6 + out7 + out8 + out9 + out10 + out11 + out12 + out13 + out14 + out15 + out16 + out17 + out18 + out19 + out20 + out21 + out22 + out23 + out24 - out /= qA - out = out.squeeze(-1) - return out - - -def select_index_4dlut_tetrahedral2(ixA, ixB, ixC, ixD, lut): #self, weight, upscale, mode, img_in, bd): +def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): #self, weight, upscale, mode, img_in, bd): + lut = torch.clamp(lut, 0, 255) dimA, dimB, dimC, dimD = lut.shape[:4] q = 256/(dimA-1) L = dimA upscale = lut.shape[-1] - weight = lut + weight = lut.reshape(L**4,upscale,upscale) img_a1 = torch.floor_divide(ixA, q).type(torch.int64) img_b1 = torch.floor_divide(ixB, q).type(torch.int64) @@ -404,6 +268,7 @@ def select_index_4dlut_tetrahedral2(ixA, ixB, ixC, ixD, lut): #self, weight, ups i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) - out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale)) + # out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale)) out = out / q - return out \ No newline at end of file + # print(out.shape) + return out diff --git a/src/common/validation.py b/src/common/validation.py index 2772a80..7bea7f5 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -3,9 +3,9 @@ import numpy as np from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from pathlib import Path from PIL import Image -import ray +# import ray -@ray.remote(num_cpus=1, num_gpus=0.3) +# @ray.remote(num_cpus=1, num_gpus=0.3) def val_image_pair(model, hr_image, lr_image, output_image_path=None): with torch.no_grad(): # prepare lr_image @@ -30,7 +30,7 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None): return PSNR(left, right, model.scale), cal_ssim(left, right) def valid_steps(model, datasets, config, log_prefix=""): - ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"}) + # ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"}) dataset_names = list(datasets.keys()) for i in range(len(dataset_names)): @@ -45,22 +45,21 @@ def valid_steps(model, datasets, config, log_prefix=""): tasks = [] for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_val_predictions else None - task = val_image_pair.remote(model, hr_image, lr_image, output_image_path) + task = val_image_pair(model, hr_image, lr_image, output_image_path) tasks.append(task) - ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None) - while len(remaining_refs) > 0: - print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ") - ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None) - print("\r", end=" ") - tasks = [ray.get(task) for task in tasks] + # ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None) + # while len(remaining_refs) > 0: + # print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ") + # ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None) + # print("\r", end=" ") + # tasks = [ray.get(task) for task in tasks] for psnr, ssim in tasks: psnrs.append(psnr) ssims.append(ssim) config.logger.info( - '{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims)))) + '\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims)))) config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter) - config.writer.flush() - print() \ No newline at end of file + config.writer.flush() \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py index 90dfa6c..597a06e 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,20 +1,21 @@ -from .rcnet import RCNetCentered_3x3, RCNetCentered_7x7, RCNetRot90_3x3, RCNetRot90_7x7, RCNetx1, RCNetx2 -from .rclut import RCLutCentered_3x3, RCLutCentered_7x7, RCLutRot90_3x3, RCLutRot90_7x7, RCLutx1, RCLutx2 -from .srnet import SRNet, SRNetRot90 -from .srlut import SRLut, SRLutRot90 +from . import rcnet +from . import rclut +from . import srnet +from . import srlut import torch import numpy as np from pathlib import Path AVAILABLE_MODELS = { - 'SRNet': SRNet, 'SRLut': SRLut, - 'SRNetRot90': SRNetRot90, 'SRLutRot90': SRLutRot90, - 'RCNetCentered_3x3': RCNetCentered_3x3, 'RCLutCentered_3x3': RCLutCentered_3x3, - 'RCNetCentered_7x7': RCNetCentered_7x7, 'RCLutCentered_7x7': RCLutCentered_7x7, - 'RCNetRot90_3x3': RCNetRot90_3x3, 'RCLutRot90_3x3': RCLutRot90_3x3, - 'RCNetRot90_7x7': RCNetRot90_7x7, 'RCLutRot90_7x7': RCLutRot90_7x7, - 'RCNetx1': RCNetx1, 'RCLutx1': RCLutx1, - 'RCNetx2': RCNetx2, 'RCLutx2': RCLutx2, + 'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut, + 'SRNetRot90': srnet.SRNetRot90, 'SRLutRot90': srlut.SRLutRot90, + 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, + 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, + 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, + 'RCNetRot90_7x7': rcnet.RCNetRot90_7x7, 'RCLutRot90_7x7': rclut.RCLutRot90_7x7, + 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1, + 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2, + 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered, } def SaveCheckpoint(model, path): diff --git a/src/models/rclut.py b/src/models/rclut.py index 6909b71..4cf4b06 100644 --- a/src/models/rclut.py +++ b/src/models/rclut.py @@ -3,24 +3,21 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np from common.utils import round_func -from common.lut import select_index_1dlut_msb, select_index_4dlut_msb, select_index_4dlut_tetrahedral, select_index_1dlut_linear, \ - forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output +from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output from pathlib import Path from einops import repeat class RCLutCentered_3x3(nn.Module): def __init__( self, - window_size, quantization_interval, scale ): super(RCLutCentered_3x3, self).__init__() self.scale = scale self.quantization_interval = quantization_interval - self.window_size = window_size self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) @staticmethod def init_from_lut( @@ -28,8 +25,7 @@ class RCLutCentered_3x3(nn.Module): ): scale = int(dense_conv_lut.shape[-1]) quantization_interval = 256//(dense_conv_lut.shape[0]-1) - window_size = rc_conv_luts.shape[0] - lut_model = RCLutCentered_3x3(window_size=window_size, quantization_interval=quantization_interval, scale=scale) + lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) return lut_model @@ -61,9 +57,8 @@ class RCLutCentered_7x7(nn.Module): super(RCLutCentered_7x7, self).__init__() self.scale = scale self.quantization_interval = quantization_interval - self.window_size = window_size self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) @staticmethod def init_from_lut( @@ -71,8 +66,7 @@ class RCLutCentered_7x7(nn.Module): ): scale = int(dense_conv_lut.shape[-1]) quantization_interval = 256//(dense_conv_lut.shape[0]-1) - window_size = rc_conv_luts.shape[0] - lut_model = RCLutCentered_7x7(window_size=window_size, quantization_interval=quantization_interval, scale=scale) + lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) return lut_model @@ -102,9 +96,8 @@ class RCLutRot90_3x3(nn.Module): super(RCLutRot90_3x3, self).__init__() self.scale = scale self.quantization_interval = quantization_interval - window_size = 3 self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) @staticmethod def init_from_lut( @@ -112,7 +105,6 @@ class RCLutRot90_3x3(nn.Module): ): scale = int(dense_conv_lut.shape[-1]) quantization_interval = 256//(dense_conv_lut.shape[0]-1) - window_size = rc_conv_luts.shape[0] lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale) lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) @@ -149,8 +141,7 @@ class RCLutRot90_7x7(nn.Module): self.scale = scale self.quantization_interval = quantization_interval self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) - window_size = 7 - self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) @staticmethod def init_from_lut( @@ -194,12 +185,9 @@ class RCLutx1(nn.Module): super(RCLutx1, self).__init__() self.scale = scale self.quantization_interval = quantization_interval - window_size = 3 - self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) - window_size = 5 - self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) - window_size = 7 - self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) + self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) + self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @@ -229,7 +217,6 @@ class RCLutx1(nn.Module): def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) - print("lut:", x.min(), x.max(), x.mean()) return x def forward(self, x): @@ -376,6 +363,131 @@ class RCLutx2(nn.Module): output = output.view(b, c, output.shape[-2], output.shape[-1]) return output + def __repr__(self): + return "\n".join([ + f"{self.__class__.__name__}(", + f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}", + f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}", + f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}", + f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}", + f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}", + f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}", + f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}", + f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}", + f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}", + f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}", + f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}", + f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}", + ")"]) + + + +class RCLutx2Centered(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(RCLutx2Centered, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) + self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) + self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) + self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) + self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) + self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) + self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) + self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) + self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) + self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + + @staticmethod + def init_from_lut( + s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, + s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, + s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, + s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3, + s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5, + s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7 + ): + scale = int(s2_dense_conv_lut_3x3.shape[-1]) + quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1) + + lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale) + + lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32)) + lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32)) + + lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32)) + lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32)) + + lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32)) + lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32)) + + lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32)) + lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32)) + + lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32)) + lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32)) + + lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32)) + lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32)) + + return lut_model + + def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): + x = forward_rc_conv_centered(index=index, lut=rc_conv_lut) + x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) + return x + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), + k=-rotations_count, + dims=[2, 3] + ) + output /= 3*4 + x = output + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), + k=-rotations_count, + dims=[2, 3] + ) + output /= 3*4 + output = output.view(b, c, output.shape[-2], output.shape[-1]) + return output + def __repr__(self): return "\n".join([ f"{self.__class__.__name__}(", diff --git a/src/models/rcnet.py b/src/models/rcnet.py index 3764d2c..1f93478 100644 --- a/src/models/rcnet.py +++ b/src/models/rcnet.py @@ -307,6 +307,69 @@ class RCNetx2(nn.Module): ) return lut_model + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) + output /= 3*4 + x = output + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) + output /= 3*4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + + +class RCNetx2Centered(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetx2Centered, self).__init__() + self.scale = scale + self.hidden_dim = hidden_dim + self.stage1_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) + self.stage1_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) + self.stage1_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) + self.stage2_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) + self.stage2_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) + self.stage2_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) + s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) + s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) + s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) + s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) + s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) + s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + lut_model = rclut.RCLutx2Centered.init_from_lut( + s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, + s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, + s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, + s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, + s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, + s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 + ) + return lut_model + def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w) diff --git a/src/scripts/image_demo.py b/src/scripts/image_demo.py index 8d5681d..1c0c49a 100644 --- a/src/scripts/image_demo.py +++ b/src/scripts/image_demo.py @@ -8,8 +8,43 @@ import numpy as np import cv2 from PIL import Image from datetime import datetime - -project_path = Path("../../").resolve() +import argparse +class DemoOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + self.parser.add_argument('--net_model_path', '-n', type=str, default=None, help="Net model path folder") + self.parser.add_argument('--lut_model_path', '-l', type=str, default=None, help="Lut model path folder") + self.parser.add_argument('--project_path', '-q', type=str, default="../../", help="Project path.") + self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.") + + def parse_args(self): + args = self.parser.parse_args() + args.project_path = Path(args.project_path).resolve() + if args.net_model_path is None: + args.project_path / "models" / "last_transfered_net.pth" + else: + args.net_model_path = Path(args.net_model_path).resolve() + if args.lut_model_path is None: + args.project_path / "models" / "last_transfered_lut.pth" + else: + args.lut_model_path = Path(args.lut_model_path).resolve() + return args + + def print_options(self, opt): + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + print() + +config_inst = DemoOptions() +config = config_inst.parse_args() start_script_time = datetime.now() # net_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCNetCentered_10000.pth") @@ -24,14 +59,14 @@ start_script_time = datetime.now() # net_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCNetx1_46000.pth") # lut_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCLutx1_0.pth") -net_model = LoadCheckpoint(project_path / "models" / "last_transfered_net.pth").cuda() -lut_model = LoadCheckpoint(project_path / "models" / "last_transfered_lut.pth").cuda() +net_model = LoadCheckpoint(config.net_model_path).cuda() +lut_model = LoadCheckpoint(config.lut_model_path).cuda() print(net_model) print(lut_model) -lr_image = cv2.imread(str(project_path / "data" / "Set14/LR/X4/lenna.png"))[:,:,::-1].copy() -image_gt = cv2.imread(str(project_path / "data" / "Set14/HR/lenna.png"))[:,:,::-1].copy() +lr_image = cv2.imread(str(config.project_path / "data" / "Set14/LR/X4/lenna.png"))[:,:,::-1].copy() +image_gt = cv2.imread(str(config.project_path / "data" / "Set14/HR/lenna.png"))[:,:,::-1].copy() # lr_image = cv2.imread(str(project_path / "data" / "Synthetic/LR/X4/linear.png"))[:,:,::-1].copy() # image_gt = cv2.imread(str(project_path / "data" / "Synthetic/HR/linear.png"))[:,:,::-1].copy() @@ -46,6 +81,6 @@ image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_ image_net = cv2.putText(net_prediction, net_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) image_lut = cv2.putText(lut_prediction, lut_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) -Image.fromarray(np.concatenate([image_gt, image_net, image_lut], 1)).save(project_path / "models" / 'last_transfered_demo.png') +Image.fromarray(np.concatenate([image_gt, image_net, image_lut], 1)).save(config.project_path / "models" / 'last_transfered_demo.png') print(datetime.now() - start_script_time ) \ No newline at end of file diff --git a/src/scripts/train.py b/src/scripts/train.py index b183119..85a3010 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -29,21 +29,21 @@ class TrainOptions: parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False) parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}") parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.") + parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.") + parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.") parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor") parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers") - parser.add_argument('--models_dir', type=str, default='../../models/', help="experiment folder") - parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size') - parser.add_argument('--datasets_dir', type=str, default="../../data/") - parser.add_argument('--train_datasets', type=str, default='DIV2K') - parser.add_argument('--val_datasets', type=str, default='Set5,Set14') + parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training") + parser.add_argument('--models_dir', type=str, default='../../models/', help="experiment folder") + parser.add_argument('--datasets_dir', type=str, default="../../data/") parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further') parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations') parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration') parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration') parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration') - parser.add_argument('--worker_num', '-n', type=int, default=1) - parser.add_argument('--prefetch_factor', '-p', type=int, default=16) + parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers") + parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') self.parser = parser @@ -72,20 +72,20 @@ def prepare_experiment_folder(config): assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}." - config.exp_dir = config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}" + config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}").resolve() if not config.exp_dir.exists(): config.exp_dir.mkdir() - config.checkpoint_dir = config.exp_dir / "checkpoints" + config.checkpoint_dir = (config.exp_dir / "checkpoints").resolve() if not config.checkpoint_dir.exists(): config.checkpoint_dir.mkdir() - config.valout_dir = config.exp_dir / 'val' + config.valout_dir = (config.exp_dir / 'val').resolve() if not config.valout_dir.exists(): config.valout_dir.mkdir() - config.logs_dir = config.exp_dir / 'logs' + config.logs_dir = (config.exp_dir / 'logs').resolve() if not config.logs_dir.exists(): config.logs_dir.mkdir() @@ -101,7 +101,7 @@ if __name__ == "__main__": config.model = model.__class__.__name__ else: model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) - # model = model.cuda() + model = model.cuda() optimizer = AdamWScheduleFree(model.parameters()) prepare_experiment_folder(config) @@ -150,7 +150,12 @@ if __name__ == "__main__": # TRAINING i = config.start_iter + if not config.model_path is None: + config.current_iter = i + valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") + for i in range(config.start_iter + 1, config.total_iter + 1): + # prof.step() torch.cuda.empty_cache() start_time = time.time() try: @@ -158,15 +163,18 @@ if __name__ == "__main__": except StopIteration: train_iter = iter(train_loader) hr_patch, lr_patch = next(train_iter) - # hr_patch = hr_patch.cuda() - # lr_patch = lr_patch.cuda() + hr_patch = hr_patch.cuda() + lr_patch = lr_patch.cuda() prepare_data_time += time.time() - start_time + start_time = time.time() - optimizer.zero_grad() + pred = model(lr_patch) loss = F.mse_loss(pred/255, hr_patch/255) loss.backward() optimizer.step() + optimizer.zero_grad() + forward_backward_time += time.time() - start_time # For monitoring @@ -178,7 +186,7 @@ if __name__ == "__main__": config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i) config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format( - config.exp_dir, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step, + model.__class__.__name__, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step, forward_backward_time / config.display_step)) l_accum = [0., 0., 0.] prepare_data_time = 0. @@ -193,7 +201,5 @@ if __name__ == "__main__": config.current_iter = i valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") - - total_script_time = datetime.now() - script_start_time config.logger.info(f"Completed after {total_script_time}") \ No newline at end of file diff --git a/src/scripts/transfer_to_lut.py b/src/scripts/transfer_to_lut.py index b0af18c..e0ee5a8 100644 --- a/src/scripts/transfer_to_lut.py +++ b/src/scripts/transfer_to_lut.py @@ -20,7 +20,7 @@ class TransferToLutOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser.add_argument('--model_path', '-m', type=str, default='', help="model path folder") - self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets in 2**bits. Value is in range [1, 8].") + self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.") def parse_args(self): @@ -52,8 +52,11 @@ if __name__ == "__main__": config_inst.print_options(config) model = models.LoadCheckpoint(config.model_path).cuda() - print(model) - + if getattr(model, 'get_lut_model', None) is None: + print("Transfer to lut can be applied only to the network model.") + exit(1) + print(model) + print() print("Transfering:") lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size) diff --git a/src/scripts/validate.py b/src/scripts/validate.py index efd27cb..48c5266 100644 --- a/src/scripts/validate.py +++ b/src/scripts/validate.py @@ -24,8 +24,8 @@ class ValOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser.add_argument('--model_path', type=str, help="Model path.") - self.parser.add_argument('--datasets_dir', type=str, default="../../data/") - self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14') + self.parser.add_argument('--datasets_dir', type=str, default="../../data/", help="Path to datasets.") + self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.") self.parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') def parse_args(self):