From 14a7f002450ac847187eb2ff311d25e2ca41d084 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Fri, 19 Apr 2024 10:35:24 +0000 Subject: [PATCH] FC --- .gitignore | 160 +++++++++++++ data/.gitignore | 2 + models/.gitignore | 2 + readme.md | 12 + src/common/data.py | 112 +++++++++ src/common/lut.py | 409 +++++++++++++++++++++++++++++++++ src/common/utils.py | 112 +++++++++ src/common/validation.py | 66 ++++++ src/eval_bicubic_metrics.py | 193 ++++++++++++++++ src/models/__init__.py | 36 +++ src/models/mulut.py | 76 ++++++ src/models/munet.py | 65 ++++++ src/models/rclut.py | 394 +++++++++++++++++++++++++++++++ src/models/rcnet.py | 329 ++++++++++++++++++++++++++ src/models/srlut.py | 76 ++++++ src/models/srnet.py | 85 +++++++ src/scripts/image_demo.py | 51 ++++ src/scripts/train.py | 199 ++++++++++++++++ src/scripts/transfer_to_lut.py | 75 ++++++ src/scripts/validate.py | 86 +++++++ 20 files changed, 2540 insertions(+) create mode 100644 .gitignore create mode 100644 data/.gitignore create mode 100644 models/.gitignore create mode 100644 readme.md create mode 100644 src/common/data.py create mode 100644 src/common/lut.py create mode 100644 src/common/utils.py create mode 100644 src/common/validation.py create mode 100644 src/eval_bicubic_metrics.py create mode 100644 src/models/__init__.py create mode 100644 src/models/mulut.py create mode 100644 src/models/munet.py create mode 100644 src/models/rclut.py create mode 100644 src/models/rcnet.py create mode 100644 src/models/srlut.py create mode 100644 src/models/srnet.py create mode 100644 src/scripts/image_demo.py create mode 100644 src/scripts/train.py create mode 100644 src/scripts/transfer_to_lut.py create mode 100644 src/scripts/validate.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/models/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..5a1b5de --- /dev/null +++ b/readme.md @@ -0,0 +1,12 @@ + +``` +python train.py --train_datasets DIV2K_pillow_bicubic --val_datasets Set5,Set14 --scale 4 --total_iter 10000 --model SRNetRot90 +python transfer_to_lut.py --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth +python image_demo.py +python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth +``` + +Requierements: +- [shedulefree](https://github.com/facebookresearch/schedule_free) +- einops +- ray \ No newline at end of file diff --git a/src/common/data.py b/src/common/data.py new file mode 100644 index 0000000..4e39b6b --- /dev/null +++ b/src/common/data.py @@ -0,0 +1,112 @@ +import os +import random +import sys + +import numpy as np +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from pathlib import Path + +sys.path.insert(0, "../") # run under the project directory + +image_extensions = ['.jpg', '.png'] +def load_images_cached(images_dir_path): + image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions]) + cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_cache.npy" + if not Path(cache_path).exists(): + print("Caching to:", cache_path) + value = {f:np.array(Image.open(f)) for f in image_paths} + np.save(cache_path, value, allow_pickle=True) + else: + value = np.load(cache_path, allow_pickle=True).item() + print("Loaded cache from:", cache_path) + return list(value.keys()), list(value.values()) + +class SRTrainDataset(Dataset): + def __init__(self, hr_dir_path, lr_dir_path, patch_size, rigid_aug=True): + super(SRTrainDataset, self).__init__() + self.sz = patch_size + self.rigid_aug = rigid_aug + self.hr_image_names, self.hr_images = load_images_cached(hr_dir_path) + self.lr_image_names, self.lr_images = load_images_cached(lr_dir_path) + assert len(self.hr_images) == len(self.lr_images) + + def __getitem__(self, idx): + if isinstance(idx, slice): + batch_hr, batch_lr = [], [] + for i in range(idx.start, idx.stop): + hr_patch, lr_patch = self.__getitem__(i) + batch_hr.append(hr_patch) + batch_lr.append(lr_patch) + return batch_hr, batch_lr + else: + idx = idx % len(self.hr_images) + hr_image = self.hr_images[idx] + lr_image = self.lr_images[idx] + scale = hr_image.shape[0]//lr_image.shape[0] + + i = random.randint(0, lr_image.shape[0] - self.sz) + j = random.randint(0, lr_image.shape[1] - self.sz) + c = random.choice([0, 1, 2]) + + hr_patch = hr_image[ + (i*scale):(i*scale + self.sz*scale), + (j*scale):(j*scale + self.sz*scale), + c + ] + lr_patch = lr_image[ + i:(i + self.sz), + j:(j + self.sz), + c + ] + + if self.rigid_aug: + if random.uniform(0, 1) < 0.5: + hr_patch = np.fliplr(hr_patch) + lr_patch = np.fliplr(lr_patch) + + if random.uniform(0, 1) < 0.5: + hr_patch = np.flipud(hr_patch) + lr_patch = np.flipud(lr_patch) + + k = random.choice([0, 1, 2, 3]) + hr_patch = np.rot90(hr_patch, k) + lr_patch = np.rot90(lr_patch, k) + + hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32) + lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32) + hr_patch = hr_patch.unsqueeze(0) + lr_patch = lr_patch.unsqueeze(0) + + return hr_patch, lr_patch + + def __len__(self): + return len(self.hr_images) + +class SRTestDataset(Dataset): + def __init__(self, hr_dir_path, lr_dir_path): + super(SRTestDataset, self).__init__() + self.hr_image_paths, self.hr_images = load_images_cached(hr_dir_path) + self.lr_image_paths, self.lr_images = load_images_cached(lr_dir_path) + assert len(self.hr_images) == len(self.lr_images) + + def __getitem__(self, idx): + if isinstance(idx, slice): + batch_hr, batch_lr = [], [] + for i in range(idx.start, idx.stop): + hr_image, lr_image = self.__getitem__(i) + batch_hr.append(hr_image) + batch_lr.append(lr_image) + return batch_hr, batch_lr + else: + idx = idx % len(self.hr_images) + return self.hr_images[idx], self.lr_images[idx], self.hr_image_paths[idx], self.lr_image_paths[idx] + + def __len__(self): + return len(self.hr_images) + + def __iter__(self): + for i in range(0, len(self.hr_images)): + yield self.__getitem__(i) diff --git a/src/common/lut.py b/src/common/lut.py new file mode 100644 index 0000000..78a9cd1 --- /dev/null +++ b/src/common/lut.py @@ -0,0 +1,409 @@ + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +import torch.multiprocessing as mp +from .utils import round_func +import numpy as np + +##################### TRANSFER ########################## + +class Domain4DValues(Dataset): + def __init__(self, quantization_interval=1): + super(Domain4DValues, self).__init__() + values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8) + values1d = torch.cat([values1d, torch.tensor([255])]) + self.quantization_interval = quantization_interval + self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 2, 2) + + def __getitem__(self, idx): + if isinstance(idx, slice): + ix1s, ix2s, ix3s, ix4s, batch = [], [], [], [], [] + for i in range(idx.start, idx.stop): + ix1, ix2, ix3, ix4, values = self.__getitem__(i) + ix1s.append(ix1) + ix2s.append(ix2) + ix3s.append(ix3) + ix4s.append(ix4) + batch.append(values) + return ix1s, ix2s, ix3s, ix4s, batch + else: + v = self.values[idx] + ix = v[0]//self.quantization_interval + return ix[0,0], ix[0,1], ix[1,0], ix[1,1], v + + def __len__(self): + return len(self.values) + + def __iter__(self): + for i in range(0, len(self.values)): + yield self.__getitem__(i) + +def transfer_rc_conv(rc_conv, quantization_interval=1): + receptive_field_pixel_count = rc_conv.window_size**2 + bucket_count = 256//quantization_interval + lut = np.full((receptive_field_pixel_count, bucket_count+1), dtype=np.uint8, fill_value=255) + for pixel_id in range(receptive_field_pixel_count): + for idx, value in enumerate(range(0, 256, quantization_interval)): + inputs = torch.tensor([value]).type(torch.float32).view(1,1,1).cuda() + with torch.no_grad(): + outputs = rc_conv.pixel_wise_forward(inputs) + lut[:,idx] = outputs.flatten().cpu().numpy().astype(np.uint8) + print(f"\r {rc_conv.__class__.__name__} {pixel_id*bucket_count + idx +1}/{receptive_field_pixel_count*bucket_count}", end=" ") + print() + return lut + +def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10): + bucket_count = 256//quantization_interval + scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1 + lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, 1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2 + domain_values = Domain4DValues(quantization_interval=quantization_interval) + domain_values_loader = DataLoader( + domain_values, + batch_size=batch_size if quantization_interval >= 16 else 2**16, + pin_memory=True, + num_workers=1 if quantization_interval >= 16 else mp.cpu_count() + ) + counter = 0 + for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader): + inputs = batch.type(torch.float32).cuda() + with torch.no_grad(): + outputs = block(inputs) + lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8) #[:,:,:scale,:scale] # TODO first layer automatically pad image + counter += inputs.shape[0] + print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ") + print() + lut = lut.squeeze(-3) + return lut + + + +##################### FORWARD ########################## + +def forward_2x2_input_SxS_output(index, lut): + b,c,hs,ws = index.shape + scale = lut.shape[-1] + index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') #? + out = select_index_4dlut_tetrahedral2( + ixA = index, + ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]), + ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]), + ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]), + lut = lut + ) + out = out[:,:,0:-1,0:-1,:,:] # unpad + # Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504] + out = out.permute(0,1,2,4,3,5).reshape(b,1,hs*scale,ws*scale) + out = round_func(out) + return out + +def forward_rc_conv_centered(index, lut): + window_size = lut.shape[0] + index = F.pad(index, pad=[window_size//2]*4, mode='replicate') + window_indexes = lut.shape[:-1] + index = index.unsqueeze(-1) + x = torch.zeros_like(index) + for i in range(window_indexes[-2]): + for j in range(window_indexes[-1]): + shift_i, shift_j = -window_indexes[-2]//2+1 + i, -window_indexes[-1]//2+1 + j + shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1]) + x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j]) + x /= window_indexes[-2]*window_indexes[-1] + x = x.squeeze(-1) + x = round_func(x) + x = x[:,:,window_size//2:-window_size//2+1,window_size//2:-window_size//2+1] + return x + +def forward_rc_conv_rot90(index, lut): + window_size = lut.shape[0] + index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate') + window_indexes = lut.shape[:-1] + index = index.unsqueeze(-1) + x = torch.zeros_like(index) + for i in range(window_indexes[-2]): + for j in range(window_indexes[-1]): + shift_i, shift_j = i, j + shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1]) + x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j]) + x /= window_indexes[-2]*window_indexes[-1] + x = x.squeeze(-1) + x = round_func(x) + x = x[:,:,:-(window_size-1),:-(window_size-1)] + return x + + +##################### UTILS ########################## + +def select_index_1dlut_linear(ixA, lut): + dimA = lut.shape[0] + qA = 256/(dimA-1) + outDims = lut.shape[1:] + lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) + index_loop_indexes = ixA.shape + lut_loop_indexes = lut.shape[:-1] + ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + msbA = torch.floor_divide(ixA, qA).type(torch.int64) + msbB = torch.floor_divide(ixA, qA).type(torch.int64) + 1 + lsb_index = ixA % qA + lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) + outA = torch.gather(input=lut, dim=-1, index=msbA) + outB = torch.gather(input=lut, dim=-1, index=msbB) + out = outA + (lsb_index/qA) * (outB-outA) + out = out.squeeze(-1) + return out + +def select_index_1dlut_msb(ixA, lut): + dimA = lut.shape[0] + outDims = lut.shape[1:] + lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) + index_loop_indexes = ixA.shape + lut_loop_indexes = lut.shape[:-1] + ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + msb_index = torch.floor_divide(ixA, 256/(dimA-1)).type(torch.int64) * dimA**0 + lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) + out = torch.gather(input=lut, dim=-1, index=msb_index) + out = out.squeeze(-1) + return out + +def select_index_4dlut_msb(ixA, ixB, ixC, ixD, lut): + dimA, dimB, dimC, dimD = lut.shape[:4] + qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1) + outDims = lut.shape[4:] + lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) + index_loop_indexes = ixA.shape + lut_loop_indexes = lut.shape[:-1] + ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + msb_index = torch.floor_divide(ixA, qA) * dimA**3 + msb_index += torch.floor_divide(ixB, qB) * dimB**2 + msb_index += torch.floor_divide(ixC, qC) * dimC**1 + msb_index += torch.floor_divide(ixD, qD) * dimD**0 + lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) + out = torch.gather(input=lut, dim=-1, index=msb_index.type(torch.int64)) + out = out.squeeze(-1) + return out + +def select_index_4dlut_linear(ixA, ixB, ixC, ixD, lut): + dimA, dimB, dimC, dimD = lut.shape[:4] + qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1) + outDims = lut.shape[4:] + lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) + index_loop_indexes = ixA.shape + lut_loop_indexes = lut.shape[:-1] + lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) + ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + msb_index = torch.floor_divide(ixA, qA).type(torch.int64) * dimA**3 + msb_index += torch.floor_divide(ixB, qB).type(torch.int64) * dimB**2 + msb_index += torch.floor_divide(ixC, qC).type(torch.int64) * dimC**1 + msb_index += torch.floor_divide(ixD, qD).type(torch.int64) * dimD**0 + outA = torch.gather(input=lut, dim=-1, index=msb_index) + msb_index = (torch.floor_divide(ixA, qA).type(torch.int64) + 1) * dimA**3 + msb_index += (torch.floor_divide(ixB, qB).type(torch.int64) + 1) * dimB**2 + msb_index += (torch.floor_divide(ixC, qC).type(torch.int64) + 1) * dimC**1 + msb_index += (torch.floor_divide(ixD, qD).type(torch.int64) + 1) * dimD**0 + outB = torch.gather(input=lut, dim=-1, index=msb_index) + lsb_coef = ((ixA+ixB+ixC+ixD)/4 % qA) / qA + out = outA + lsb_coef*(outB-outA) + out = out.squeeze(-1) + return out + +def barycentric_interpolate(masks, coefs, vertices): + i = torch.all(torch.stack(masks), dim=0, keepdim = False) + coefs = torch.stack(coefs) * i + vertices = torch.stack(vertices) + out = (coefs*vertices).sum(0) + return i, out + +def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut): + dimA, dimB, dimC, dimD = lut.shape[:4] + qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1) + outDims = lut.shape[4:] + lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0) + index_loop_indexes = ixA.shape + lut_loop_indexes = lut.shape[:-1] + lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape) + ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1) + + msbA = torch.floor_divide(ixA, qA).type(torch.int64) + msbB = torch.floor_divide(ixB, qB).type(torch.int64) + msbC = torch.floor_divide(ixC, qC).type(torch.int64) + msbD = torch.floor_divide(ixD, qD).type(torch.int64) + fa, fb, fc, fd = ixA % qA, ixB % qB, ixC % qC, ixD % qD + fab, fac, fad, fbc, fbd, fcd = fa>fb, fa>fc, fa>fd, fb>fc, fb>fd, fc>fd + + strides = torch.tensor([dimA**3, dimB**2, dimC**1, dimD**0], device=lut.device).view(-1, *((1,)*len(msbA.shape))) + p0000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD ])*strides).sum(0)) + p0001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD+1])*strides).sum(0)) + p0010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD ])*strides).sum(0)) + p0011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD+1])*strides).sum(0)) + p0100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD ])*strides).sum(0)) + p0101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD+1])*strides).sum(0)) + p0110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD ])*strides).sum(0)) + p0111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD+1])*strides).sum(0)) + p1000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD ])*strides).sum(0)) + p1001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD+1])*strides).sum(0)) + p1010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD ])*strides).sum(0)) + p1011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD+1])*strides).sum(0)) + p1100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD ])*strides).sum(0)) + p1101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD+1])*strides).sum(0)) + p1110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD ])*strides).sum(0)) + p1111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD+1])*strides).sum(0)) + + i1, out1 = barycentric_interpolate([fab, fbc, fcd], [qA-fa, fa - fb, fb - fc, fc - fd, fd], [p0000, p1000, p1100, p1110, p1111]) + i2, out2 = barycentric_interpolate([fab, fbc, fbd, ~(i1)], [qA-fa, fa - fb, fb - fd, fd - fc, fc], [p0000, p1000, p1100, p1101, p1111]) + i3, out3 = barycentric_interpolate([fab, fbc, fad, ~(i1), ~(i2)], [qA-fa, fa - fd, fd - fb, fb - fc, fc], [p0000, p1000, p1001, p1101, p1111]) + i4, out4 = barycentric_interpolate([fab, fbc, ~(i1), ~(i2), ~(i3)], [qA-fd, fd - fa, fa - fb, fb - fc, fc], [p0000, p0001, p1001, p1101, p1111]) + i5, out5 = barycentric_interpolate([fab, fac, fbd, ~(fbc)], [qA-fa, fa - fc, fc - fb, fb - fd, fd], [p0000, p1000, p1010, p1110, p1111]) + i6, out6 = barycentric_interpolate([fab, fac, fcd, ~(fbc), ~(i5)], [qA-fa, fa - fc, fc - fd, fd - fb, fb], [p0000, p1000, p1010, p1011, p1111]) + i7, out7 = barycentric_interpolate([fab, fac, fad, ~(fbc), ~(i5), ~(i6)], [qA-fa, fa - fd, fd - fc, fc - fb, fb], [p0000, p1000, p1001, p1011, p1111]) + i8, out8 = barycentric_interpolate([fab, fac, ~(fbc), ~(i5), ~(i6), ~(i7)], [qA-fd, fd - fa, fa - fc, fc - fb, fb], [p0000, p0001, p1001, p1011, p1111]) + i9, out9 = barycentric_interpolate([fab, fbd, ~(fbc), ~(fac)], [qA-fc, fc - fa, fa - fb, fb - fd, fd], [p0000, p0010, p1010, p1110, p1111]) + i10, out10 = barycentric_interpolate([fab, fad, ~(fbc), ~(fac), ~(i9)], [qA-fc, fc - fa, fa - fd, fd - fb, fb], [p0000, p0010, p1010, p1011, p1111]) + i11, out11 = barycentric_interpolate([fab, fcd, ~(fbc), ~(fac), ~(i9), ~(i10)], [qA-fc, fc - fd, fd - fa, fa - fb, fb], [p0000, p0010, p0011, p1011, p1111]) + i12, out12 = barycentric_interpolate([fab, ~(fbc), ~(fac), ~(i9), ~(i10), ~(i11)], [qA-fd, fd - fc, fc - fa, fa - fb, fb], [p0000, p0001, p0011, p1011, p1111]) + + i13, out13 = barycentric_interpolate([fac, fcd, ~(fab)], [qA-fb, fb - fa, fa - fc, fc - fd, fd], [p0000, p0100, p1100, p1110, p1111]) + i14, out14 = barycentric_interpolate([fac, fad, ~(fab), ~(i13)], [qA-fb, fb - fa, fa - fd, fd - fc, fc], [p0000, p0100, p1100, p1101, p1111]) + i15, out15 = barycentric_interpolate([fac, fbd, ~(fab), ~(i13), ~(i14)], [qA-fb, fb - fd, fd - fa, fa - fc, fc], [p0000, p0100, p0101, p1101, p1111]) + i16, out16 = barycentric_interpolate([fac, ~(fab), ~(i13), ~(i14), ~(i15) ], [qA-fd, fd - fb, fb - fa, fa - fc, fc], [p0000, p0001, p0101, p1101, p1111]) + i17, out17 = barycentric_interpolate([fbc, fad, ~(fab), ~(fac)], [qA-fb, fb - fc, fc - fa, fa - fd, fd], [p0000, p0100, p0110, p1110, p1111]) + i18, out18 = barycentric_interpolate([fbc, fcd, ~(fab), ~(fac), ~(i17)], [qA-fb, fb - fc, fc - fd, fd - fa, fa], [p0000, p0100, p0110, p0111, p1111]) + i19, out19 = barycentric_interpolate([fbc, fbd, ~(fab), ~(fac), ~(i17), ~(i18)], [qA-fb, fb - fd, fd - fc, fc - fa, fa], [p0000, p0100, p0101, p0111, p1111]) + i20, out20 = barycentric_interpolate([fbc, ~(fab), ~(fac), ~(i17), ~(i18), ~(i19)], [qA-fd, fd - fb, fb - fc, fc - fa, fa], [p0000, p0001, p0101, p0111, p1111]) + i21, out21 = barycentric_interpolate([fad, ~(fab), ~(fac), ~(fbc) ], [qA-fc, fc - fb, fb - fa, fa - fd, fd], [p0000, p0010, p0110, p1110, p1111]) + i22, out22 = barycentric_interpolate([fbd, ~(fab), ~(fac), ~(fbc), ~(i21)], [qA-fc, fc - fb, fb - fd, fd - fa, fa], [p0000, p0010, p0110, p0111, p1111]) + i23, out23 = barycentric_interpolate([fcd, ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22)], [qA-fc, fc - fd, fd - fb, fb - fa, fa], [p0000, p0010, p0011, p0111, p1111]) + i24, out24 = barycentric_interpolate([ ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22), ~(i23)], [qA-fd, fd - fc, fc - fb, fb - fa, fa], [p0000, p0001, p0011, p0111, p1111]) + + out = out1 + out2 + out3 + out4 + out5 + out6 + out7 + out8 + out9 + out10 + out11 + out12 + out13 + out14 + out15 + out16 + out17 + out18 + out19 + out20 + out21 + out22 + out23 + out24 + out /= qA + out = out.squeeze(-1) + return out + + +def select_index_4dlut_tetrahedral2(ixA, ixB, ixC, ixD, lut): #self, weight, upscale, mode, img_in, bd): + dimA, dimB, dimC, dimD = lut.shape[:4] + q = 256/(dimA-1) + L = dimA + upscale = lut.shape[-1] + weight = lut + + img_a1 = torch.floor_divide(ixA, q).type(torch.int64) + img_b1 = torch.floor_divide(ixB, q).type(torch.int64) + img_c1 = torch.floor_divide(ixC, q).type(torch.int64) + img_d1 = torch.floor_divide(ixD, q).type(torch.int64) + + # Extract LSBs + fa = ixA % q + fb = ixB % q + fc = ixC % q + fd = ixD % q + + img_a2 = img_a1 + 1 + img_b2 = img_b1 + 1 + img_c2 = img_c1 + 1 + img_d2 = img_d1 + 1 + + p0000 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p0001 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p0010 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p0011 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p0100 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p0101 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p0110 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p0111 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + + p1000 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p1001 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p1010 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p1011 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p1100 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p1101 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p1110 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + p1111 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + + out = torch.zeros((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale), dtype=weight.dtype).to(device=weight.device) + sz = img_a1.shape[0] * img_a1.shape[1] * img_a1.shape[2] * img_a1.shape[3] + out = out.reshape(sz, -1) + + p0000 = p0000.reshape(sz, -1) + p0100 = p0100.reshape(sz, -1) + p1000 = p1000.reshape(sz, -1) + p1100 = p1100.reshape(sz, -1) + fa = fa.reshape(-1, 1) + + p0001 = p0001.reshape(sz, -1) + p0101 = p0101.reshape(sz, -1) + p1001 = p1001.reshape(sz, -1) + p1101 = p1101.reshape(sz, -1) + fb = fb.reshape(-1, 1) + fc = fc.reshape(-1, 1) + + p0010 = p0010.reshape(sz, -1) + p0110 = p0110.reshape(sz, -1) + p1010 = p1010.reshape(sz, -1) + p1110 = p1110.reshape(sz, -1) + fd = fd.reshape(-1, 1) + + p0011 = p0011.reshape(sz, -1) + p0111 = p0111.reshape(sz, -1) + p1011 = p1011.reshape(sz, -1) + p1111 = p1111.reshape(sz, -1) + + fab = fa > fb; + fac = fa > fc; + fad = fa > fd + + fbc = fb > fc; + fbd = fb > fd; + fcd = fc > fd + + i1 = i = torch.all(torch.cat([fab, fbc, fcd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i] + i2 = i = torch.all(torch.cat([~i1[:, None], fab, fbc, fbd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i] + i3 = i = torch.all(torch.cat([~i1[:, None], ~i2[:, None], fab, fbc, fad], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i] + i4 = i = torch.all(torch.cat([~i1[:, None], ~i2[:, None], ~i3[:, None], fab, fbc], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i] + + i5 = i = torch.all(torch.cat([~(fbc), fab, fac, fbd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i] + i6 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], fab, fac, fcd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] + i7 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], ~i6[:, None], fab, fac, fad], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] + i8 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] + + i9 = i = torch.all(torch.cat([~(fbc), ~(fac), fab, fbd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i] + # Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first! + # i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], fab, fcd], dim=1), dim=1) + # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i] + # i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad], dim=1), dim=1) + # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i] + # c > a > d > b + i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], fab, fad], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] + # c > d > a > b + i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] + i12 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i] + + i13 = i = torch.all(torch.cat([~(fab), fac, fcd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i] + i14 = i = torch.all(torch.cat([~(fab), ~i13[:, None], fac, fad], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i] + i15 = i = torch.all(torch.cat([~(fab), ~i13[:, None], ~i14[:, None], fac, fbd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i] + i16 = i = torch.all(torch.cat([~(fab), ~i13[:, None], ~i14[:, None], ~i15[:, None], fac], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i] + + i17 = i = torch.all(torch.cat([~(fab), ~(fac), fbc, fad], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i] + i18 = i = torch.all(torch.cat([~(fab), ~(fac), ~i17[:, None], fbc, fcd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] + i19 = i = torch.all(torch.cat([~(fab), ~(fac), ~i17[:, None], ~i18[:, None], fbc, fbd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] + i20 = i = torch.all(torch.cat([~(fab), ~(fac), ~i17[:, None], ~i18[:, None], ~i19[:, None], fbc], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] + + i21 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), fad], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i] + i22 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], fbd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] + i23 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] + i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i] + + out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale)) + out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale)) + out = out / q + return out \ No newline at end of file diff --git a/src/common/utils.py b/src/common/utils.py new file mode 100644 index 0000000..01a4e66 --- /dev/null +++ b/src/common/utils.py @@ -0,0 +1,112 @@ +import logging + +import cv2 +import numpy as np +from scipy import signal +import torch +import os + +def round_func(input): + # Backward Pass Differentiable Approximation (BPDA) + # This is equivalent to replacing round function (non-differentiable) + # with an identity function (differentiable) only when backward, + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + +def logger_info(logger_name, log_path='default_logger.log'): + log = logging.getLogger(logger_name) + if log.hasHandlers(): + print('LogHandlers exist!') + else: + print('LogHandlers setup!') + level = logging.INFO + formatter = logging.Formatter( + '%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') + fh = logging.FileHandler(log_path, mode='a') + fh.setFormatter(formatter) + log.setLevel(level) + log.addHandler(fh) + # print(len(log.handlers)) + + sh = logging.StreamHandler() + sh.setFormatter(formatter) + log.addHandler(sh) + + +def modcrop(image, modulo): + if len(image.shape) == 2: + sz = image.shape + sz = sz - np.mod(sz, modulo) + image = image[0:sz[0], 0:sz[1]] + elif image.shape[2] == 3: + sz = image.shape[0:2] + sz = sz - np.mod(sz, modulo) + image = image[0:sz[0], 0:sz[1], :] + else: + raise NotImplementedError + return image + + +def _rgb2ycbcr(img, maxVal=255): + O = np.array([[16], + [128], + [128]]) + T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941], + [-0.148223529411765, -0.290992156862745, 0.439215686274510], + [0.439215686274510, -0.367788235294118, -0.071427450980392]]) + + if maxVal == 1: + O = O / 255.0 + + t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2])) + t = np.dot(t, np.transpose(T)) + t[:, 0] += O[0] + t[:, 1] += O[1] + t[:, 2] += O[2] + ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]]) + + return ycbcr + + +def PSNR(y_true, y_pred, shave_border=4): + target_data = np.array(y_true, dtype=np.float32) + ref_data = np.array(y_pred, dtype=np.float32) + + diff = ref_data - target_data + if shave_border > 0: + diff = diff[shave_border:-shave_border, shave_border:-shave_border] + rmse = np.sqrt(np.mean(np.power(diff, 2))) + + return 20 * np.log10(255. / rmse) + + +def cal_ssim(img1, img2): + K = [0.01, 0.03] + L = 255 + kernelX = cv2.getGaussianKernel(11, 1.5) + window = kernelX * kernelX.T + + M, N = np.shape(img1) + + C1 = (K[0] * L) ** 2 + C2 = (K[1] * L) ** 2 + img1 = np.float64(img1) + img2 = np.float64(img2) + + mu1 = signal.convolve2d(img1, window, 'valid') + mu2 = signal.convolve2d(img2, window, 'valid') + + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + + sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq + sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq + sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + mssim = np.mean(ssim_map) + return mssim + diff --git a/src/common/validation.py b/src/common/validation.py new file mode 100644 index 0000000..2772a80 --- /dev/null +++ b/src/common/validation.py @@ -0,0 +1,66 @@ +import torch +import numpy as np +from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop +from pathlib import Path +from PIL import Image +import ray + +@ray.remote(num_cpus=1, num_gpus=0.3) +def val_image_pair(model, hr_image, lr_image, output_image_path=None): + with torch.no_grad(): + # prepare lr_image + lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1) + lr_image = lr_image.unsqueeze(0).cuda() + b, c, h, w = lr_image.shape + lr_image = lr_image.reshape(b*c, 1, h, w) + # predict + pred_lr_image = model(lr_image) + # postprocess + pred_lr_image = pred_lr_image.reshape(b, c, h*model.scale, w*model.scale).squeeze(0).permute(1,2,0).type(torch.uint8) + pred_lr_image = pred_lr_image.cpu().numpy() + torch.cuda.empty_cache() + + if not output_image_path is None: + Image.fromarray(pred_lr_image).save(output_image_path) + + # metrics + hr_image = modcrop(hr_image, model.scale) + left, right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] + torch.cuda.empty_cache() + return PSNR(left, right, model.scale), cal_ssim(left, right) + +def valid_steps(model, datasets, config, log_prefix=""): + ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"}) + dataset_names = list(datasets.keys()) + + for i in range(len(dataset_names)): + dataset_name = dataset_names[i] + psnrs, ssims = [], [] + + predictions_path = config.valout_dir / dataset_name + if not predictions_path.exists(): + predictions_path.mkdir() + + test_dataset = datasets[dataset_name] + tasks = [] + for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: + output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_val_predictions else None + task = val_image_pair.remote(model, hr_image, lr_image, output_image_path) + tasks.append(task) + + ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None) + while len(remaining_refs) > 0: + print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ") + ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None) + print("\r", end=" ") + tasks = [ray.get(task) for task in tasks] + + for psnr, ssim in tasks: + psnrs.append(psnr) + ssims.append(ssim) + + config.logger.info( + '{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims)))) + config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter) + config.writer.flush() + print() \ No newline at end of file diff --git a/src/eval_bicubic_metrics.py b/src/eval_bicubic_metrics.py new file mode 100644 index 0000000..18b0dc3 --- /dev/null +++ b/src/eval_bicubic_metrics.py @@ -0,0 +1,193 @@ +import os +from pathlib import Path +import numpy as np +import cv2 +from scipy import signal +from skimage.metrics import structural_similarity +from PIL import Image +import argparse + +import time +from datetime import datetime +import ray +ray.init(num_cpus=16, num_gpus=0, ignore_reinit_error=True, log_to_driver=False) + +parser = argparse.ArgumentParser() +parser.add_argument("path_to_dataset", type=str) +parser.add_argument("--scale", type=int, default=4) +args = parser.parse_args() + +def cal_ssim(img1, img2): + K = [0.01, 0.03] + L = 255 + kernelX = cv2.getGaussianKernel(11, 1.5) + window = kernelX * kernelX.T + + M, N = np.shape(img1) + + C1 = (K[0] * L) ** 2 + C2 = (K[1] * L) ** 2 + img1 = np.float64(img1) + img2 = np.float64(img2) + + mu1 = signal.convolve2d(img1, window, 'valid') + mu2 = signal.convolve2d(img2, window, 'valid') + + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + + sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq + sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq + sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + mssim = np.mean(ssim_map) + return mssim + +def PSNR(y_true, y_pred, shave_border=4): + target_data = np.array(y_true, dtype=np.float32) + ref_data = np.array(y_pred, dtype=np.float32) + + diff = ref_data - target_data + if shave_border > 0: + diff = diff[shave_border:-shave_border, shave_border:-shave_border] + rmse = np.sqrt(np.mean(np.power(diff, 2))) + + return 20 * np.log10(255. / rmse) + +def _rgb2ycbcr(img, maxVal=255): + O = np.array([[16], + [128], + [128]]) + T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941], + [-0.148223529411765, -0.290992156862745, 0.439215686274510], + [0.439215686274510, -0.367788235294118, -0.071427450980392]]) + + if maxVal == 1: + O = O / 255.0 + + t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2])) + t = np.dot(t, np.transpose(T)) + t[:, 0] += O[0] + t[:, 1] += O[1] + t[:, 2] += O[2] + ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]]) + + return ycbcr + + +def modcrop(image, modulo): + if len(image.shape) == 2: + sz = image.shape + sz = sz - np.mod(sz, modulo) + image = image[0:sz[0], 0:sz[1]] + elif image.shape[2] == 3: + sz = image.shape[0:2] + sz = sz - np.mod(sz, modulo) + image = image[0:sz[0], 0:sz[1], :] + else: + raise NotImplementedError + return image + +scale = args.scale + +dataset_path = Path(args.path_to_dataset) +hr_path = dataset_path / "HR/" +lr_path = dataset_path / f"LR_bicubic/X{scale}/" + + +print(hr_path, lr_path) + +hr_files = os.listdir(hr_path) +lr_files = os.listdir(lr_path) + +@ray.remote(num_cpus=1) +def benchmark_image_pair(hr_image_path, lr_image_path, interpolation_function): + hr_image = cv2.imread(hr_image_path) + lr_image = cv2.imread(lr_image_path) + + hr_image = hr_image[:,:,::-1] # BGR -> RGB + lr_image = lr_image[:,:,::-1] # BGR -> RGB + + start_time = datetime.now() + upscaled_lr_image = interpolation_function(lr_image, scale) + processing_time = datetime.now() - start_time + + hr_image = modcrop(hr_image, scale) + upscaled_lr_image = upscaled_lr_image + + psnr = PSNR(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0]) + cpsnr = PSNR(hr_image, upscaled_lr_image) + + cv2_psnr = cv2.PSNR(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0]) + cv2_cpsnr = cv2.PSNR(hr_image, upscaled_lr_image) + + ssim = cal_ssim(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0]) + cv2_ssim = cal_ssim(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0]) + ssim_scikit, diff = structural_similarity(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0], full=True, data_range=255.0) + cv2_scikit_ssim, diff = structural_similarity(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], full=True, data_range=255.0) + + return ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time.total_seconds() + + +def benchmark_interpolation(interpolation_function): + psnrs, cpsnrs, ssims = [], [], [] + cv2_psnrs, cv2_cpsnrs, scikit_ssims = [], [], [] + cv2_scikit_ssims = [] + cv2_ssims = [] + tasks = [] + for hr_name, lr_name in zip(hr_files, lr_files): + hr_image_path = str(hr_path / hr_name) + lr_image_path = str(lr_path / lr_name) + tasks.append(benchmark_image_pair.remote(hr_image_path, lr_image_path, interpolation_function)) + + ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None) + while len(remaining_refs) > 0: + print(f"\rReady {len(ready_refs)}/{len(hr_files)}", end=" ") + ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None) + + for task in tasks: + ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time = ray.get(task) + ssims.append(ssim) + cv2_ssims.append(cv2_ssim) + scikit_ssims.append(ssim_scikit) + cv2_scikit_ssims.append(cv2_scikit_ssim) + psnrs.append(psnr) + cpsnrs.append(cpsnr) + cv2_psnrs.append(cv2_psnr) + cv2_cpsnrs.append(cv2_cpsnr) + processing_times.append(processing_time) + + print() + print(f"AVG PSNR: {np.mean(psnrs):.2f} PSNR + _rgb2ycbcr") + print(f"AVG PSNR: {np.mean(cv2_psnrs):.2f} cv2.PSNR + cv2.cvtColor") + print(f"AVG cPSNR: {np.mean(cpsnrs):.2f} PSNR") + print(f"AVG cPSNR: {np.mean(cv2_cpsnrs):.2f} cv2.PSNR ") + print(f"AVG SSIM: {np.mean(ssims):.4f} cal_ssim + _rgb2ycbcr") + print(f"AVG SSIM: {np.mean(cv2_ssims):.4f} cal_ssim + cv2.cvtColor") + print(f"AVG SSIM: {np.mean(scikit_ssims):.4f} structural_similarity + _rgb2ycbcr") + print(f"AVG SSIM: {np.mean(cv2_scikit_ssims):.4f} structural_similarity + cv2.cvtColor") + print(f"AVG Time s: {np.percentile(processing_times, q=0.9)}") + print(f"{np.mean(psnrs):.2f},{np.mean(cv2_psnrs):.2f},{np.mean(cpsnrs):.2f},{np.mean(cv2_cpsnrs):.2f},{np.mean(ssims):.4f},{np.mean(cv2_ssims):.4f},{np.mean(scikit_ssims):.4f},{np.mean(cv2_scikit_ssims):.4f},{np.percentile(processing_times, q=0.9)}") + +def cv2_interpolation(image, scale): + scaled_image = cv2.resize( + image, + None, None, + fx=scale, fy=scale, + interpolation=cv2.INTER_CUBIC + ) + return scaled_image + +def pillow_interpolation(image, scale): + image = Image.fromarray(image[:,:,::-1]) + width, height = int(image.width * scale), int(image.height * scale) + scaled_image = image.resize((width, height), resample=Image.Resampling.BICUBIC) + return np.array(scaled_image)[:,:,::-1] + +print("cv2 bicubic interpolation") +benchmark_interpolation(cv2_interpolation) +print() +print("pillow bicubic interpolation") +benchmark_interpolation(pillow_interpolation) \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..90dfa6c --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,36 @@ +from .rcnet import RCNetCentered_3x3, RCNetCentered_7x7, RCNetRot90_3x3, RCNetRot90_7x7, RCNetx1, RCNetx2 +from .rclut import RCLutCentered_3x3, RCLutCentered_7x7, RCLutRot90_3x3, RCLutRot90_7x7, RCLutx1, RCLutx2 +from .srnet import SRNet, SRNetRot90 +from .srlut import SRLut, SRLutRot90 +import torch +import numpy as np +from pathlib import Path + +AVAILABLE_MODELS = { + 'SRNet': SRNet, 'SRLut': SRLut, + 'SRNetRot90': SRNetRot90, 'SRLutRot90': SRLutRot90, + 'RCNetCentered_3x3': RCNetCentered_3x3, 'RCLutCentered_3x3': RCLutCentered_3x3, + 'RCNetCentered_7x7': RCNetCentered_7x7, 'RCLutCentered_7x7': RCLutCentered_7x7, + 'RCNetRot90_3x3': RCNetRot90_3x3, 'RCLutRot90_3x3': RCLutRot90_3x3, + 'RCNetRot90_7x7': RCNetRot90_7x7, 'RCLutRot90_7x7': RCLutRot90_7x7, + 'RCNetx1': RCNetx1, 'RCLutx1': RCLutx1, + 'RCNetx2': RCNetx2, 'RCLutx2': RCLutx2, +} + +def SaveCheckpoint(model, path): + model_container = { + 'model': model.__class__.__name__, + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + } + torch.save(model_container, path) + +def LoadCheckpoint(model_path): + model_path = Path(model_path).absolute() + if model_path.exists(): + model_container = torch.load(model_path) + model = AVAILABLE_MODELS[model_container['model']](**{k:v for k,v in model_container.items() if k != "model" and k != "state_dict"}) + model.load_state_dict(model_container['state_dict'], strict=True) + return model + else: + raise Exception(f"Path {model_path} does not exist.") \ No newline at end of file diff --git a/src/models/mulut.py b/src/models/mulut.py new file mode 100644 index 0000000..ae0269e --- /dev/null +++ b/src/models/mulut.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from pathlib import Path +from common.lut import forward_2x2_input_SxS_output + +class SRLut2x2(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(SRLut2x2, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + + @staticmethod + def init_from_lut( + stage_lut + ): + scale = int(stage_lut.shape[-1]) + quantization_interval = 256//(stage_lut.shape[0]-1) + lut_model = SRLut2x2(quantization_interval=quantization_interval, scale=scale) + lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut) + x = x.view(b, c, x.shape[-2], x.shape[-1]) + return x + + def __repr__(self): + return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" + + + +class SRLut3x3(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(SRLut3x3, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + + @staticmethod + def init_from_lut( + stage_lut + ): + scale = int(stage_lut.shape[-1]) + quantization_interval = 256//(stage_lut.shape[0]-1) + lut_model = SRLut3x3(quantization_interval=quantization_interval, scale=scale) + lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut) + unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) + output += unrotated_prediction + output /= 4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + + def __repr__(self): + return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" \ No newline at end of file diff --git a/src/models/munet.py b/src/models/munet.py new file mode 100644 index 0000000..1d0cd32 --- /dev/null +++ b/src/models/munet.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from common.utils import round_func +from common import lut +from pathlib import Path +# from .mulut import MuLutx1, MuLutx2 + +# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. – 2017. – С. 4700-4708. +# https://ar5iv.labs.arxiv.org/html/1608.06993 +# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40 +class DenseConvUpscaleBlock(nn.Module): + def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1): + super(DenseConvUpscaleBlock, self).__init__() + assert layers_count > 0 + self.upscale_factor = upscale_factor + + self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True) + self.convs = [] + for i in range(layers_count): + self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)) + self.convs = nn.ModuleList(self.convs) + + for name, p in self.named_parameters(): + if "weight" in name: nn.init.kaiming_normal_(p) + if "bias" in name: nn.init.constant_(p, 0) + + self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True) + self.shuffle = nn.PixelShuffle(upscale_factor) + + def forward(self, x): + x = (x-127.5)/127.5 + x = torch.relu(self.percieve(x)) + for conv in self.convs: + x = torch.cat([x, torch.relu(conv(x))], dim=1) + x = self.shuffle(self.project_channels(x)) + x = torch.tanh(x) + x = round_func(x*127.5 + 127.5) + return x + +class MuNetx1(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SRNet2x2, self).__init__() + self.scale = scale + self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + rotated_padded = F.pad(rotated, pad=[0,1,0,1], mode='replicate') + rotated_prediction = self.stage(rotated_padded) + unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) + output += unrotated_prediction + output /= 4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = MuLutx1.init_from_lut(stage_lut) + return lut_model \ No newline at end of file diff --git a/src/models/rclut.py b/src/models/rclut.py new file mode 100644 index 0000000..6909b71 --- /dev/null +++ b/src/models/rclut.py @@ -0,0 +1,394 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from common.utils import round_func +from common.lut import select_index_1dlut_msb, select_index_4dlut_msb, select_index_4dlut_tetrahedral, select_index_1dlut_linear, \ + forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output +from pathlib import Path +from einops import repeat + +class RCLutCentered_3x3(nn.Module): + def __init__( + self, + window_size, + quantization_interval, + scale + ): + super(RCLutCentered_3x3, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.window_size = window_size + self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + + @staticmethod + def init_from_lut( + rc_conv_luts, dense_conv_lut + ): + scale = int(dense_conv_lut.shape[-1]) + quantization_interval = 256//(dense_conv_lut.shape[0]-1) + window_size = rc_conv_luts.shape[0] + lut_model = RCLutCentered_3x3(window_size=window_size, quantization_interval=quantization_interval, scale=scale) + lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) + lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') + x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts) + x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1] + x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) + x = x.view(b, c, x.shape[-2], x.shape[-1]) + return x + + def __repr__(self): + return "\n".join([ + f"{self.__class__.__name__}(", + f" rc_conv_luts size: {self.rc_conv_luts.shape}", + f" dense_conv_lut size: {self.dense_conv_lut.shape}", + ")"]) + +class RCLutCentered_7x7(nn.Module): + def __init__( + self, + window_size, + quantization_interval, + scale + ): + super(RCLutCentered_7x7, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.window_size = window_size + self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + + @staticmethod + def init_from_lut( + rc_conv_luts, dense_conv_lut + ): + scale = int(dense_conv_lut.shape[-1]) + quantization_interval = 256//(dense_conv_lut.shape[0]-1) + window_size = rc_conv_luts.shape[0] + lut_model = RCLutCentered_7x7(window_size=window_size, quantization_interval=quantization_interval, scale=scale) + lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) + lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts) + x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) + # x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4) + x = x.view(b, c, x.shape[-2], x.shape[-1]) + return x + + def __repr__(self): + return "\n".join([ + f"{self.__class__.__name__}(", + f" rc_conv_luts size: {self.rc_conv_luts.shape}", + f" dense_conv_lut size: {self.dense_conv_lut.shape}", + ")"]) + +class RCLutRot90_3x3(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(RCLutRot90_3x3, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + window_size = 3 + self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + + @staticmethod + def init_from_lut( + rc_conv_luts, dense_conv_lut + ): + scale = int(dense_conv_lut.shape[-1]) + quantization_interval = 256//(dense_conv_lut.shape[0]-1) + window_size = rc_conv_luts.shape[0] + lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale) + lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) + lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts) + rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) + unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) + output += unrotated_prediction + output /= 4 + output = output.view(b, c, output.shape[-2], output.shape[-1]) + return output + + def __repr__(self): + return "\n".join([ + f"{self.__class__.__name__}(", + f" rc_conv_luts size: {self.rc_conv_luts.shape}", + f" dense_conv_lut size: {self.dense_conv_lut.shape}", + ")"]) + +class RCLutRot90_7x7(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(RCLutRot90_7x7, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + window_size = 7 + self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + + @staticmethod + def init_from_lut( + rc_conv_luts, dense_conv_lut + ): + scale = int(dense_conv_lut.shape[-1]) + quantization_interval = 256//(dense_conv_lut.shape[0]-1) + window_size = rc_conv_luts.shape[0] + lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale) + lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32)) + lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts) + rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) + unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) + output += unrotated_prediction + output /= 4 + output = output.view(b, c, output.shape[-2], output.shape[-1]) + return output + + def __repr__(self): + return "\n".join([ + f"{self.__class__.__name__}(", + f" rc_conv_luts size: {self.rc_conv_luts.shape}", + f" dense_conv_lut size: {self.dense_conv_lut.shape}", + ")"]) + +class RCLutx1(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(RCLutx1, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + window_size = 3 + self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + window_size = 5 + self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + window_size = 7 + self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32)) + self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + + @staticmethod + def init_from_lut( + rc_conv_luts_3x3, dense_conv_lut_3x3, + rc_conv_luts_5x5, dense_conv_lut_5x5, + rc_conv_luts_7x7, dense_conv_lut_7x7 + ): + scale = int(dense_conv_lut_3x3.shape[-1]) + quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1) + + lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale) + + lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32)) + lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32)) + + lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32)) + lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32)) + + lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32)) + lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32)) + + return lut_model + + def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): + x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) + x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) + print("lut:", x.min(), x.max(), x.mean()) + return x + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7), + k=-rotations_count, + dims=[2, 3] + ) + output /= 3*4 + output = output.view(b, c, output.shape[-2], output.shape[-1]) + return output + + def __repr__(self): + return "\n".join([ + f"{self.__class__.__name__}(", + f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}", + f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}", + f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}", + f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}", + f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}", + f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}", + ")"]) + + + +class RCLutx2(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(RCLutx2, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) + self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) + self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) + self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) + self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) + self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32)) + self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32)) + self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32)) + self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32)) + self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + + @staticmethod + def init_from_lut( + s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3, + s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5, + s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7, + s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3, + s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5, + s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7 + ): + scale = int(s2_dense_conv_lut_3x3.shape[-1]) + quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1) + + lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale) + + lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32)) + lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32)) + + lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32)) + lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32)) + + lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32)) + lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32)) + + lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32)) + lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32)) + + lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32)) + lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32)) + + lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32)) + lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32)) + + return lut_model + + def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut): + x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut) + x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut) + return x + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), + k=-rotations_count, + dims=[2, 3] + ) + output /= 3*4 + x = output + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), + k=-rotations_count, + dims=[2, 3] + ) + output += torch.rot90( + self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), + k=-rotations_count, + dims=[2, 3] + ) + output /= 3*4 + output = output.view(b, c, output.shape[-2], output.shape[-1]) + return output + + def __repr__(self): + return "\n".join([ + f"{self.__class__.__name__}(", + f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}", + f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}", + f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}", + f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}", + f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}", + f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}", + f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}", + f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}", + f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}", + f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}", + f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}", + f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}", + ")"]) \ No newline at end of file diff --git a/src/models/rcnet.py b/src/models/rcnet.py new file mode 100644 index 0000000..3764d2c --- /dev/null +++ b/src/models/rcnet.py @@ -0,0 +1,329 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from common.utils import round_func +from pathlib import Path +from common import lut +from . import rclut + +# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. – 2017. – С. 4700-4708. +# https://ar5iv.labs.arxiv.org/html/1608.06993 +# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40 +class DenseConvUpscaleBlock(nn.Module): + def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1): + super(DenseConvUpscaleBlock, self).__init__() + assert layers_count > 0 + self.upscale_factor = upscale_factor + + self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True) + self.convs = [] + for i in range(layers_count): + self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)) + self.convs = nn.ModuleList(self.convs) + + for name, p in self.named_parameters(): + if "weight" in name: nn.init.kaiming_normal_(p) + if "bias" in name: nn.init.constant_(p, 0) + + self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True) + self.shuffle = nn.PixelShuffle(upscale_factor) + + def forward(self, x): + x = (x-127.5)/127.5 + x = torch.relu(self.percieve(x)) + for conv in self.convs: + x = torch.cat([x, torch.relu(conv(x))], dim=1) + x = self.shuffle(self.project_channels(x)) + x = torch.tanh(x) + x = round_func(x*127.5 + 127.5) + return x + +class ReconstructedConvCentered(nn.Module): + def __init__(self, hidden_dim, window_size=7): + super(ReconstructedConvCentered, self).__init__() + self.window_size = window_size + self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + + def pixel_wise_forward(self, x): + x = (x-127.5)/127.5 + out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) + out = torch.tanh(out) + out = out*127.5 + 127.5 + return out + + def forward(self, x): + original_shape = x.shape + x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') + x = F.unfold(x, self.window_size) + x = self.pixel_wise_forward(x) + x = x.mean(1) + x = x.reshape(*original_shape) + x = round_func(x) + return x + + def __repr__(self): + return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" + +class RCBlockCentered(nn.Module): + def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): + super(RCBlockCentered, self).__init__() + self.window_size = window_size + self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size) + self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + + def forward(self, x): + b,c,hs,ws = x.shape + x = self.rc_conv(x) + x = F.pad(x, pad=[0,1,0,1], mode='replicate') + x = self.dense_conv_block(x) + return x + +class RCNetCentered_3x3(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetCentered_3x3, self).__init__() + self.hidden_dim = hidden_dim + self.layers_count = layers_count + self.scale = scale + window_size = 3 + self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size) + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + x = self.stage(x) + x = x.view(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + window_size = self.stage.rc_conv.window_size + rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) + dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = rclut.RCLutCentered_3x3.init_from_lut(rc_conv_luts, dense_conv_lut) + return lut_model + +class RCNetCentered_7x7(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetCentered_7x7, self).__init__() + self.hidden_dim = hidden_dim + self.layers_count = layers_count + self.scale = scale + window_size = 7 + self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size) + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + x = self.stage(x) + x = x.view(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + window_size = self.stage.rc_conv.window_size + rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) + dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = rclut.RCLutCentered_7x7.init_from_lut(rc_conv_luts, dense_conv_lut) + return lut_model + + +class ReconstructedConvRot90(nn.Module): + def __init__(self, hidden_dim, window_size=7): + super(ReconstructedConvRot90, self).__init__() + self.window_size = window_size + self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size) + + def pixel_wise_forward(self, x): + x = (x-127.5)/127.5 + out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) + out = torch.tanh(out) + out = out*127.5 + 127.5 + return out + + def forward(self, x): + original_shape = x.shape + x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') + x = F.unfold(x, self.window_size) + x = self.pixel_wise_forward(x) + x = x.mean(1) + x = x.reshape(*original_shape) + x = round_func(x) # quality likely suffer from this + return x + + def __repr__(self): + return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}" + +class RCBlockRot90(nn.Module): + def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4): + super(RCBlockRot90, self).__init__() + self.window_size = window_size + self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size) + self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + + def forward(self, x): + b,c,hs,ws = x.shape + x = self.rc_conv(x) + x = F.pad(x, pad=[0,1,0,1], mode='replicate') + x = self.dense_conv_block(x) + + return x + +class RCNetRot90_3x3(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetRot90_3x3, self).__init__() + self.hidden_dim = hidden_dim + self.layers_count = layers_count + self.scale = scale + self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + window_size = self.stage.rc_conv.window_size + rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) + dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = rclut.RCLutRot90_3x3.init_from_lut(rc_conv_luts, dense_conv_lut) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + rotated_prediction = self.stage(rotated) + unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) + output += unrotated_prediction + output /= 4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + +class RCNetRot90_7x7(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetRot90_7x7, self).__init__() + self.hidden_dim = hidden_dim + self.layers_count = layers_count + self.scale = scale + self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + window_size = self.stage.rc_conv.window_size + rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1) + dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = rclut.RCLutRot90_7x7.init_from_lut(rc_conv_luts, dense_conv_lut) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + rotated_prediction = self.stage(rotated) + unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) + output += unrotated_prediction + output /= 4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + +class RCNetx1(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetx1, self).__init__() + self.scale = scale + self.hidden_dim = hidden_dim + self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) + self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) + self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) + dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) + dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) + dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + lut_model = rclut.RCLutx1.init_from_lut( + rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3, + rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5, + rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7 + ) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) + + output /= 3*4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + + +class RCNetx2(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(RCNetx2, self).__init__() + self.scale = scale + self.hidden_dim = hidden_dim + self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3) + self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5) + self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7) + self.stage2_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) + self.stage2_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5) + self.stage2_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7) + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) + s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) + s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) + s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1) + s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1) + s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1) + s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size) + + lut_model = rclut.RCLutx2.init_from_lut( + s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3, + s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5, + s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7, + s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3, + s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5, + s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7 + ) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3]) + output /= 3*4 + x = output + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3]) + output /= 3*4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output \ No newline at end of file diff --git a/src/models/srlut.py b/src/models/srlut.py new file mode 100644 index 0000000..4a35e5c --- /dev/null +++ b/src/models/srlut.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from pathlib import Path +from common.lut import forward_2x2_input_SxS_output + +class SRLut(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(SRLut, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + + @staticmethod + def init_from_lut( + stage_lut + ): + scale = int(stage_lut.shape[-1]) + quantization_interval = 256//(stage_lut.shape[0]-1) + lut_model = SRLut(quantization_interval=quantization_interval, scale=scale) + lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut) + x = x.view(b, c, x.shape[-2], x.shape[-1]) + return x + + def __repr__(self): + return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" + + + +class SRLutRot90(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(SRLutRot90, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + + @staticmethod + def init_from_lut( + stage_lut + ): + scale = int(stage_lut.shape[-1]) + quantization_interval = 256//(stage_lut.shape[0]-1) + lut_model = SRLutRot90(quantization_interval=quantization_interval, scale=scale) + lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut) + unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) + output += unrotated_prediction + output /= 4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + + def __repr__(self): + return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" \ No newline at end of file diff --git a/src/models/srnet.py b/src/models/srnet.py new file mode 100644 index 0000000..9501e1a --- /dev/null +++ b/src/models/srnet.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from common.utils import round_func +from common import lut +from pathlib import Path +from .srlut import SRLut, SRLutRot90 + +# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. – 2017. – С. 4700-4708. +# https://ar5iv.labs.arxiv.org/html/1608.06993 +# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40 +class DenseConvUpscaleBlock(nn.Module): + def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1): + super(DenseConvUpscaleBlock, self).__init__() + assert layers_count > 0 + self.upscale_factor = upscale_factor + + self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True) + self.convs = [] + for i in range(layers_count): + self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)) + self.convs = nn.ModuleList(self.convs) + + for name, p in self.named_parameters(): + if "weight" in name: nn.init.kaiming_normal_(p) + if "bias" in name: nn.init.constant_(p, 0) + + self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True) + self.shuffle = nn.PixelShuffle(upscale_factor) + + def forward(self, x): + x = (x-127.5)/127.5 + x = torch.relu(self.percieve(x)) + for conv in self.convs: + x = torch.cat([x, torch.relu(conv(x))], dim=1) + x = self.shuffle(self.project_channels(x)) + x = torch.tanh(x) + x = round_func(x*127.5 + 127.5) + return x + +class SRNet(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SRNet, self).__init__() + self.scale = scale + self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + x = F.pad(x, pad=[0,1,0,1], mode='replicate') + x = self.stage(x) + x = x.view(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = SRLut.init_from_lut(stage_lut) + return lut_model + + +class SRNetRot90(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SRNetRot90, self).__init__() + self.scale = scale + self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) + rotated_padded = F.pad(rotated, pad=[0,1,0,1], mode='replicate') + rotated_prediction = self.stage(rotated_padded) + unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) + output += unrotated_prediction + output /= 4 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = SRLutRot90.init_from_lut(stage_lut) + return lut_model \ No newline at end of file diff --git a/src/scripts/image_demo.py b/src/scripts/image_demo.py new file mode 100644 index 0000000..8d5681d --- /dev/null +++ b/src/scripts/image_demo.py @@ -0,0 +1,51 @@ +from pathlib import Path +import sys +sys.path.insert(0, str(Path("../").resolve()) + "/") + +from models import LoadCheckpoint +import torch +import numpy as np +import cv2 +from PIL import Image +from datetime import datetime + +project_path = Path("../../").resolve() + +start_script_time = datetime.now() +# net_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCNetCentered_10000.pth") +# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCLutCentered_0.pth") + +# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCNetRot90_7x7_10000.pth") +# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCLutRot90_7x7_0.pth") + +# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCNetRot90_3x3_10000.pth") +# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCLutRot90_3x3_0.pth") + +# net_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCNetx1_46000.pth") +# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCLutx1_0.pth") + +net_model = LoadCheckpoint(project_path / "models" / "last_transfered_net.pth").cuda() +lut_model = LoadCheckpoint(project_path / "models" / "last_transfered_lut.pth").cuda() + +print(net_model) +print(lut_model) + +lr_image = cv2.imread(str(project_path / "data" / "Set14/LR/X4/lenna.png"))[:,:,::-1].copy() +image_gt = cv2.imread(str(project_path / "data" / "Set14/HR/lenna.png"))[:,:,::-1].copy() + +# lr_image = cv2.imread(str(project_path / "data" / "Synthetic/LR/X4/linear.png"))[:,:,::-1].copy() +# image_gt = cv2.imread(str(project_path / "data" / "Synthetic/HR/linear.png"))[:,:,::-1].copy() + +input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda() + +net_prediction = net_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() +lut_prediction = lut_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() + + +image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) +image_net = cv2.putText(net_prediction, net_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) +image_lut = cv2.putText(lut_prediction, lut_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) + +Image.fromarray(np.concatenate([image_gt, image_net, image_lut], 1)).save(project_path / "models" / 'last_transfered_demo.png') + +print(datetime.now() - start_script_time ) \ No newline at end of file diff --git a/src/scripts/train.py b/src/scripts/train.py new file mode 100644 index 0000000..b183119 --- /dev/null +++ b/src/scripts/train.py @@ -0,0 +1,199 @@ +import sys +sys.path.insert(0, "../") # run under the project directory +from pickle import dump +import logging +import math +import os +import time +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +from PIL import Image +from pathlib import Path +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import Dataset, DataLoader +from common.data import SRTrainDataset, SRTestDataset +from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr + +from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS +from common.validation import valid_steps + +torch.backends.cudnn.benchmark = True +import argparse +from schedulefree import AdamWScheduleFree +from datetime import datetime + +class TrainOptions: + def __init__(self): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False) + parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}") + parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.") + parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor") + parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers") + parser.add_argument('--models_dir', type=str, default='../../models/', help="experiment folder") + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size') + parser.add_argument('--datasets_dir', type=str, default="../../data/") + parser.add_argument('--train_datasets', type=str, default='DIV2K') + parser.add_argument('--val_datasets', type=str, default='Set5,Set14') + parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further') + parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations') + parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration') + parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration') + parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration') + parser.add_argument('--worker_num', '-n', type=int, default=1) + parser.add_argument('--prefetch_factor', '-p', type=int, default=16) + parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') + self.parser = parser + + def parse_args(self): + args = self.parser.parse_args() + args.models_dir = Path(args.models_dir) + args.model_path = Path(args.model_path) if not args.model_path is None else None + args.train_datasets = args.train_datasets.split(',') + args.val_datasets = args.val_datasets.split(',') + return args + + def print_options(self, opt): + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + print() + +def prepare_experiment_folder(config): + assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." + assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}." + + config.exp_dir = config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}" + + if not config.exp_dir.exists(): + config.exp_dir.mkdir() + + config.checkpoint_dir = config.exp_dir / "checkpoints" + if not config.checkpoint_dir.exists(): + config.checkpoint_dir.mkdir() + + config.valout_dir = config.exp_dir / 'val' + if not config.valout_dir.exists(): + config.valout_dir.mkdir() + + config.logs_dir = config.exp_dir / 'logs' + if not config.logs_dir.exists(): + config.logs_dir.mkdir() + + +if __name__ == "__main__": + script_start_time = datetime.now() + + config_inst = TrainOptions() + config = config_inst.parse_args() + + if not config.model_path is None: + model = LoadCheckpoint(config.model_path) + config.model = model.__class__.__name__ + else: + model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) + # model = model.cuda() + optimizer = AdamWScheduleFree(model.parameters()) + + prepare_experiment_folder(config) + + # Tensorboard for monitoring + writer = SummaryWriter(log_dir=config.logs_dir) + logger_name = 'train' + logger_info(logger_name, os.path.join(config.logs_dir, logger_name + '.log')) + logger = logging.getLogger(logger_name) + config.writer = writer + config.logger = logger + + config.logger.info(config_inst.print_options(config)) + print(model) + + # Training dataset + train_datasets = [] + for train_dataset_name in config.train_datasets: + train_datasets.append(SRTrainDataset( + hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR", + lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}", + patch_size = config.crop_size + )) + train_dataset = torch.utils.data.ConcatDataset(train_datasets) + train_loader = DataLoader( + dataset = train_dataset, + batch_size = config.batch_size, + num_workers = config.worker_num, + shuffle = False, + drop_last = False, + pin_memory = True, + prefetch_factor = config.prefetch_factor + ) + train_iter = iter(train_loader) + + test_datasets = {} + for test_dataset_name in config.val_datasets: + test_datasets[test_dataset_name] = SRTestDataset( + hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", + lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}", + ) + l_accum = [0., 0., 0.] + prepare_data_time = 0. + forward_backward_time = 0. + accum_samples = 0 + + # TRAINING + i = config.start_iter + for i in range(config.start_iter + 1, config.total_iter + 1): + torch.cuda.empty_cache() + start_time = time.time() + try: + hr_patch, lr_patch = next(train_iter) + except StopIteration: + train_iter = iter(train_loader) + hr_patch, lr_patch = next(train_iter) + # hr_patch = hr_patch.cuda() + # lr_patch = lr_patch.cuda() + prepare_data_time += time.time() - start_time + start_time = time.time() + optimizer.zero_grad() + pred = model(lr_patch) + loss = F.mse_loss(pred/255, hr_patch/255) + loss.backward() + optimizer.step() + forward_backward_time += time.time() - start_time + + # For monitoring + accum_samples += config.batch_size + l_accum[0] += loss.item() + + # Show information + if i % config.display_step == 0: + config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i) + + config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format( + config.exp_dir, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step, + forward_backward_time / config.display_step)) + l_accum = [0., 0., 0.] + prepare_data_time = 0. + forward_backward_time = 0. + + # Save models + if i % config.save_step == 0: + SaveCheckpoint(model=model, path=Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth") + + # Validation + if i % config.val_step == 0: + config.current_iter = i + valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") + + + + total_script_time = datetime.now() - script_start_time + config.logger.info(f"Completed after {total_script_time}") \ No newline at end of file diff --git a/src/scripts/transfer_to_lut.py b/src/scripts/transfer_to_lut.py new file mode 100644 index 0000000..b0af18c --- /dev/null +++ b/src/scripts/transfer_to_lut.py @@ -0,0 +1,75 @@ +import sys +sys.path.insert(0, "../") # run under the project directory +import logging +import math +import os +import time +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +from PIL import Image +from pathlib import Path +from torch.utils.tensorboard import SummaryWriter +torch.backends.cudnn.benchmark = True +from datetime import datetime +import argparse +import models + +class TransferToLutOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + self.parser.add_argument('--model_path', '-m', type=str, default='', help="model path folder") + self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets in 2**bits. Value is in range [1, 8].") + self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.") + + def parse_args(self): + args = self.parser.parse_args() + args.model_path = Path(args.model_path) + args.checkpoint_dir = Path(args.model_path).absolute().parent + return args + + def print_options(self, opt): + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + print() + + +if __name__ == "__main__": + start_time = datetime.now() + print(start_time) + config_inst = TransferToLutOptions() + config = config_inst.parse_args() + + config_inst.print_options(config) + + model = models.LoadCheckpoint(config.model_path).cuda() + print(model) + + print() + print("Transfering:") + lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size) + print() + print(lut_model) + + lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth" + models.SaveCheckpoint(model=lut_model, path=lut_path) + + lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()]) + + print() + print(datetime.now()-start_time) + print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB") + + models.SaveCheckpoint(model=model, path=Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_net.pth") + models.SaveCheckpoint(model=lut_model, path=Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_lut.pth") + print("Updated", Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_net.pth") + print("Updated", Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_lut.pth") \ No newline at end of file diff --git a/src/scripts/validate.py b/src/scripts/validate.py new file mode 100644 index 0000000..efd27cb --- /dev/null +++ b/src/scripts/validate.py @@ -0,0 +1,86 @@ +import sys +sys.path.insert(0, "../") # run under the project directory +import logging +import math +import os +import time +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +from PIL import Image +from pathlib import Path +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import Dataset, DataLoader +from common.data import SRTrainDataset, SRTestDataset +from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop +from common.validation import valid_steps +from models import LoadCheckpoint +torch.backends.cudnn.benchmark = True +from datetime import datetime +import argparse + +class ValOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + self.parser.add_argument('--model_path', type=str, help="Model path.") + self.parser.add_argument('--datasets_dir', type=str, default="../../data/") + self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14') + self.parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') + + def parse_args(self): + args = self.parser.parse_args() + args.datasets_dir = Path(args.datasets_dir).resolve() + args.val_datasets = args.val_datasets.split(',') + args.exp_dir = Path(args.model_path).absolute().parent.parent + args.model_path = Path(args.model_path) + args.model_name = args.model_path.stem + args.valout_dir = Path(args.exp_dir)/ 'val' + if not args.valout_dir.exists(): + args.valout_dir.mkdir() + args.current_iter = args.model_name.split('_')[-1] + # Tensorboard for monitoring + writer = SummaryWriter(log_dir=args.valout_dir) + logger_name = f'val_{args.model_path.stem}' + logger_info(logger_name, os.path.join(args.valout_dir, logger_name + '.log')) + logger = logging.getLogger(logger_name) + args.writer = writer + args.logger = logger + + return args + + def print_options(self, opt): + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + print() + + +# TODO with unified save/load function any model file of net or lut can be tested with the same script. +if __name__ == "__main__": + config_inst = ValOptions() + config = config_inst.parse_args() + + config.logger.info(config_inst.print_options(config)) + + model = LoadCheckpoint(config.model_path) + model = model.cuda() + print(model) + + test_datasets = {} + for test_dataset_name in config.val_datasets: + test_datasets[test_dataset_name] = SRTestDataset( + hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", + lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", + ) + + valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") + + config.logger.info("Complete") \ No newline at end of file