Vladimir Protsenko 7 months ago
commit 14a7f00245

160
.gitignore vendored

@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

2
data/.gitignore vendored

@ -0,0 +1,2 @@
*
!.gitignore

2
models/.gitignore vendored

@ -0,0 +1,2 @@
*
!.gitignore

@ -0,0 +1,12 @@
```
python train.py --train_datasets DIV2K_pillow_bicubic --val_datasets Set5,Set14 --scale 4 --total_iter 10000 --model SRNetRot90
python transfer_to_lut.py --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth
python image_demo.py
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth
```
Requierements:
- [shedulefree](https://github.com/facebookresearch/schedule_free)
- einops
- ray

@ -0,0 +1,112 @@
import os
import random
import sys
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
sys.path.insert(0, "../") # run under the project directory
image_extensions = ['.jpg', '.png']
def load_images_cached(images_dir_path):
image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions])
cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_cache.npy"
if not Path(cache_path).exists():
print("Caching to:", cache_path)
value = {f:np.array(Image.open(f)) for f in image_paths}
np.save(cache_path, value, allow_pickle=True)
else:
value = np.load(cache_path, allow_pickle=True).item()
print("Loaded cache from:", cache_path)
return list(value.keys()), list(value.values())
class SRTrainDataset(Dataset):
def __init__(self, hr_dir_path, lr_dir_path, patch_size, rigid_aug=True):
super(SRTrainDataset, self).__init__()
self.sz = patch_size
self.rigid_aug = rigid_aug
self.hr_image_names, self.hr_images = load_images_cached(hr_dir_path)
self.lr_image_names, self.lr_images = load_images_cached(lr_dir_path)
assert len(self.hr_images) == len(self.lr_images)
def __getitem__(self, idx):
if isinstance(idx, slice):
batch_hr, batch_lr = [], []
for i in range(idx.start, idx.stop):
hr_patch, lr_patch = self.__getitem__(i)
batch_hr.append(hr_patch)
batch_lr.append(lr_patch)
return batch_hr, batch_lr
else:
idx = idx % len(self.hr_images)
hr_image = self.hr_images[idx]
lr_image = self.lr_images[idx]
scale = hr_image.shape[0]//lr_image.shape[0]
i = random.randint(0, lr_image.shape[0] - self.sz)
j = random.randint(0, lr_image.shape[1] - self.sz)
c = random.choice([0, 1, 2])
hr_patch = hr_image[
(i*scale):(i*scale + self.sz*scale),
(j*scale):(j*scale + self.sz*scale),
c
]
lr_patch = lr_image[
i:(i + self.sz),
j:(j + self.sz),
c
]
if self.rigid_aug:
if random.uniform(0, 1) < 0.5:
hr_patch = np.fliplr(hr_patch)
lr_patch = np.fliplr(lr_patch)
if random.uniform(0, 1) < 0.5:
hr_patch = np.flipud(hr_patch)
lr_patch = np.flipud(lr_patch)
k = random.choice([0, 1, 2, 3])
hr_patch = np.rot90(hr_patch, k)
lr_patch = np.rot90(lr_patch, k)
hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32)
lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32)
hr_patch = hr_patch.unsqueeze(0)
lr_patch = lr_patch.unsqueeze(0)
return hr_patch, lr_patch
def __len__(self):
return len(self.hr_images)
class SRTestDataset(Dataset):
def __init__(self, hr_dir_path, lr_dir_path):
super(SRTestDataset, self).__init__()
self.hr_image_paths, self.hr_images = load_images_cached(hr_dir_path)
self.lr_image_paths, self.lr_images = load_images_cached(lr_dir_path)
assert len(self.hr_images) == len(self.lr_images)
def __getitem__(self, idx):
if isinstance(idx, slice):
batch_hr, batch_lr = [], []
for i in range(idx.start, idx.stop):
hr_image, lr_image = self.__getitem__(i)
batch_hr.append(hr_image)
batch_lr.append(lr_image)
return batch_hr, batch_lr
else:
idx = idx % len(self.hr_images)
return self.hr_images[idx], self.lr_images[idx], self.hr_image_paths[idx], self.lr_image_paths[idx]
def __len__(self):
return len(self.hr_images)
def __iter__(self):
for i in range(0, len(self.hr_images)):
yield self.__getitem__(i)

@ -0,0 +1,409 @@
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
from .utils import round_func
import numpy as np
##################### TRANSFER ##########################
class Domain4DValues(Dataset):
def __init__(self, quantization_interval=1):
super(Domain4DValues, self).__init__()
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([255])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 2, 2)
def __getitem__(self, idx):
if isinstance(idx, slice):
ix1s, ix2s, ix3s, ix4s, batch = [], [], [], [], []
for i in range(idx.start, idx.stop):
ix1, ix2, ix3, ix4, values = self.__getitem__(i)
ix1s.append(ix1)
ix2s.append(ix2)
ix3s.append(ix3)
ix4s.append(ix4)
batch.append(values)
return ix1s, ix2s, ix3s, ix4s, batch
else:
v = self.values[idx]
ix = v[0]//self.quantization_interval
return ix[0,0], ix[0,1], ix[1,0], ix[1,1], v
def __len__(self):
return len(self.values)
def __iter__(self):
for i in range(0, len(self.values)):
yield self.__getitem__(i)
def transfer_rc_conv(rc_conv, quantization_interval=1):
receptive_field_pixel_count = rc_conv.window_size**2
bucket_count = 256//quantization_interval
lut = np.full((receptive_field_pixel_count, bucket_count+1), dtype=np.uint8, fill_value=255)
for pixel_id in range(receptive_field_pixel_count):
for idx, value in enumerate(range(0, 256, quantization_interval)):
inputs = torch.tensor([value]).type(torch.float32).view(1,1,1).cuda()
with torch.no_grad():
outputs = rc_conv.pixel_wise_forward(inputs)
lut[:,idx] = outputs.flatten().cpu().numpy().astype(np.uint8)
print(f"\r {rc_conv.__class__.__name__} {pixel_id*bucket_count + idx +1}/{receptive_field_pixel_count*bucket_count}", end=" ")
print()
return lut
def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, 1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
domain_values = Domain4DValues(quantization_interval=quantization_interval)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
pin_memory=True,
num_workers=1 if quantization_interval >= 16 else mp.cpu_count()
)
counter = 0
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
with torch.no_grad():
outputs = block(inputs)
lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8) #[:,:,:scale,:scale] # TODO first layer automatically pad image
counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print()
lut = lut.squeeze(-3)
return lut
##################### FORWARD ##########################
def forward_2x2_input_SxS_output(index, lut):
b,c,hs,ws = index.shape
scale = lut.shape[-1]
index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') #?
out = select_index_4dlut_tetrahedral2(
ixA = index,
ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]),
ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]),
ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]),
lut = lut
)
out = out[:,:,0:-1,0:-1,:,:] # unpad
# Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504]
out = out.permute(0,1,2,4,3,5).reshape(b,1,hs*scale,ws*scale)
out = round_func(out)
return out
def forward_rc_conv_centered(index, lut):
window_size = lut.shape[0]
index = F.pad(index, pad=[window_size//2]*4, mode='replicate')
window_indexes = lut.shape[:-1]
index = index.unsqueeze(-1)
x = torch.zeros_like(index)
for i in range(window_indexes[-2]):
for j in range(window_indexes[-1]):
shift_i, shift_j = -window_indexes[-2]//2+1 + i, -window_indexes[-1]//2+1 + j
shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
x /= window_indexes[-2]*window_indexes[-1]
x = x.squeeze(-1)
x = round_func(x)
x = x[:,:,window_size//2:-window_size//2+1,window_size//2:-window_size//2+1]
return x
def forward_rc_conv_rot90(index, lut):
window_size = lut.shape[0]
index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate')
window_indexes = lut.shape[:-1]
index = index.unsqueeze(-1)
x = torch.zeros_like(index)
for i in range(window_indexes[-2]):
for j in range(window_indexes[-1]):
shift_i, shift_j = i, j
shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
x /= window_indexes[-2]*window_indexes[-1]
x = x.squeeze(-1)
x = round_func(x)
x = x[:,:,:-(window_size-1),:-(window_size-1)]
return x
##################### UTILS ##########################
def select_index_1dlut_linear(ixA, lut):
dimA = lut.shape[0]
qA = 256/(dimA-1)
outDims = lut.shape[1:]
lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
index_loop_indexes = ixA.shape
lut_loop_indexes = lut.shape[:-1]
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
msbA = torch.floor_divide(ixA, qA).type(torch.int64)
msbB = torch.floor_divide(ixA, qA).type(torch.int64) + 1
lsb_index = ixA % qA
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
outA = torch.gather(input=lut, dim=-1, index=msbA)
outB = torch.gather(input=lut, dim=-1, index=msbB)
out = outA + (lsb_index/qA) * (outB-outA)
out = out.squeeze(-1)
return out
def select_index_1dlut_msb(ixA, lut):
dimA = lut.shape[0]
outDims = lut.shape[1:]
lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
index_loop_indexes = ixA.shape
lut_loop_indexes = lut.shape[:-1]
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
msb_index = torch.floor_divide(ixA, 256/(dimA-1)).type(torch.int64) * dimA**0
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
out = torch.gather(input=lut, dim=-1, index=msb_index)
out = out.squeeze(-1)
return out
def select_index_4dlut_msb(ixA, ixB, ixC, ixD, lut):
dimA, dimB, dimC, dimD = lut.shape[:4]
qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1)
outDims = lut.shape[4:]
lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
index_loop_indexes = ixA.shape
lut_loop_indexes = lut.shape[:-1]
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
msb_index = torch.floor_divide(ixA, qA) * dimA**3
msb_index += torch.floor_divide(ixB, qB) * dimB**2
msb_index += torch.floor_divide(ixC, qC) * dimC**1
msb_index += torch.floor_divide(ixD, qD) * dimD**0
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
out = torch.gather(input=lut, dim=-1, index=msb_index.type(torch.int64))
out = out.squeeze(-1)
return out
def select_index_4dlut_linear(ixA, ixB, ixC, ixD, lut):
dimA, dimB, dimC, dimD = lut.shape[:4]
qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1)
outDims = lut.shape[4:]
lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
index_loop_indexes = ixA.shape
lut_loop_indexes = lut.shape[:-1]
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
msb_index = torch.floor_divide(ixA, qA).type(torch.int64) * dimA**3
msb_index += torch.floor_divide(ixB, qB).type(torch.int64) * dimB**2
msb_index += torch.floor_divide(ixC, qC).type(torch.int64) * dimC**1
msb_index += torch.floor_divide(ixD, qD).type(torch.int64) * dimD**0
outA = torch.gather(input=lut, dim=-1, index=msb_index)
msb_index = (torch.floor_divide(ixA, qA).type(torch.int64) + 1) * dimA**3
msb_index += (torch.floor_divide(ixB, qB).type(torch.int64) + 1) * dimB**2
msb_index += (torch.floor_divide(ixC, qC).type(torch.int64) + 1) * dimC**1
msb_index += (torch.floor_divide(ixD, qD).type(torch.int64) + 1) * dimD**0
outB = torch.gather(input=lut, dim=-1, index=msb_index)
lsb_coef = ((ixA+ixB+ixC+ixD)/4 % qA) / qA
out = outA + lsb_coef*(outB-outA)
out = out.squeeze(-1)
return out
def barycentric_interpolate(masks, coefs, vertices):
i = torch.all(torch.stack(masks), dim=0, keepdim = False)
coefs = torch.stack(coefs) * i
vertices = torch.stack(vertices)
out = (coefs*vertices).sum(0)
return i, out
def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
dimA, dimB, dimC, dimD = lut.shape[:4]
qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1)
outDims = lut.shape[4:]
lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
index_loop_indexes = ixA.shape
lut_loop_indexes = lut.shape[:-1]
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
msbA = torch.floor_divide(ixA, qA).type(torch.int64)
msbB = torch.floor_divide(ixB, qB).type(torch.int64)
msbC = torch.floor_divide(ixC, qC).type(torch.int64)
msbD = torch.floor_divide(ixD, qD).type(torch.int64)
fa, fb, fc, fd = ixA % qA, ixB % qB, ixC % qC, ixD % qD
fab, fac, fad, fbc, fbd, fcd = fa>fb, fa>fc, fa>fd, fb>fc, fb>fd, fc>fd
strides = torch.tensor([dimA**3, dimB**2, dimC**1, dimD**0], device=lut.device).view(-1, *((1,)*len(msbA.shape)))
p0000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD ])*strides).sum(0))
p0001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD+1])*strides).sum(0))
p0010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD ])*strides).sum(0))
p0011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD+1])*strides).sum(0))
p0100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD ])*strides).sum(0))
p0101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD+1])*strides).sum(0))
p0110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD ])*strides).sum(0))
p0111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD+1])*strides).sum(0))
p1000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD ])*strides).sum(0))
p1001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD+1])*strides).sum(0))
p1010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD ])*strides).sum(0))
p1011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD+1])*strides).sum(0))
p1100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD ])*strides).sum(0))
p1101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD+1])*strides).sum(0))
p1110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD ])*strides).sum(0))
p1111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD+1])*strides).sum(0))
i1, out1 = barycentric_interpolate([fab, fbc, fcd], [qA-fa, fa - fb, fb - fc, fc - fd, fd], [p0000, p1000, p1100, p1110, p1111])
i2, out2 = barycentric_interpolate([fab, fbc, fbd, ~(i1)], [qA-fa, fa - fb, fb - fd, fd - fc, fc], [p0000, p1000, p1100, p1101, p1111])
i3, out3 = barycentric_interpolate([fab, fbc, fad, ~(i1), ~(i2)], [qA-fa, fa - fd, fd - fb, fb - fc, fc], [p0000, p1000, p1001, p1101, p1111])
i4, out4 = barycentric_interpolate([fab, fbc, ~(i1), ~(i2), ~(i3)], [qA-fd, fd - fa, fa - fb, fb - fc, fc], [p0000, p0001, p1001, p1101, p1111])
i5, out5 = barycentric_interpolate([fab, fac, fbd, ~(fbc)], [qA-fa, fa - fc, fc - fb, fb - fd, fd], [p0000, p1000, p1010, p1110, p1111])
i6, out6 = barycentric_interpolate([fab, fac, fcd, ~(fbc), ~(i5)], [qA-fa, fa - fc, fc - fd, fd - fb, fb], [p0000, p1000, p1010, p1011, p1111])
i7, out7 = barycentric_interpolate([fab, fac, fad, ~(fbc), ~(i5), ~(i6)], [qA-fa, fa - fd, fd - fc, fc - fb, fb], [p0000, p1000, p1001, p1011, p1111])
i8, out8 = barycentric_interpolate([fab, fac, ~(fbc), ~(i5), ~(i6), ~(i7)], [qA-fd, fd - fa, fa - fc, fc - fb, fb], [p0000, p0001, p1001, p1011, p1111])
i9, out9 = barycentric_interpolate([fab, fbd, ~(fbc), ~(fac)], [qA-fc, fc - fa, fa - fb, fb - fd, fd], [p0000, p0010, p1010, p1110, p1111])
i10, out10 = barycentric_interpolate([fab, fad, ~(fbc), ~(fac), ~(i9)], [qA-fc, fc - fa, fa - fd, fd - fb, fb], [p0000, p0010, p1010, p1011, p1111])
i11, out11 = barycentric_interpolate([fab, fcd, ~(fbc), ~(fac), ~(i9), ~(i10)], [qA-fc, fc - fd, fd - fa, fa - fb, fb], [p0000, p0010, p0011, p1011, p1111])
i12, out12 = barycentric_interpolate([fab, ~(fbc), ~(fac), ~(i9), ~(i10), ~(i11)], [qA-fd, fd - fc, fc - fa, fa - fb, fb], [p0000, p0001, p0011, p1011, p1111])
i13, out13 = barycentric_interpolate([fac, fcd, ~(fab)], [qA-fb, fb - fa, fa - fc, fc - fd, fd], [p0000, p0100, p1100, p1110, p1111])
i14, out14 = barycentric_interpolate([fac, fad, ~(fab), ~(i13)], [qA-fb, fb - fa, fa - fd, fd - fc, fc], [p0000, p0100, p1100, p1101, p1111])
i15, out15 = barycentric_interpolate([fac, fbd, ~(fab), ~(i13), ~(i14)], [qA-fb, fb - fd, fd - fa, fa - fc, fc], [p0000, p0100, p0101, p1101, p1111])
i16, out16 = barycentric_interpolate([fac, ~(fab), ~(i13), ~(i14), ~(i15) ], [qA-fd, fd - fb, fb - fa, fa - fc, fc], [p0000, p0001, p0101, p1101, p1111])
i17, out17 = barycentric_interpolate([fbc, fad, ~(fab), ~(fac)], [qA-fb, fb - fc, fc - fa, fa - fd, fd], [p0000, p0100, p0110, p1110, p1111])
i18, out18 = barycentric_interpolate([fbc, fcd, ~(fab), ~(fac), ~(i17)], [qA-fb, fb - fc, fc - fd, fd - fa, fa], [p0000, p0100, p0110, p0111, p1111])
i19, out19 = barycentric_interpolate([fbc, fbd, ~(fab), ~(fac), ~(i17), ~(i18)], [qA-fb, fb - fd, fd - fc, fc - fa, fa], [p0000, p0100, p0101, p0111, p1111])
i20, out20 = barycentric_interpolate([fbc, ~(fab), ~(fac), ~(i17), ~(i18), ~(i19)], [qA-fd, fd - fb, fb - fc, fc - fa, fa], [p0000, p0001, p0101, p0111, p1111])
i21, out21 = barycentric_interpolate([fad, ~(fab), ~(fac), ~(fbc) ], [qA-fc, fc - fb, fb - fa, fa - fd, fd], [p0000, p0010, p0110, p1110, p1111])
i22, out22 = barycentric_interpolate([fbd, ~(fab), ~(fac), ~(fbc), ~(i21)], [qA-fc, fc - fb, fb - fd, fd - fa, fa], [p0000, p0010, p0110, p0111, p1111])
i23, out23 = barycentric_interpolate([fcd, ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22)], [qA-fc, fc - fd, fd - fb, fb - fa, fa], [p0000, p0010, p0011, p0111, p1111])
i24, out24 = barycentric_interpolate([ ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22), ~(i23)], [qA-fd, fd - fc, fc - fb, fb - fa, fa], [p0000, p0001, p0011, p0111, p1111])
out = out1 + out2 + out3 + out4 + out5 + out6 + out7 + out8 + out9 + out10 + out11 + out12 + out13 + out14 + out15 + out16 + out17 + out18 + out19 + out20 + out21 + out22 + out23 + out24
out /= qA
out = out.squeeze(-1)
return out
def select_index_4dlut_tetrahedral2(ixA, ixB, ixC, ixD, lut): #self, weight, upscale, mode, img_in, bd):
dimA, dimB, dimC, dimD = lut.shape[:4]
q = 256/(dimA-1)
L = dimA
upscale = lut.shape[-1]
weight = lut
img_a1 = torch.floor_divide(ixA, q).type(torch.int64)
img_b1 = torch.floor_divide(ixB, q).type(torch.int64)
img_c1 = torch.floor_divide(ixC, q).type(torch.int64)
img_d1 = torch.floor_divide(ixD, q).type(torch.int64)
# Extract LSBs
fa = ixA % q
fb = ixB % q
fc = ixC % q
fd = ixD % q
img_a2 = img_a1 + 1
img_b2 = img_b1 + 1
img_c2 = img_c1 + 1
img_d2 = img_d1 + 1
p0000 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0001 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0010 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0011 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0100 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0101 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0110 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p0111 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1000 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1001 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1010 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1011 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1100 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1101 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1110 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
p1111 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
out = torch.zeros((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale), dtype=weight.dtype).to(device=weight.device)
sz = img_a1.shape[0] * img_a1.shape[1] * img_a1.shape[2] * img_a1.shape[3]
out = out.reshape(sz, -1)
p0000 = p0000.reshape(sz, -1)
p0100 = p0100.reshape(sz, -1)
p1000 = p1000.reshape(sz, -1)
p1100 = p1100.reshape(sz, -1)
fa = fa.reshape(-1, 1)
p0001 = p0001.reshape(sz, -1)
p0101 = p0101.reshape(sz, -1)
p1001 = p1001.reshape(sz, -1)
p1101 = p1101.reshape(sz, -1)
fb = fb.reshape(-1, 1)
fc = fc.reshape(-1, 1)
p0010 = p0010.reshape(sz, -1)
p0110 = p0110.reshape(sz, -1)
p1010 = p1010.reshape(sz, -1)
p1110 = p1110.reshape(sz, -1)
fd = fd.reshape(-1, 1)
p0011 = p0011.reshape(sz, -1)
p0111 = p0111.reshape(sz, -1)
p1011 = p1011.reshape(sz, -1)
p1111 = p1111.reshape(sz, -1)
fab = fa > fb;
fac = fa > fc;
fad = fa > fd
fbc = fb > fc;
fbd = fb > fd;
fcd = fc > fd
i1 = i = torch.all(torch.cat([fab, fbc, fcd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
i2 = i = torch.all(torch.cat([~i1[:, None], fab, fbc, fbd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
i3 = i = torch.all(torch.cat([~i1[:, None], ~i2[:, None], fab, fbc, fad], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
i4 = i = torch.all(torch.cat([~i1[:, None], ~i2[:, None], ~i3[:, None], fab, fbc], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
i5 = i = torch.all(torch.cat([~(fbc), fab, fac, fbd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
i6 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], fab, fac, fcd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
i7 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], ~i6[:, None], fab, fac, fad], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
i8 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
i9 = i = torch.all(torch.cat([~(fbc), ~(fac), fab, fbd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
# Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first!
# i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], fab, fcd], dim=1), dim=1)
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
# i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad], dim=1), dim=1)
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
# c > a > d > b
i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], fab, fad], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
# c > d > a > b
i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
i12 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
i13 = i = torch.all(torch.cat([~(fab), fac, fcd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
i14 = i = torch.all(torch.cat([~(fab), ~i13[:, None], fac, fad], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
i15 = i = torch.all(torch.cat([~(fab), ~i13[:, None], ~i14[:, None], fac, fbd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
i16 = i = torch.all(torch.cat([~(fab), ~i13[:, None], ~i14[:, None], ~i15[:, None], fac], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
i17 = i = torch.all(torch.cat([~(fab), ~(fac), fbc, fad], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
i18 = i = torch.all(torch.cat([~(fab), ~(fac), ~i17[:, None], fbc, fcd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
i19 = i = torch.all(torch.cat([~(fab), ~(fac), ~i17[:, None], ~i18[:, None], fbc, fbd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
i20 = i = torch.all(torch.cat([~(fab), ~(fac), ~i17[:, None], ~i18[:, None], ~i19[:, None], fbc], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
i21 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), fad], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
i22 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], fbd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
i23 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale))
out = out / q
return out

@ -0,0 +1,112 @@
import logging
import cv2
import numpy as np
from scipy import signal
import torch
import os
def round_func(input):
# Backward Pass Differentiable Approximation (BPDA)
# This is equivalent to replacing round function (non-differentiable)
# with an identity function (differentiable) only when backward,
forward_value = torch.round(input)
out = input.clone()
out.data = forward_value.data
return out
def logger_info(logger_name, log_path='default_logger.log'):
log = logging.getLogger(logger_name)
if log.hasHandlers():
print('LogHandlers exist!')
else:
print('LogHandlers setup!')
level = logging.INFO
formatter = logging.Formatter(
'%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
fh = logging.FileHandler(log_path, mode='a')
fh.setFormatter(formatter)
log.setLevel(level)
log.addHandler(fh)
# print(len(log.handlers))
sh = logging.StreamHandler()
sh.setFormatter(formatter)
log.addHandler(sh)
def modcrop(image, modulo):
if len(image.shape) == 2:
sz = image.shape
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1]]
elif image.shape[2] == 3:
sz = image.shape[0:2]
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1], :]
else:
raise NotImplementedError
return image
def _rgb2ycbcr(img, maxVal=255):
O = np.array([[16],
[128],
[128]])
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
if maxVal == 1:
O = O / 255.0
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
t = np.dot(t, np.transpose(T))
t[:, 0] += O[0]
t[:, 1] += O[1]
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr
def PSNR(y_true, y_pred, shave_border=4):
target_data = np.array(y_true, dtype=np.float32)
ref_data = np.array(y_pred, dtype=np.float32)
diff = ref_data - target_data
if shave_border > 0:
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
rmse = np.sqrt(np.mean(np.power(diff, 2)))
return 20 * np.log10(255. / rmse)
def cal_ssim(img1, img2):
K = [0.01, 0.03]
L = 255
kernelX = cv2.getGaussianKernel(11, 1.5)
window = kernelX * kernelX.T
M, N = np.shape(img1)
C1 = (K[0] * L) ** 2
C2 = (K[1] * L) ** 2
img1 = np.float64(img1)
img2 = np.float64(img2)
mu1 = signal.convolve2d(img1, window, 'valid')
mu2 = signal.convolve2d(img2, window, 'valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
mssim = np.mean(ssim_map)
return mssim

@ -0,0 +1,66 @@
import torch
import numpy as np
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from pathlib import Path
from PIL import Image
import ray
@ray.remote(num_cpus=1, num_gpus=0.3)
def val_image_pair(model, hr_image, lr_image, output_image_path=None):
with torch.no_grad():
# prepare lr_image
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
lr_image = lr_image.unsqueeze(0).cuda()
b, c, h, w = lr_image.shape
lr_image = lr_image.reshape(b*c, 1, h, w)
# predict
pred_lr_image = model(lr_image)
# postprocess
pred_lr_image = pred_lr_image.reshape(b, c, h*model.scale, w*model.scale).squeeze(0).permute(1,2,0).type(torch.uint8)
pred_lr_image = pred_lr_image.cpu().numpy()
torch.cuda.empty_cache()
if not output_image_path is None:
Image.fromarray(pred_lr_image).save(output_image_path)
# metrics
hr_image = modcrop(hr_image, model.scale)
left, right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
torch.cuda.empty_cache()
return PSNR(left, right, model.scale), cal_ssim(left, right)
def valid_steps(model, datasets, config, log_prefix=""):
ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"})
dataset_names = list(datasets.keys())
for i in range(len(dataset_names)):
dataset_name = dataset_names[i]
psnrs, ssims = [], []
predictions_path = config.valout_dir / dataset_name
if not predictions_path.exists():
predictions_path.mkdir()
test_dataset = datasets[dataset_name]
tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_val_predictions else None
task = val_image_pair.remote(model, hr_image, lr_image, output_image_path)
tasks.append(task)
ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
while len(remaining_refs) > 0:
print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ")
ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None)
print("\r", end=" ")
tasks = [ray.get(task) for task in tasks]
for psnr, ssim in tasks:
psnrs.append(psnr)
ssims.append(ssim)
config.logger.info(
'{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims))))
config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter)
config.writer.flush()
print()

@ -0,0 +1,193 @@
import os
from pathlib import Path
import numpy as np
import cv2
from scipy import signal
from skimage.metrics import structural_similarity
from PIL import Image
import argparse
import time
from datetime import datetime
import ray
ray.init(num_cpus=16, num_gpus=0, ignore_reinit_error=True, log_to_driver=False)
parser = argparse.ArgumentParser()
parser.add_argument("path_to_dataset", type=str)
parser.add_argument("--scale", type=int, default=4)
args = parser.parse_args()
def cal_ssim(img1, img2):
K = [0.01, 0.03]
L = 255
kernelX = cv2.getGaussianKernel(11, 1.5)
window = kernelX * kernelX.T
M, N = np.shape(img1)
C1 = (K[0] * L) ** 2
C2 = (K[1] * L) ** 2
img1 = np.float64(img1)
img2 = np.float64(img2)
mu1 = signal.convolve2d(img1, window, 'valid')
mu2 = signal.convolve2d(img2, window, 'valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
mssim = np.mean(ssim_map)
return mssim
def PSNR(y_true, y_pred, shave_border=4):
target_data = np.array(y_true, dtype=np.float32)
ref_data = np.array(y_pred, dtype=np.float32)
diff = ref_data - target_data
if shave_border > 0:
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
rmse = np.sqrt(np.mean(np.power(diff, 2)))
return 20 * np.log10(255. / rmse)
def _rgb2ycbcr(img, maxVal=255):
O = np.array([[16],
[128],
[128]])
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
if maxVal == 1:
O = O / 255.0
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
t = np.dot(t, np.transpose(T))
t[:, 0] += O[0]
t[:, 1] += O[1]
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr
def modcrop(image, modulo):
if len(image.shape) == 2:
sz = image.shape
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1]]
elif image.shape[2] == 3:
sz = image.shape[0:2]
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1], :]
else:
raise NotImplementedError
return image
scale = args.scale
dataset_path = Path(args.path_to_dataset)
hr_path = dataset_path / "HR/"
lr_path = dataset_path / f"LR_bicubic/X{scale}/"
print(hr_path, lr_path)
hr_files = os.listdir(hr_path)
lr_files = os.listdir(lr_path)
@ray.remote(num_cpus=1)
def benchmark_image_pair(hr_image_path, lr_image_path, interpolation_function):
hr_image = cv2.imread(hr_image_path)
lr_image = cv2.imread(lr_image_path)
hr_image = hr_image[:,:,::-1] # BGR -> RGB
lr_image = lr_image[:,:,::-1] # BGR -> RGB
start_time = datetime.now()
upscaled_lr_image = interpolation_function(lr_image, scale)
processing_time = datetime.now() - start_time
hr_image = modcrop(hr_image, scale)
upscaled_lr_image = upscaled_lr_image
psnr = PSNR(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0])
cpsnr = PSNR(hr_image, upscaled_lr_image)
cv2_psnr = cv2.PSNR(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0])
cv2_cpsnr = cv2.PSNR(hr_image, upscaled_lr_image)
ssim = cal_ssim(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0])
cv2_ssim = cal_ssim(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0])
ssim_scikit, diff = structural_similarity(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0], full=True, data_range=255.0)
cv2_scikit_ssim, diff = structural_similarity(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], full=True, data_range=255.0)
return ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time.total_seconds()
def benchmark_interpolation(interpolation_function):
psnrs, cpsnrs, ssims = [], [], []
cv2_psnrs, cv2_cpsnrs, scikit_ssims = [], [], []
cv2_scikit_ssims = []
cv2_ssims = []
tasks = []
for hr_name, lr_name in zip(hr_files, lr_files):
hr_image_path = str(hr_path / hr_name)
lr_image_path = str(lr_path / lr_name)
tasks.append(benchmark_image_pair.remote(hr_image_path, lr_image_path, interpolation_function))
ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
while len(remaining_refs) > 0:
print(f"\rReady {len(ready_refs)}/{len(hr_files)}", end=" ")
ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None)
for task in tasks:
ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time = ray.get(task)
ssims.append(ssim)
cv2_ssims.append(cv2_ssim)
scikit_ssims.append(ssim_scikit)
cv2_scikit_ssims.append(cv2_scikit_ssim)
psnrs.append(psnr)
cpsnrs.append(cpsnr)
cv2_psnrs.append(cv2_psnr)
cv2_cpsnrs.append(cv2_cpsnr)
processing_times.append(processing_time)
print()
print(f"AVG PSNR: {np.mean(psnrs):.2f} PSNR + _rgb2ycbcr")
print(f"AVG PSNR: {np.mean(cv2_psnrs):.2f} cv2.PSNR + cv2.cvtColor")
print(f"AVG cPSNR: {np.mean(cpsnrs):.2f} PSNR")
print(f"AVG cPSNR: {np.mean(cv2_cpsnrs):.2f} cv2.PSNR ")
print(f"AVG SSIM: {np.mean(ssims):.4f} cal_ssim + _rgb2ycbcr")
print(f"AVG SSIM: {np.mean(cv2_ssims):.4f} cal_ssim + cv2.cvtColor")
print(f"AVG SSIM: {np.mean(scikit_ssims):.4f} structural_similarity + _rgb2ycbcr")
print(f"AVG SSIM: {np.mean(cv2_scikit_ssims):.4f} structural_similarity + cv2.cvtColor")
print(f"AVG Time s: {np.percentile(processing_times, q=0.9)}")
print(f"{np.mean(psnrs):.2f},{np.mean(cv2_psnrs):.2f},{np.mean(cpsnrs):.2f},{np.mean(cv2_cpsnrs):.2f},{np.mean(ssims):.4f},{np.mean(cv2_ssims):.4f},{np.mean(scikit_ssims):.4f},{np.mean(cv2_scikit_ssims):.4f},{np.percentile(processing_times, q=0.9)}")
def cv2_interpolation(image, scale):
scaled_image = cv2.resize(
image,
None, None,
fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC
)
return scaled_image
def pillow_interpolation(image, scale):
image = Image.fromarray(image[:,:,::-1])
width, height = int(image.width * scale), int(image.height * scale)
scaled_image = image.resize((width, height), resample=Image.Resampling.BICUBIC)
return np.array(scaled_image)[:,:,::-1]
print("cv2 bicubic interpolation")
benchmark_interpolation(cv2_interpolation)
print()
print("pillow bicubic interpolation")
benchmark_interpolation(pillow_interpolation)

@ -0,0 +1,36 @@
from .rcnet import RCNetCentered_3x3, RCNetCentered_7x7, RCNetRot90_3x3, RCNetRot90_7x7, RCNetx1, RCNetx2
from .rclut import RCLutCentered_3x3, RCLutCentered_7x7, RCLutRot90_3x3, RCLutRot90_7x7, RCLutx1, RCLutx2
from .srnet import SRNet, SRNetRot90
from .srlut import SRLut, SRLutRot90
import torch
import numpy as np
from pathlib import Path
AVAILABLE_MODELS = {
'SRNet': SRNet, 'SRLut': SRLut,
'SRNetRot90': SRNetRot90, 'SRLutRot90': SRLutRot90,
'RCNetCentered_3x3': RCNetCentered_3x3, 'RCLutCentered_3x3': RCLutCentered_3x3,
'RCNetCentered_7x7': RCNetCentered_7x7, 'RCLutCentered_7x7': RCLutCentered_7x7,
'RCNetRot90_3x3': RCNetRot90_3x3, 'RCLutRot90_3x3': RCLutRot90_3x3,
'RCNetRot90_7x7': RCNetRot90_7x7, 'RCLutRot90_7x7': RCLutRot90_7x7,
'RCNetx1': RCNetx1, 'RCLutx1': RCLutx1,
'RCNetx2': RCNetx2, 'RCLutx2': RCLutx2,
}
def SaveCheckpoint(model, path):
model_container = {
'model': model.__class__.__name__,
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
}
torch.save(model_container, path)
def LoadCheckpoint(model_path):
model_path = Path(model_path).absolute()
if model_path.exists():
model_container = torch.load(model_path)
model = AVAILABLE_MODELS[model_container['model']](**{k:v for k,v in model_container.items() if k != "model" and k != "state_dict"})
model.load_state_dict(model_container['state_dict'], strict=True)
return model
else:
raise Exception(f"Path {model_path} does not exist.")

@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import forward_2x2_input_SxS_output
class SRLut2x2(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SRLut2x2, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLut2x2(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut)
x = x.view(b, c, x.shape[-2], x.shape[-1])
return x
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
class SRLut3x3(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SRLut3x3, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLut3x3(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"

@ -0,0 +1,65 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import lut
from pathlib import Path
# from .mulut import MuLutx1, MuLutx2
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.percieve(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = round_func(x*127.5 + 127.5)
return x
class MuNetx1(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet2x2, self).__init__()
self.scale = scale
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_padded = F.pad(rotated, pad=[0,1,0,1], mode='replicate')
rotated_prediction = self.stage(rotated_padded)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = MuLutx1.init_from_lut(stage_lut)
return lut_model

@ -0,0 +1,394 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common.lut import select_index_1dlut_msb, select_index_4dlut_msb, select_index_4dlut_tetrahedral, select_index_1dlut_linear, \
forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
from pathlib import Path
from einops import repeat
class RCLutCentered_3x3(nn.Module):
def __init__(
self,
window_size,
quantization_interval,
scale
):
super(RCLutCentered_3x3, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.window_size = window_size
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts, dense_conv_lut
):
scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
window_size = rc_conv_luts.shape[0]
lut_model = RCLutCentered_3x3(window_size=window_size, quantization_interval=quantization_interval, scale=scale)
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1]
x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
x = x.view(b, c, x.shape[-2], x.shape[-1])
return x
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
")"])
class RCLutCentered_7x7(nn.Module):
def __init__(
self,
window_size,
quantization_interval,
scale
):
super(RCLutCentered_7x7, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.window_size = window_size
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts, dense_conv_lut
):
scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
window_size = rc_conv_luts.shape[0]
lut_model = RCLutCentered_7x7(window_size=window_size, quantization_interval=quantization_interval, scale=scale)
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
# x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4)
x = x.view(b, c, x.shape[-2], x.shape[-1])
return x
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
")"])
class RCLutRot90_3x3(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutRot90_3x3, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
window_size = 3
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts, dense_conv_lut
):
scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
window_size = rc_conv_luts.shape[0]
lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
")"])
class RCLutRot90_7x7(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutRot90_7x7, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
window_size = 7
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts, dense_conv_lut
):
scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
window_size = rc_conv_luts.shape[0]
lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale)
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
")"])
class RCLutx1(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutx1, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
window_size = 3
self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
window_size = 5
self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
window_size = 7
self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts_3x3, dense_conv_lut_3x3,
rc_conv_luts_5x5, dense_conv_lut_5x5,
rc_conv_luts_7x7, dense_conv_lut_7x7
):
scale = int(dense_conv_lut_3x3.shape[-1])
quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1)
lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale)
lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32))
lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32))
lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32))
lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32))
lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32))
lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32))
return lut_model
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
print("lut:", x.min(), x.max(), x.mean())
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}",
f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}",
f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}",
f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}",
f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}",
f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}",
")"])
class RCLutx2(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutx2, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
):
scale = int(s2_dense_conv_lut_3x3.shape[-1])
quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale)
lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
return lut_model
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
")"])

@ -0,0 +1,329 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from pathlib import Path
from common import lut
from . import rclut
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.percieve(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = round_func(x*127.5 + 127.5)
return x
class ReconstructedConvCentered(nn.Module):
def __init__(self, hidden_dim, window_size=7):
super(ReconstructedConvCentered, self).__init__()
self.window_size = window_size
self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
def pixel_wise_forward(self, x):
x = (x-127.5)/127.5
out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
out = torch.tanh(out)
out = out*127.5 + 127.5
return out
def forward(self, x):
original_shape = x.shape
x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
x = F.unfold(x, self.window_size)
x = self.pixel_wise_forward(x)
x = x.mean(1)
x = x.reshape(*original_shape)
x = round_func(x)
return x
def __repr__(self):
return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
class RCBlockCentered(nn.Module):
def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
super(RCBlockCentered, self).__init__()
self.window_size = window_size
self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size)
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
def forward(self, x):
b,c,hs,ws = x.shape
x = self.rc_conv(x)
x = F.pad(x, pad=[0,1,0,1], mode='replicate')
x = self.dense_conv_block(x)
return x
class RCNetCentered_3x3(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(RCNetCentered_3x3, self).__init__()
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.scale = scale
window_size = 3
self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = self.stage(x)
x = x.view(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
window_size = self.stage.rc_conv.window_size
rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutCentered_3x3.init_from_lut(rc_conv_luts, dense_conv_lut)
return lut_model
class RCNetCentered_7x7(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(RCNetCentered_7x7, self).__init__()
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.scale = scale
window_size = 7
self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = self.stage(x)
x = x.view(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
window_size = self.stage.rc_conv.window_size
rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutCentered_7x7.init_from_lut(rc_conv_luts, dense_conv_lut)
return lut_model
class ReconstructedConvRot90(nn.Module):
def __init__(self, hidden_dim, window_size=7):
super(ReconstructedConvRot90, self).__init__()
self.window_size = window_size
self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
def pixel_wise_forward(self, x):
x = (x-127.5)/127.5
out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2)
out = torch.tanh(out)
out = out*127.5 + 127.5
return out
def forward(self, x):
original_shape = x.shape
x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
x = F.unfold(x, self.window_size)
x = self.pixel_wise_forward(x)
x = x.mean(1)
x = x.reshape(*original_shape)
x = round_func(x) # quality likely suffer from this
return x
def __repr__(self):
return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
class RCBlockRot90(nn.Module):
def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
super(RCBlockRot90, self).__init__()
self.window_size = window_size
self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size)
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
def forward(self, x):
b,c,hs,ws = x.shape
x = self.rc_conv(x)
x = F.pad(x, pad=[0,1,0,1], mode='replicate')
x = self.dense_conv_block(x)
return x
class RCNetRot90_3x3(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(RCNetRot90_3x3, self).__init__()
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.scale = scale
self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
window_size = self.stage.rc_conv.window_size
rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutRot90_3x3.init_from_lut(rc_conv_luts, dense_conv_lut)
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_prediction = self.stage(rotated)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
return output
class RCNetRot90_7x7(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(RCNetRot90_7x7, self).__init__()
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.scale = scale
self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
window_size = self.stage.rc_conv.window_size
rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutRot90_7x7.init_from_lut(rc_conv_luts, dense_conv_lut)
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_prediction = self.stage(rotated)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
return output
class RCNetx1(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(RCNetx1, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutx1.init_from_lut(
rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3,
rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5,
rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7
)
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
output /= 3*4
output = output.view(b, c, h*self.scale, w*self.scale)
return output
class RCNetx2(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(RCNetx2, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
self.stage2_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
self.stage2_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
self.stage2_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = rclut.RCLutx2.init_from_lut(
s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
)
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
output /= 3*4
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])
output /= 3*4
output = output.view(b, c, h*self.scale, w*self.scale)
return output

@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from common.lut import forward_2x2_input_SxS_output
class SRLut(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SRLut, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLut(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut)
x = x.view(b, c, x.shape[-2], x.shape[-1])
return x
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
class SRLutRot90(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutRot90, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutRot90(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"

@ -0,0 +1,85 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import lut
from pathlib import Path
from .srlut import SRLut, SRLutRot90
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.percieve(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = round_func(x*127.5 + 127.5)
return x
class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__()
self.scale = scale
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = F.pad(x, pad=[0,1,0,1], mode='replicate')
x = self.stage(x)
x = x.view(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = SRLut.init_from_lut(stage_lut)
return lut_model
class SRNetRot90(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetRot90, self).__init__()
self.scale = scale
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_padded = F.pad(rotated, pad=[0,1,0,1], mode='replicate')
rotated_prediction = self.stage(rotated_padded)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = SRLutRot90.init_from_lut(stage_lut)
return lut_model

@ -0,0 +1,51 @@
from pathlib import Path
import sys
sys.path.insert(0, str(Path("../").resolve()) + "/")
from models import LoadCheckpoint
import torch
import numpy as np
import cv2
from PIL import Image
from datetime import datetime
project_path = Path("../../").resolve()
start_script_time = datetime.now()
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCNetCentered_10000.pth")
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCLutCentered_0.pth")
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCNetRot90_7x7_10000.pth")
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCLutRot90_7x7_0.pth")
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCNetRot90_3x3_10000.pth")
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCLutRot90_3x3_0.pth")
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCNetx1_46000.pth")
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCLutx1_0.pth")
net_model = LoadCheckpoint(project_path / "models" / "last_transfered_net.pth").cuda()
lut_model = LoadCheckpoint(project_path / "models" / "last_transfered_lut.pth").cuda()
print(net_model)
print(lut_model)
lr_image = cv2.imread(str(project_path / "data" / "Set14/LR/X4/lenna.png"))[:,:,::-1].copy()
image_gt = cv2.imread(str(project_path / "data" / "Set14/HR/lenna.png"))[:,:,::-1].copy()
# lr_image = cv2.imread(str(project_path / "data" / "Synthetic/LR/X4/linear.png"))[:,:,::-1].copy()
# image_gt = cv2.imread(str(project_path / "data" / "Synthetic/HR/linear.png"))[:,:,::-1].copy()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda()
net_prediction = net_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy()
lut_prediction = lut_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy()
image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
image_net = cv2.putText(net_prediction, net_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
image_lut = cv2.putText(lut_prediction, lut_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
Image.fromarray(np.concatenate([image_gt, image_net, image_lut], 1)).save(project_path / "models" / 'last_transfered_demo.png')
print(datetime.now() - start_script_time )

@ -0,0 +1,199 @@
import sys
sys.path.insert(0, "../") # run under the project directory
from pickle import dump
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr
from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS
from common.validation import valid_steps
torch.backends.cudnn.benchmark = True
import argparse
from schedulefree import AdamWScheduleFree
from datetime import datetime
class TrainOptions:
def __init__(self):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False)
parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--models_dir', type=str, default='../../models/', help="experiment folder")
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
parser.add_argument('--datasets_dir', type=str, default="../../data/")
parser.add_argument('--train_datasets', type=str, default='DIV2K')
parser.add_argument('--val_datasets', type=str, default='Set5,Set14')
parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further')
parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations')
parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration')
parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration')
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
parser.add_argument('--worker_num', '-n', type=int, default=1)
parser.add_argument('--prefetch_factor', '-p', type=int, default=16)
parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
self.parser = parser
def parse_args(self):
args = self.parser.parse_args()
args.models_dir = Path(args.models_dir)
args.model_path = Path(args.model_path) if not args.model_path is None else None
args.train_datasets = args.train_datasets.split(',')
args.val_datasets = args.val_datasets.split(',')
return args
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
print()
def prepare_experiment_folder(config):
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}."
config.exp_dir = config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}"
if not config.exp_dir.exists():
config.exp_dir.mkdir()
config.checkpoint_dir = config.exp_dir / "checkpoints"
if not config.checkpoint_dir.exists():
config.checkpoint_dir.mkdir()
config.valout_dir = config.exp_dir / 'val'
if not config.valout_dir.exists():
config.valout_dir.mkdir()
config.logs_dir = config.exp_dir / 'logs'
if not config.logs_dir.exists():
config.logs_dir.mkdir()
if __name__ == "__main__":
script_start_time = datetime.now()
config_inst = TrainOptions()
config = config_inst.parse_args()
if not config.model_path is None:
model = LoadCheckpoint(config.model_path)
config.model = model.__class__.__name__
else:
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale)
# model = model.cuda()
optimizer = AdamWScheduleFree(model.parameters())
prepare_experiment_folder(config)
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=config.logs_dir)
logger_name = 'train'
logger_info(logger_name, os.path.join(config.logs_dir, logger_name + '.log'))
logger = logging.getLogger(logger_name)
config.writer = writer
config.logger = logger
config.logger.info(config_inst.print_options(config))
print(model)
# Training dataset
train_datasets = []
for train_dataset_name in config.train_datasets:
train_datasets.append(SRTrainDataset(
hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}",
patch_size = config.crop_size
))
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
train_loader = DataLoader(
dataset = train_dataset,
batch_size = config.batch_size,
num_workers = config.worker_num,
shuffle = False,
drop_last = False,
pin_memory = True,
prefetch_factor = config.prefetch_factor
)
train_iter = iter(train_loader)
test_datasets = {}
for test_dataset_name in config.val_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}",
)
l_accum = [0., 0., 0.]
prepare_data_time = 0.
forward_backward_time = 0.
accum_samples = 0
# TRAINING
i = config.start_iter
for i in range(config.start_iter + 1, config.total_iter + 1):
torch.cuda.empty_cache()
start_time = time.time()
try:
hr_patch, lr_patch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
hr_patch, lr_patch = next(train_iter)
# hr_patch = hr_patch.cuda()
# lr_patch = lr_patch.cuda()
prepare_data_time += time.time() - start_time
start_time = time.time()
optimizer.zero_grad()
pred = model(lr_patch)
loss = F.mse_loss(pred/255, hr_patch/255)
loss.backward()
optimizer.step()
forward_backward_time += time.time() - start_time
# For monitoring
accum_samples += config.batch_size
l_accum[0] += loss.item()
# Show information
if i % config.display_step == 0:
config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i)
config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
config.exp_dir, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step,
forward_backward_time / config.display_step))
l_accum = [0., 0., 0.]
prepare_data_time = 0.
forward_backward_time = 0.
# Save models
if i % config.save_step == 0:
SaveCheckpoint(model=model, path=Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth")
# Validation
if i % config.val_step == 0:
config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}")

@ -0,0 +1,75 @@
import sys
sys.path.insert(0, "../") # run under the project directory
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
torch.backends.cudnn.benchmark = True
from datetime import datetime
import argparse
import models
class TransferToLutOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', '-m', type=str, default='', help="model path folder")
self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets in 2**bits. Value is in range [1, 8].")
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")
def parse_args(self):
args = self.parser.parse_args()
args.model_path = Path(args.model_path)
args.checkpoint_dir = Path(args.model_path).absolute().parent
return args
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
print()
if __name__ == "__main__":
start_time = datetime.now()
print(start_time)
config_inst = TransferToLutOptions()
config = config_inst.parse_args()
config_inst.print_options(config)
model = models.LoadCheckpoint(config.model_path).cuda()
print(model)
print()
print("Transfering:")
lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size)
print()
print(lut_model)
lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth"
models.SaveCheckpoint(model=lut_model, path=lut_path)
lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()])
print()
print(datetime.now()-start_time)
print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB")
models.SaveCheckpoint(model=model, path=Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_net.pth")
models.SaveCheckpoint(model=lut_model, path=Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_lut.pth")
print("Updated", Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_net.pth")
print("Updated", Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_lut.pth")

@ -0,0 +1,86 @@
import sys
sys.path.insert(0, "../") # run under the project directory
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from common.validation import valid_steps
from models import LoadCheckpoint
torch.backends.cudnn.benchmark = True
from datetime import datetime
import argparse
class ValOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../../data/")
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14')
self.parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
def parse_args(self):
args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve()
args.val_datasets = args.val_datasets.split(',')
args.exp_dir = Path(args.model_path).absolute().parent.parent
args.model_path = Path(args.model_path)
args.model_name = args.model_path.stem
args.valout_dir = Path(args.exp_dir)/ 'val'
if not args.valout_dir.exists():
args.valout_dir.mkdir()
args.current_iter = args.model_name.split('_')[-1]
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.valout_dir)
logger_name = f'val_{args.model_path.stem}'
logger_info(logger_name, os.path.join(args.valout_dir, logger_name + '.log'))
logger = logging.getLogger(logger_name)
args.writer = writer
args.logger = logger
return args
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
print()
# TODO with unified save/load function any model file of net or lut can be tested with the same script.
if __name__ == "__main__":
config_inst = ValOptions()
config = config_inst.parse_args()
config.logger.info(config_inst.print_options(config))
model = LoadCheckpoint(config.model_path)
model = model.cuda()
print(model)
test_datasets = {}
for test_dataset_name in config.val_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
)
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")
config.logger.info("Complete")
Loading…
Cancel
Save