main
commit
14a7f00245
@ -0,0 +1,160 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
@ -0,0 +1,12 @@
|
||||
|
||||
```
|
||||
python train.py --train_datasets DIV2K_pillow_bicubic --val_datasets Set5,Set14 --scale 4 --total_iter 10000 --model SRNetRot90
|
||||
python transfer_to_lut.py --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth
|
||||
python image_demo.py
|
||||
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth
|
||||
```
|
||||
|
||||
Requierements:
|
||||
- [shedulefree](https://github.com/facebookresearch/schedule_free)
|
||||
- einops
|
||||
- ray
|
@ -0,0 +1,112 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, "../") # run under the project directory
|
||||
|
||||
image_extensions = ['.jpg', '.png']
|
||||
def load_images_cached(images_dir_path):
|
||||
image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions])
|
||||
cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_cache.npy"
|
||||
if not Path(cache_path).exists():
|
||||
print("Caching to:", cache_path)
|
||||
value = {f:np.array(Image.open(f)) for f in image_paths}
|
||||
np.save(cache_path, value, allow_pickle=True)
|
||||
else:
|
||||
value = np.load(cache_path, allow_pickle=True).item()
|
||||
print("Loaded cache from:", cache_path)
|
||||
return list(value.keys()), list(value.values())
|
||||
|
||||
class SRTrainDataset(Dataset):
|
||||
def __init__(self, hr_dir_path, lr_dir_path, patch_size, rigid_aug=True):
|
||||
super(SRTrainDataset, self).__init__()
|
||||
self.sz = patch_size
|
||||
self.rigid_aug = rigid_aug
|
||||
self.hr_image_names, self.hr_images = load_images_cached(hr_dir_path)
|
||||
self.lr_image_names, self.lr_images = load_images_cached(lr_dir_path)
|
||||
assert len(self.hr_images) == len(self.lr_images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
batch_hr, batch_lr = [], []
|
||||
for i in range(idx.start, idx.stop):
|
||||
hr_patch, lr_patch = self.__getitem__(i)
|
||||
batch_hr.append(hr_patch)
|
||||
batch_lr.append(lr_patch)
|
||||
return batch_hr, batch_lr
|
||||
else:
|
||||
idx = idx % len(self.hr_images)
|
||||
hr_image = self.hr_images[idx]
|
||||
lr_image = self.lr_images[idx]
|
||||
scale = hr_image.shape[0]//lr_image.shape[0]
|
||||
|
||||
i = random.randint(0, lr_image.shape[0] - self.sz)
|
||||
j = random.randint(0, lr_image.shape[1] - self.sz)
|
||||
c = random.choice([0, 1, 2])
|
||||
|
||||
hr_patch = hr_image[
|
||||
(i*scale):(i*scale + self.sz*scale),
|
||||
(j*scale):(j*scale + self.sz*scale),
|
||||
c
|
||||
]
|
||||
lr_patch = lr_image[
|
||||
i:(i + self.sz),
|
||||
j:(j + self.sz),
|
||||
c
|
||||
]
|
||||
|
||||
if self.rigid_aug:
|
||||
if random.uniform(0, 1) < 0.5:
|
||||
hr_patch = np.fliplr(hr_patch)
|
||||
lr_patch = np.fliplr(lr_patch)
|
||||
|
||||
if random.uniform(0, 1) < 0.5:
|
||||
hr_patch = np.flipud(hr_patch)
|
||||
lr_patch = np.flipud(lr_patch)
|
||||
|
||||
k = random.choice([0, 1, 2, 3])
|
||||
hr_patch = np.rot90(hr_patch, k)
|
||||
lr_patch = np.rot90(lr_patch, k)
|
||||
|
||||
hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32)
|
||||
lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32)
|
||||
hr_patch = hr_patch.unsqueeze(0)
|
||||
lr_patch = lr_patch.unsqueeze(0)
|
||||
|
||||
return hr_patch, lr_patch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.hr_images)
|
||||
|
||||
class SRTestDataset(Dataset):
|
||||
def __init__(self, hr_dir_path, lr_dir_path):
|
||||
super(SRTestDataset, self).__init__()
|
||||
self.hr_image_paths, self.hr_images = load_images_cached(hr_dir_path)
|
||||
self.lr_image_paths, self.lr_images = load_images_cached(lr_dir_path)
|
||||
assert len(self.hr_images) == len(self.lr_images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
batch_hr, batch_lr = [], []
|
||||
for i in range(idx.start, idx.stop):
|
||||
hr_image, lr_image = self.__getitem__(i)
|
||||
batch_hr.append(hr_image)
|
||||
batch_lr.append(lr_image)
|
||||
return batch_hr, batch_lr
|
||||
else:
|
||||
idx = idx % len(self.hr_images)
|
||||
return self.hr_images[idx], self.lr_images[idx], self.hr_image_paths[idx], self.lr_image_paths[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.hr_images)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(0, len(self.hr_images)):
|
||||
yield self.__getitem__(i)
|
@ -0,0 +1,409 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.multiprocessing as mp
|
||||
from .utils import round_func
|
||||
import numpy as np
|
||||
|
||||
##################### TRANSFER ##########################
|
||||
|
||||
class Domain4DValues(Dataset):
|
||||
def __init__(self, quantization_interval=1):
|
||||
super(Domain4DValues, self).__init__()
|
||||
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
|
||||
values1d = torch.cat([values1d, torch.tensor([255])])
|
||||
self.quantization_interval = quantization_interval
|
||||
self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 2, 2)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
ix1s, ix2s, ix3s, ix4s, batch = [], [], [], [], []
|
||||
for i in range(idx.start, idx.stop):
|
||||
ix1, ix2, ix3, ix4, values = self.__getitem__(i)
|
||||
ix1s.append(ix1)
|
||||
ix2s.append(ix2)
|
||||
ix3s.append(ix3)
|
||||
ix4s.append(ix4)
|
||||
batch.append(values)
|
||||
return ix1s, ix2s, ix3s, ix4s, batch
|
||||
else:
|
||||
v = self.values[idx]
|
||||
ix = v[0]//self.quantization_interval
|
||||
return ix[0,0], ix[0,1], ix[1,0], ix[1,1], v
|
||||
|
||||
def __len__(self):
|
||||
return len(self.values)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(0, len(self.values)):
|
||||
yield self.__getitem__(i)
|
||||
|
||||
def transfer_rc_conv(rc_conv, quantization_interval=1):
|
||||
receptive_field_pixel_count = rc_conv.window_size**2
|
||||
bucket_count = 256//quantization_interval
|
||||
lut = np.full((receptive_field_pixel_count, bucket_count+1), dtype=np.uint8, fill_value=255)
|
||||
for pixel_id in range(receptive_field_pixel_count):
|
||||
for idx, value in enumerate(range(0, 256, quantization_interval)):
|
||||
inputs = torch.tensor([value]).type(torch.float32).view(1,1,1).cuda()
|
||||
with torch.no_grad():
|
||||
outputs = rc_conv.pixel_wise_forward(inputs)
|
||||
lut[:,idx] = outputs.flatten().cpu().numpy().astype(np.uint8)
|
||||
print(f"\r {rc_conv.__class__.__name__} {pixel_id*bucket_count + idx +1}/{receptive_field_pixel_count*bucket_count}", end=" ")
|
||||
print()
|
||||
return lut
|
||||
|
||||
def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
|
||||
bucket_count = 256//quantization_interval
|
||||
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
|
||||
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, 1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
|
||||
domain_values = Domain4DValues(quantization_interval=quantization_interval)
|
||||
domain_values_loader = DataLoader(
|
||||
domain_values,
|
||||
batch_size=batch_size if quantization_interval >= 16 else 2**16,
|
||||
pin_memory=True,
|
||||
num_workers=1 if quantization_interval >= 16 else mp.cpu_count()
|
||||
)
|
||||
counter = 0
|
||||
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
|
||||
inputs = batch.type(torch.float32).cuda()
|
||||
with torch.no_grad():
|
||||
outputs = block(inputs)
|
||||
lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8) #[:,:,:scale,:scale] # TODO first layer automatically pad image
|
||||
counter += inputs.shape[0]
|
||||
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
|
||||
print()
|
||||
lut = lut.squeeze(-3)
|
||||
return lut
|
||||
|
||||
|
||||
|
||||
##################### FORWARD ##########################
|
||||
|
||||
def forward_2x2_input_SxS_output(index, lut):
|
||||
b,c,hs,ws = index.shape
|
||||
scale = lut.shape[-1]
|
||||
index = F.pad(input=index, pad=[0,1,0,1], mode='replicate') #?
|
||||
out = select_index_4dlut_tetrahedral2(
|
||||
ixA = index,
|
||||
ixB = torch.roll(index, shifts=[0, -1], dims=[2,3]),
|
||||
ixC = torch.roll(index, shifts=[-1, 0], dims=[2,3]),
|
||||
ixD = torch.roll(index, shifts=[-1,-1], dims=[2,3]),
|
||||
lut = lut
|
||||
)
|
||||
out = out[:,:,0:-1,0:-1,:,:] # unpad
|
||||
# Pixel Shuffle. Example: [3, 1, 126, 126, 4, 4] -> [3, 1, 126, 4, 126, 4] -> [3, 1, 504, 504]
|
||||
out = out.permute(0,1,2,4,3,5).reshape(b,1,hs*scale,ws*scale)
|
||||
out = round_func(out)
|
||||
return out
|
||||
|
||||
def forward_rc_conv_centered(index, lut):
|
||||
window_size = lut.shape[0]
|
||||
index = F.pad(index, pad=[window_size//2]*4, mode='replicate')
|
||||
window_indexes = lut.shape[:-1]
|
||||
index = index.unsqueeze(-1)
|
||||
x = torch.zeros_like(index)
|
||||
for i in range(window_indexes[-2]):
|
||||
for j in range(window_indexes[-1]):
|
||||
shift_i, shift_j = -window_indexes[-2]//2+1 + i, -window_indexes[-1]//2+1 + j
|
||||
shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
|
||||
x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
|
||||
x /= window_indexes[-2]*window_indexes[-1]
|
||||
x = x.squeeze(-1)
|
||||
x = round_func(x)
|
||||
x = x[:,:,window_size//2:-window_size//2+1,window_size//2:-window_size//2+1]
|
||||
return x
|
||||
|
||||
def forward_rc_conv_rot90(index, lut):
|
||||
window_size = lut.shape[0]
|
||||
index = F.pad(index, pad=[0, window_size-1]*2, mode='replicate')
|
||||
window_indexes = lut.shape[:-1]
|
||||
index = index.unsqueeze(-1)
|
||||
x = torch.zeros_like(index)
|
||||
for i in range(window_indexes[-2]):
|
||||
for j in range(window_indexes[-1]):
|
||||
shift_i, shift_j = i, j
|
||||
shifted_index = torch.roll(index, shifts=[shift_i, shift_j], dims=[-2, -1])
|
||||
x += select_index_1dlut_linear(ixA=shifted_index, lut=lut[i,j])
|
||||
x /= window_indexes[-2]*window_indexes[-1]
|
||||
x = x.squeeze(-1)
|
||||
x = round_func(x)
|
||||
x = x[:,:,:-(window_size-1),:-(window_size-1)]
|
||||
return x
|
||||
|
||||
|
||||
##################### UTILS ##########################
|
||||
|
||||
def select_index_1dlut_linear(ixA, lut):
|
||||
dimA = lut.shape[0]
|
||||
qA = 256/(dimA-1)
|
||||
outDims = lut.shape[1:]
|
||||
lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
|
||||
index_loop_indexes = ixA.shape
|
||||
lut_loop_indexes = lut.shape[:-1]
|
||||
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
msbA = torch.floor_divide(ixA, qA).type(torch.int64)
|
||||
msbB = torch.floor_divide(ixA, qA).type(torch.int64) + 1
|
||||
lsb_index = ixA % qA
|
||||
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
|
||||
outA = torch.gather(input=lut, dim=-1, index=msbA)
|
||||
outB = torch.gather(input=lut, dim=-1, index=msbB)
|
||||
out = outA + (lsb_index/qA) * (outB-outA)
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
|
||||
def select_index_1dlut_msb(ixA, lut):
|
||||
dimA = lut.shape[0]
|
||||
outDims = lut.shape[1:]
|
||||
lut = lut.reshape(dimA, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
|
||||
index_loop_indexes = ixA.shape
|
||||
lut_loop_indexes = lut.shape[:-1]
|
||||
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
msb_index = torch.floor_divide(ixA, 256/(dimA-1)).type(torch.int64) * dimA**0
|
||||
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
|
||||
out = torch.gather(input=lut, dim=-1, index=msb_index)
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
|
||||
def select_index_4dlut_msb(ixA, ixB, ixC, ixD, lut):
|
||||
dimA, dimB, dimC, dimD = lut.shape[:4]
|
||||
qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1)
|
||||
outDims = lut.shape[4:]
|
||||
lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
|
||||
index_loop_indexes = ixA.shape
|
||||
lut_loop_indexes = lut.shape[:-1]
|
||||
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
msb_index = torch.floor_divide(ixA, qA) * dimA**3
|
||||
msb_index += torch.floor_divide(ixB, qB) * dimB**2
|
||||
msb_index += torch.floor_divide(ixC, qC) * dimC**1
|
||||
msb_index += torch.floor_divide(ixD, qD) * dimD**0
|
||||
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
|
||||
out = torch.gather(input=lut, dim=-1, index=msb_index.type(torch.int64))
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
|
||||
def select_index_4dlut_linear(ixA, ixB, ixC, ixD, lut):
|
||||
dimA, dimB, dimC, dimD = lut.shape[:4]
|
||||
qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1)
|
||||
outDims = lut.shape[4:]
|
||||
lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
|
||||
index_loop_indexes = ixA.shape
|
||||
lut_loop_indexes = lut.shape[:-1]
|
||||
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
|
||||
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
msb_index = torch.floor_divide(ixA, qA).type(torch.int64) * dimA**3
|
||||
msb_index += torch.floor_divide(ixB, qB).type(torch.int64) * dimB**2
|
||||
msb_index += torch.floor_divide(ixC, qC).type(torch.int64) * dimC**1
|
||||
msb_index += torch.floor_divide(ixD, qD).type(torch.int64) * dimD**0
|
||||
outA = torch.gather(input=lut, dim=-1, index=msb_index)
|
||||
msb_index = (torch.floor_divide(ixA, qA).type(torch.int64) + 1) * dimA**3
|
||||
msb_index += (torch.floor_divide(ixB, qB).type(torch.int64) + 1) * dimB**2
|
||||
msb_index += (torch.floor_divide(ixC, qC).type(torch.int64) + 1) * dimC**1
|
||||
msb_index += (torch.floor_divide(ixD, qD).type(torch.int64) + 1) * dimD**0
|
||||
outB = torch.gather(input=lut, dim=-1, index=msb_index)
|
||||
lsb_coef = ((ixA+ixB+ixC+ixD)/4 % qA) / qA
|
||||
out = outA + lsb_coef*(outB-outA)
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
|
||||
def barycentric_interpolate(masks, coefs, vertices):
|
||||
i = torch.all(torch.stack(masks), dim=0, keepdim = False)
|
||||
coefs = torch.stack(coefs) * i
|
||||
vertices = torch.stack(vertices)
|
||||
out = (coefs*vertices).sum(0)
|
||||
return i, out
|
||||
|
||||
def select_index_4dlut_tetrahedral(ixA, ixB, ixC, ixD, lut):
|
||||
dimA, dimB, dimC, dimD = lut.shape[:4]
|
||||
qA, qB, qC, qD = 256/(dimA-1), 256/(dimB-1), 256/(dimC-1), 256/(dimD-1)
|
||||
outDims = lut.shape[4:]
|
||||
lut = lut.reshape(dimA*dimB*dimC*dimD, *outDims).permute(*(i+1 for i in range(len(outDims))), 0)
|
||||
index_loop_indexes = ixA.shape
|
||||
lut_loop_indexes = lut.shape[:-1]
|
||||
lut = lut.view(*((1,)*len(index_loop_indexes) + lut.shape)).expand(index_loop_indexes + lut.shape)
|
||||
ixA = ixA.view(*(ixA.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
ixB = ixB.view(*(ixB.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
ixC = ixC.view(*(ixC.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
ixD = ixD.view(*(ixD.shape + (1,)*len(lut_loop_indexes))).expand(index_loop_indexes + lut_loop_indexes).unsqueeze(-1)
|
||||
|
||||
msbA = torch.floor_divide(ixA, qA).type(torch.int64)
|
||||
msbB = torch.floor_divide(ixB, qB).type(torch.int64)
|
||||
msbC = torch.floor_divide(ixC, qC).type(torch.int64)
|
||||
msbD = torch.floor_divide(ixD, qD).type(torch.int64)
|
||||
fa, fb, fc, fd = ixA % qA, ixB % qB, ixC % qC, ixD % qD
|
||||
fab, fac, fad, fbc, fbd, fcd = fa>fb, fa>fc, fa>fd, fb>fc, fb>fd, fc>fd
|
||||
|
||||
strides = torch.tensor([dimA**3, dimB**2, dimC**1, dimD**0], device=lut.device).view(-1, *((1,)*len(msbA.shape)))
|
||||
p0000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD ])*strides).sum(0))
|
||||
p0001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC, msbD+1])*strides).sum(0))
|
||||
p0010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD ])*strides).sum(0))
|
||||
p0011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB, msbC+1, msbD+1])*strides).sum(0))
|
||||
p0100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD ])*strides).sum(0))
|
||||
p0101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC, msbD+1])*strides).sum(0))
|
||||
p0110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD ])*strides).sum(0))
|
||||
p0111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA, msbB+1, msbC+1, msbD+1])*strides).sum(0))
|
||||
p1000 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD ])*strides).sum(0))
|
||||
p1001 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC, msbD+1])*strides).sum(0))
|
||||
p1010 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD ])*strides).sum(0))
|
||||
p1011 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB, msbC+1, msbD+1])*strides).sum(0))
|
||||
p1100 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD ])*strides).sum(0))
|
||||
p1101 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC, msbD+1])*strides).sum(0))
|
||||
p1110 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD ])*strides).sum(0))
|
||||
p1111 = torch.gather(input=lut, dim=-1, index=(torch.stack([msbA+1, msbB+1, msbC+1, msbD+1])*strides).sum(0))
|
||||
|
||||
i1, out1 = barycentric_interpolate([fab, fbc, fcd], [qA-fa, fa - fb, fb - fc, fc - fd, fd], [p0000, p1000, p1100, p1110, p1111])
|
||||
i2, out2 = barycentric_interpolate([fab, fbc, fbd, ~(i1)], [qA-fa, fa - fb, fb - fd, fd - fc, fc], [p0000, p1000, p1100, p1101, p1111])
|
||||
i3, out3 = barycentric_interpolate([fab, fbc, fad, ~(i1), ~(i2)], [qA-fa, fa - fd, fd - fb, fb - fc, fc], [p0000, p1000, p1001, p1101, p1111])
|
||||
i4, out4 = barycentric_interpolate([fab, fbc, ~(i1), ~(i2), ~(i3)], [qA-fd, fd - fa, fa - fb, fb - fc, fc], [p0000, p0001, p1001, p1101, p1111])
|
||||
i5, out5 = barycentric_interpolate([fab, fac, fbd, ~(fbc)], [qA-fa, fa - fc, fc - fb, fb - fd, fd], [p0000, p1000, p1010, p1110, p1111])
|
||||
i6, out6 = barycentric_interpolate([fab, fac, fcd, ~(fbc), ~(i5)], [qA-fa, fa - fc, fc - fd, fd - fb, fb], [p0000, p1000, p1010, p1011, p1111])
|
||||
i7, out7 = barycentric_interpolate([fab, fac, fad, ~(fbc), ~(i5), ~(i6)], [qA-fa, fa - fd, fd - fc, fc - fb, fb], [p0000, p1000, p1001, p1011, p1111])
|
||||
i8, out8 = barycentric_interpolate([fab, fac, ~(fbc), ~(i5), ~(i6), ~(i7)], [qA-fd, fd - fa, fa - fc, fc - fb, fb], [p0000, p0001, p1001, p1011, p1111])
|
||||
i9, out9 = barycentric_interpolate([fab, fbd, ~(fbc), ~(fac)], [qA-fc, fc - fa, fa - fb, fb - fd, fd], [p0000, p0010, p1010, p1110, p1111])
|
||||
i10, out10 = barycentric_interpolate([fab, fad, ~(fbc), ~(fac), ~(i9)], [qA-fc, fc - fa, fa - fd, fd - fb, fb], [p0000, p0010, p1010, p1011, p1111])
|
||||
i11, out11 = barycentric_interpolate([fab, fcd, ~(fbc), ~(fac), ~(i9), ~(i10)], [qA-fc, fc - fd, fd - fa, fa - fb, fb], [p0000, p0010, p0011, p1011, p1111])
|
||||
i12, out12 = barycentric_interpolate([fab, ~(fbc), ~(fac), ~(i9), ~(i10), ~(i11)], [qA-fd, fd - fc, fc - fa, fa - fb, fb], [p0000, p0001, p0011, p1011, p1111])
|
||||
|
||||
i13, out13 = barycentric_interpolate([fac, fcd, ~(fab)], [qA-fb, fb - fa, fa - fc, fc - fd, fd], [p0000, p0100, p1100, p1110, p1111])
|
||||
i14, out14 = barycentric_interpolate([fac, fad, ~(fab), ~(i13)], [qA-fb, fb - fa, fa - fd, fd - fc, fc], [p0000, p0100, p1100, p1101, p1111])
|
||||
i15, out15 = barycentric_interpolate([fac, fbd, ~(fab), ~(i13), ~(i14)], [qA-fb, fb - fd, fd - fa, fa - fc, fc], [p0000, p0100, p0101, p1101, p1111])
|
||||
i16, out16 = barycentric_interpolate([fac, ~(fab), ~(i13), ~(i14), ~(i15) ], [qA-fd, fd - fb, fb - fa, fa - fc, fc], [p0000, p0001, p0101, p1101, p1111])
|
||||
i17, out17 = barycentric_interpolate([fbc, fad, ~(fab), ~(fac)], [qA-fb, fb - fc, fc - fa, fa - fd, fd], [p0000, p0100, p0110, p1110, p1111])
|
||||
i18, out18 = barycentric_interpolate([fbc, fcd, ~(fab), ~(fac), ~(i17)], [qA-fb, fb - fc, fc - fd, fd - fa, fa], [p0000, p0100, p0110, p0111, p1111])
|
||||
i19, out19 = barycentric_interpolate([fbc, fbd, ~(fab), ~(fac), ~(i17), ~(i18)], [qA-fb, fb - fd, fd - fc, fc - fa, fa], [p0000, p0100, p0101, p0111, p1111])
|
||||
i20, out20 = barycentric_interpolate([fbc, ~(fab), ~(fac), ~(i17), ~(i18), ~(i19)], [qA-fd, fd - fb, fb - fc, fc - fa, fa], [p0000, p0001, p0101, p0111, p1111])
|
||||
i21, out21 = barycentric_interpolate([fad, ~(fab), ~(fac), ~(fbc) ], [qA-fc, fc - fb, fb - fa, fa - fd, fd], [p0000, p0010, p0110, p1110, p1111])
|
||||
i22, out22 = barycentric_interpolate([fbd, ~(fab), ~(fac), ~(fbc), ~(i21)], [qA-fc, fc - fb, fb - fd, fd - fa, fa], [p0000, p0010, p0110, p0111, p1111])
|
||||
i23, out23 = barycentric_interpolate([fcd, ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22)], [qA-fc, fc - fd, fd - fb, fb - fa, fa], [p0000, p0010, p0011, p0111, p1111])
|
||||
i24, out24 = barycentric_interpolate([ ~(fab), ~(fac), ~(fbc), ~(i21), ~(i22), ~(i23)], [qA-fd, fd - fc, fc - fb, fb - fa, fa], [p0000, p0001, p0011, p0111, p1111])
|
||||
|
||||
out = out1 + out2 + out3 + out4 + out5 + out6 + out7 + out8 + out9 + out10 + out11 + out12 + out13 + out14 + out15 + out16 + out17 + out18 + out19 + out20 + out21 + out22 + out23 + out24
|
||||
out /= qA
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
|
||||
|
||||
def select_index_4dlut_tetrahedral2(ixA, ixB, ixC, ixD, lut): #self, weight, upscale, mode, img_in, bd):
|
||||
dimA, dimB, dimC, dimD = lut.shape[:4]
|
||||
q = 256/(dimA-1)
|
||||
L = dimA
|
||||
upscale = lut.shape[-1]
|
||||
weight = lut
|
||||
|
||||
img_a1 = torch.floor_divide(ixA, q).type(torch.int64)
|
||||
img_b1 = torch.floor_divide(ixB, q).type(torch.int64)
|
||||
img_c1 = torch.floor_divide(ixC, q).type(torch.int64)
|
||||
img_d1 = torch.floor_divide(ixD, q).type(torch.int64)
|
||||
|
||||
# Extract LSBs
|
||||
fa = ixA % q
|
||||
fb = ixB % q
|
||||
fc = ixC % q
|
||||
fd = ixD % q
|
||||
|
||||
img_a2 = img_a1 + 1
|
||||
img_b2 = img_b1 + 1
|
||||
img_c2 = img_c1 + 1
|
||||
img_d2 = img_d1 + 1
|
||||
|
||||
p0000 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p0001 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p0010 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p0011 = weight[img_a1.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p0100 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p0101 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p0110 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p0111 = weight[img_a1.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
|
||||
p1000 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p1001 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p1010 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p1011 = weight[img_a2.flatten() * L * L * L + img_b1.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p1100 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p1101 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c1.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p1110 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d1.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
p1111 = weight[img_a2.flatten() * L * L * L + img_b2.flatten() * L * L + img_c2.flatten() * L + img_d2.flatten()].reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
|
||||
out = torch.zeros((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale), dtype=weight.dtype).to(device=weight.device)
|
||||
sz = img_a1.shape[0] * img_a1.shape[1] * img_a1.shape[2] * img_a1.shape[3]
|
||||
out = out.reshape(sz, -1)
|
||||
|
||||
p0000 = p0000.reshape(sz, -1)
|
||||
p0100 = p0100.reshape(sz, -1)
|
||||
p1000 = p1000.reshape(sz, -1)
|
||||
p1100 = p1100.reshape(sz, -1)
|
||||
fa = fa.reshape(-1, 1)
|
||||
|
||||
p0001 = p0001.reshape(sz, -1)
|
||||
p0101 = p0101.reshape(sz, -1)
|
||||
p1001 = p1001.reshape(sz, -1)
|
||||
p1101 = p1101.reshape(sz, -1)
|
||||
fb = fb.reshape(-1, 1)
|
||||
fc = fc.reshape(-1, 1)
|
||||
|
||||
p0010 = p0010.reshape(sz, -1)
|
||||
p0110 = p0110.reshape(sz, -1)
|
||||
p1010 = p1010.reshape(sz, -1)
|
||||
p1110 = p1110.reshape(sz, -1)
|
||||
fd = fd.reshape(-1, 1)
|
||||
|
||||
p0011 = p0011.reshape(sz, -1)
|
||||
p0111 = p0111.reshape(sz, -1)
|
||||
p1011 = p1011.reshape(sz, -1)
|
||||
p1111 = p1111.reshape(sz, -1)
|
||||
|
||||
fab = fa > fb;
|
||||
fac = fa > fc;
|
||||
fad = fa > fd
|
||||
|
||||
fbc = fb > fc;
|
||||
fbd = fb > fd;
|
||||
fcd = fc > fd
|
||||
|
||||
i1 = i = torch.all(torch.cat([fab, fbc, fcd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
|
||||
i2 = i = torch.all(torch.cat([~i1[:, None], fab, fbc, fbd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
|
||||
i3 = i = torch.all(torch.cat([~i1[:, None], ~i2[:, None], fab, fbc, fad], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
|
||||
i4 = i = torch.all(torch.cat([~i1[:, None], ~i2[:, None], ~i3[:, None], fab, fbc], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
|
||||
|
||||
i5 = i = torch.all(torch.cat([~(fbc), fab, fac, fbd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
|
||||
i6 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], fab, fac, fcd], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
||||
i7 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], ~i6[:, None], fab, fac, fad], dim=1), dim=1); out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
||||
i8 = i = torch.all(torch.cat([~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
||||
|
||||
i9 = i = torch.all(torch.cat([~(fbc), ~(fac), fab, fbd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
|
||||
# Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first!
|
||||
# i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], fab, fcd], dim=1), dim=1)
|
||||
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
||||
# i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad], dim=1), dim=1)
|
||||
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
||||
# c > a > d > b
|
||||
i10 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], fab, fad], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
||||
# c > d > a > b
|
||||
i11 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
||||
i12 = i = torch.all(torch.cat([~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]
|
||||
|
||||
i13 = i = torch.all(torch.cat([~(fab), fac, fcd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
|
||||
i14 = i = torch.all(torch.cat([~(fab), ~i13[:, None], fac, fad], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
|
||||
i15 = i = torch.all(torch.cat([~(fab), ~i13[:, None], ~i14[:, None], fac, fbd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
|
||||
i16 = i = torch.all(torch.cat([~(fab), ~i13[:, None], ~i14[:, None], ~i15[:, None], fac], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]
|
||||
|
||||
i17 = i = torch.all(torch.cat([~(fab), ~(fac), fbc, fad], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
|
||||
i18 = i = torch.all(torch.cat([~(fab), ~(fac), ~i17[:, None], fbc, fcd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
|
||||
i19 = i = torch.all(torch.cat([~(fab), ~(fac), ~i17[:, None], ~i18[:, None], fbc, fbd], dim=1), dim=1); out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
|
||||
i20 = i = torch.all(torch.cat([~(fab), ~(fac), ~i17[:, None], ~i18[:, None], ~i19[:, None], fbc], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
|
||||
|
||||
i21 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), fad], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]
|
||||
i22 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], fbd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
|
||||
i23 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd], dim=1), dim=1); out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
|
||||
i24 = i = torch.all(torch.cat([~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None]], dim=1), dim=1); out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]
|
||||
|
||||
out = out.reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2], img_a1.shape[3], upscale, upscale))
|
||||
out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale))
|
||||
out = out / q
|
||||
return out
|
@ -0,0 +1,112 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
import torch
|
||||
import os
|
||||
|
||||
def round_func(input):
|
||||
# Backward Pass Differentiable Approximation (BPDA)
|
||||
# This is equivalent to replacing round function (non-differentiable)
|
||||
# with an identity function (differentiable) only when backward,
|
||||
forward_value = torch.round(input)
|
||||
out = input.clone()
|
||||
out.data = forward_value.data
|
||||
return out
|
||||
|
||||
def logger_info(logger_name, log_path='default_logger.log'):
|
||||
log = logging.getLogger(logger_name)
|
||||
if log.hasHandlers():
|
||||
print('LogHandlers exist!')
|
||||
else:
|
||||
print('LogHandlers setup!')
|
||||
level = logging.INFO
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
|
||||
fh = logging.FileHandler(log_path, mode='a')
|
||||
fh.setFormatter(formatter)
|
||||
log.setLevel(level)
|
||||
log.addHandler(fh)
|
||||
# print(len(log.handlers))
|
||||
|
||||
sh = logging.StreamHandler()
|
||||
sh.setFormatter(formatter)
|
||||
log.addHandler(sh)
|
||||
|
||||
|
||||
def modcrop(image, modulo):
|
||||
if len(image.shape) == 2:
|
||||
sz = image.shape
|
||||
sz = sz - np.mod(sz, modulo)
|
||||
image = image[0:sz[0], 0:sz[1]]
|
||||
elif image.shape[2] == 3:
|
||||
sz = image.shape[0:2]
|
||||
sz = sz - np.mod(sz, modulo)
|
||||
image = image[0:sz[0], 0:sz[1], :]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return image
|
||||
|
||||
|
||||
def _rgb2ycbcr(img, maxVal=255):
|
||||
O = np.array([[16],
|
||||
[128],
|
||||
[128]])
|
||||
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
|
||||
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
|
||||
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
|
||||
|
||||
if maxVal == 1:
|
||||
O = O / 255.0
|
||||
|
||||
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
|
||||
t = np.dot(t, np.transpose(T))
|
||||
t[:, 0] += O[0]
|
||||
t[:, 1] += O[1]
|
||||
t[:, 2] += O[2]
|
||||
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
|
||||
|
||||
return ycbcr
|
||||
|
||||
|
||||
def PSNR(y_true, y_pred, shave_border=4):
|
||||
target_data = np.array(y_true, dtype=np.float32)
|
||||
ref_data = np.array(y_pred, dtype=np.float32)
|
||||
|
||||
diff = ref_data - target_data
|
||||
if shave_border > 0:
|
||||
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
|
||||
rmse = np.sqrt(np.mean(np.power(diff, 2)))
|
||||
|
||||
return 20 * np.log10(255. / rmse)
|
||||
|
||||
|
||||
def cal_ssim(img1, img2):
|
||||
K = [0.01, 0.03]
|
||||
L = 255
|
||||
kernelX = cv2.getGaussianKernel(11, 1.5)
|
||||
window = kernelX * kernelX.T
|
||||
|
||||
M, N = np.shape(img1)
|
||||
|
||||
C1 = (K[0] * L) ** 2
|
||||
C2 = (K[1] * L) ** 2
|
||||
img1 = np.float64(img1)
|
||||
img2 = np.float64(img2)
|
||||
|
||||
mu1 = signal.convolve2d(img1, window, 'valid')
|
||||
mu2 = signal.convolve2d(img2, window, 'valid')
|
||||
|
||||
mu1_sq = mu1 * mu1
|
||||
mu2_sq = mu2 * mu2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
|
||||
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
|
||||
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
mssim = np.mean(ssim_map)
|
||||
return mssim
|
||||
|
@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import ray
|
||||
|
||||
@ray.remote(num_cpus=1, num_gpus=0.3)
|
||||
def val_image_pair(model, hr_image, lr_image, output_image_path=None):
|
||||
with torch.no_grad():
|
||||
# prepare lr_image
|
||||
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
|
||||
lr_image = lr_image.unsqueeze(0).cuda()
|
||||
b, c, h, w = lr_image.shape
|
||||
lr_image = lr_image.reshape(b*c, 1, h, w)
|
||||
# predict
|
||||
pred_lr_image = model(lr_image)
|
||||
# postprocess
|
||||
pred_lr_image = pred_lr_image.reshape(b, c, h*model.scale, w*model.scale).squeeze(0).permute(1,2,0).type(torch.uint8)
|
||||
pred_lr_image = pred_lr_image.cpu().numpy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not output_image_path is None:
|
||||
Image.fromarray(pred_lr_image).save(output_image_path)
|
||||
|
||||
# metrics
|
||||
hr_image = modcrop(hr_image, model.scale)
|
||||
left, right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
|
||||
torch.cuda.empty_cache()
|
||||
return PSNR(left, right, model.scale), cal_ssim(left, right)
|
||||
|
||||
def valid_steps(model, datasets, config, log_prefix=""):
|
||||
ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"})
|
||||
dataset_names = list(datasets.keys())
|
||||
|
||||
for i in range(len(dataset_names)):
|
||||
dataset_name = dataset_names[i]
|
||||
psnrs, ssims = [], []
|
||||
|
||||
predictions_path = config.valout_dir / dataset_name
|
||||
if not predictions_path.exists():
|
||||
predictions_path.mkdir()
|
||||
|
||||
test_dataset = datasets[dataset_name]
|
||||
tasks = []
|
||||
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
|
||||
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_val_predictions else None
|
||||
task = val_image_pair.remote(model, hr_image, lr_image, output_image_path)
|
||||
tasks.append(task)
|
||||
|
||||
ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
|
||||
while len(remaining_refs) > 0:
|
||||
print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ")
|
||||
ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None)
|
||||
print("\r", end=" ")
|
||||
tasks = [ray.get(task) for task in tasks]
|
||||
|
||||
for psnr, ssim in tasks:
|
||||
psnrs.append(psnr)
|
||||
ssims.append(ssim)
|
||||
|
||||
config.logger.info(
|
||||
'{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims))))
|
||||
config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter)
|
||||
config.writer.flush()
|
||||
print()
|
@ -0,0 +1,193 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import cv2
|
||||
from scipy import signal
|
||||
from skimage.metrics import structural_similarity
|
||||
from PIL import Image
|
||||
import argparse
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
import ray
|
||||
ray.init(num_cpus=16, num_gpus=0, ignore_reinit_error=True, log_to_driver=False)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("path_to_dataset", type=str)
|
||||
parser.add_argument("--scale", type=int, default=4)
|
||||
args = parser.parse_args()
|
||||
|
||||
def cal_ssim(img1, img2):
|
||||
K = [0.01, 0.03]
|
||||
L = 255
|
||||
kernelX = cv2.getGaussianKernel(11, 1.5)
|
||||
window = kernelX * kernelX.T
|
||||
|
||||
M, N = np.shape(img1)
|
||||
|
||||
C1 = (K[0] * L) ** 2
|
||||
C2 = (K[1] * L) ** 2
|
||||
img1 = np.float64(img1)
|
||||
img2 = np.float64(img2)
|
||||
|
||||
mu1 = signal.convolve2d(img1, window, 'valid')
|
||||
mu2 = signal.convolve2d(img2, window, 'valid')
|
||||
|
||||
mu1_sq = mu1 * mu1
|
||||
mu2_sq = mu2 * mu2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
|
||||
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
|
||||
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
mssim = np.mean(ssim_map)
|
||||
return mssim
|
||||
|
||||
def PSNR(y_true, y_pred, shave_border=4):
|
||||
target_data = np.array(y_true, dtype=np.float32)
|
||||
ref_data = np.array(y_pred, dtype=np.float32)
|
||||
|
||||
diff = ref_data - target_data
|
||||
if shave_border > 0:
|
||||
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
|
||||
rmse = np.sqrt(np.mean(np.power(diff, 2)))
|
||||
|
||||
return 20 * np.log10(255. / rmse)
|
||||
|
||||
def _rgb2ycbcr(img, maxVal=255):
|
||||
O = np.array([[16],
|
||||
[128],
|
||||
[128]])
|
||||
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
|
||||
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
|
||||
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
|
||||
|
||||
if maxVal == 1:
|
||||
O = O / 255.0
|
||||
|
||||
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
|
||||
t = np.dot(t, np.transpose(T))
|
||||
t[:, 0] += O[0]
|
||||
t[:, 1] += O[1]
|
||||
t[:, 2] += O[2]
|
||||
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
|
||||
|
||||
return ycbcr
|
||||
|
||||
|
||||
def modcrop(image, modulo):
|
||||
if len(image.shape) == 2:
|
||||
sz = image.shape
|
||||
sz = sz - np.mod(sz, modulo)
|
||||
image = image[0:sz[0], 0:sz[1]]
|
||||
elif image.shape[2] == 3:
|
||||
sz = image.shape[0:2]
|
||||
sz = sz - np.mod(sz, modulo)
|
||||
image = image[0:sz[0], 0:sz[1], :]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return image
|
||||
|
||||
scale = args.scale
|
||||
|
||||
dataset_path = Path(args.path_to_dataset)
|
||||
hr_path = dataset_path / "HR/"
|
||||
lr_path = dataset_path / f"LR_bicubic/X{scale}/"
|
||||
|
||||
|
||||
print(hr_path, lr_path)
|
||||
|
||||
hr_files = os.listdir(hr_path)
|
||||
lr_files = os.listdir(lr_path)
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def benchmark_image_pair(hr_image_path, lr_image_path, interpolation_function):
|
||||
hr_image = cv2.imread(hr_image_path)
|
||||
lr_image = cv2.imread(lr_image_path)
|
||||
|
||||
hr_image = hr_image[:,:,::-1] # BGR -> RGB
|
||||
lr_image = lr_image[:,:,::-1] # BGR -> RGB
|
||||
|
||||
start_time = datetime.now()
|
||||
upscaled_lr_image = interpolation_function(lr_image, scale)
|
||||
processing_time = datetime.now() - start_time
|
||||
|
||||
hr_image = modcrop(hr_image, scale)
|
||||
upscaled_lr_image = upscaled_lr_image
|
||||
|
||||
psnr = PSNR(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0])
|
||||
cpsnr = PSNR(hr_image, upscaled_lr_image)
|
||||
|
||||
cv2_psnr = cv2.PSNR(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0])
|
||||
cv2_cpsnr = cv2.PSNR(hr_image, upscaled_lr_image)
|
||||
|
||||
ssim = cal_ssim(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0])
|
||||
cv2_ssim = cal_ssim(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0])
|
||||
ssim_scikit, diff = structural_similarity(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0], full=True, data_range=255.0)
|
||||
cv2_scikit_ssim, diff = structural_similarity(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], full=True, data_range=255.0)
|
||||
|
||||
return ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time.total_seconds()
|
||||
|
||||
|
||||
def benchmark_interpolation(interpolation_function):
|
||||
psnrs, cpsnrs, ssims = [], [], []
|
||||
cv2_psnrs, cv2_cpsnrs, scikit_ssims = [], [], []
|
||||
cv2_scikit_ssims = []
|
||||
cv2_ssims = []
|
||||
tasks = []
|
||||
for hr_name, lr_name in zip(hr_files, lr_files):
|
||||
hr_image_path = str(hr_path / hr_name)
|
||||
lr_image_path = str(lr_path / lr_name)
|
||||
tasks.append(benchmark_image_pair.remote(hr_image_path, lr_image_path, interpolation_function))
|
||||
|
||||
ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
|
||||
while len(remaining_refs) > 0:
|
||||
print(f"\rReady {len(ready_refs)}/{len(hr_files)}", end=" ")
|
||||
ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None)
|
||||
|
||||
for task in tasks:
|
||||
ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time = ray.get(task)
|
||||
ssims.append(ssim)
|
||||
cv2_ssims.append(cv2_ssim)
|
||||
scikit_ssims.append(ssim_scikit)
|
||||
cv2_scikit_ssims.append(cv2_scikit_ssim)
|
||||
psnrs.append(psnr)
|
||||
cpsnrs.append(cpsnr)
|
||||
cv2_psnrs.append(cv2_psnr)
|
||||
cv2_cpsnrs.append(cv2_cpsnr)
|
||||
processing_times.append(processing_time)
|
||||
|
||||
print()
|
||||
print(f"AVG PSNR: {np.mean(psnrs):.2f} PSNR + _rgb2ycbcr")
|
||||
print(f"AVG PSNR: {np.mean(cv2_psnrs):.2f} cv2.PSNR + cv2.cvtColor")
|
||||
print(f"AVG cPSNR: {np.mean(cpsnrs):.2f} PSNR")
|
||||
print(f"AVG cPSNR: {np.mean(cv2_cpsnrs):.2f} cv2.PSNR ")
|
||||
print(f"AVG SSIM: {np.mean(ssims):.4f} cal_ssim + _rgb2ycbcr")
|
||||
print(f"AVG SSIM: {np.mean(cv2_ssims):.4f} cal_ssim + cv2.cvtColor")
|
||||
print(f"AVG SSIM: {np.mean(scikit_ssims):.4f} structural_similarity + _rgb2ycbcr")
|
||||
print(f"AVG SSIM: {np.mean(cv2_scikit_ssims):.4f} structural_similarity + cv2.cvtColor")
|
||||
print(f"AVG Time s: {np.percentile(processing_times, q=0.9)}")
|
||||
print(f"{np.mean(psnrs):.2f},{np.mean(cv2_psnrs):.2f},{np.mean(cpsnrs):.2f},{np.mean(cv2_cpsnrs):.2f},{np.mean(ssims):.4f},{np.mean(cv2_ssims):.4f},{np.mean(scikit_ssims):.4f},{np.mean(cv2_scikit_ssims):.4f},{np.percentile(processing_times, q=0.9)}")
|
||||
|
||||
def cv2_interpolation(image, scale):
|
||||
scaled_image = cv2.resize(
|
||||
image,
|
||||
None, None,
|
||||
fx=scale, fy=scale,
|
||||
interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
return scaled_image
|
||||
|
||||
def pillow_interpolation(image, scale):
|
||||
image = Image.fromarray(image[:,:,::-1])
|
||||
width, height = int(image.width * scale), int(image.height * scale)
|
||||
scaled_image = image.resize((width, height), resample=Image.Resampling.BICUBIC)
|
||||
return np.array(scaled_image)[:,:,::-1]
|
||||
|
||||
print("cv2 bicubic interpolation")
|
||||
benchmark_interpolation(cv2_interpolation)
|
||||
print()
|
||||
print("pillow bicubic interpolation")
|
||||
benchmark_interpolation(pillow_interpolation)
|
@ -0,0 +1,36 @@
|
||||
from .rcnet import RCNetCentered_3x3, RCNetCentered_7x7, RCNetRot90_3x3, RCNetRot90_7x7, RCNetx1, RCNetx2
|
||||
from .rclut import RCLutCentered_3x3, RCLutCentered_7x7, RCLutRot90_3x3, RCLutRot90_7x7, RCLutx1, RCLutx2
|
||||
from .srnet import SRNet, SRNetRot90
|
||||
from .srlut import SRLut, SRLutRot90
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
AVAILABLE_MODELS = {
|
||||
'SRNet': SRNet, 'SRLut': SRLut,
|
||||
'SRNetRot90': SRNetRot90, 'SRLutRot90': SRLutRot90,
|
||||
'RCNetCentered_3x3': RCNetCentered_3x3, 'RCLutCentered_3x3': RCLutCentered_3x3,
|
||||
'RCNetCentered_7x7': RCNetCentered_7x7, 'RCLutCentered_7x7': RCLutCentered_7x7,
|
||||
'RCNetRot90_3x3': RCNetRot90_3x3, 'RCLutRot90_3x3': RCLutRot90_3x3,
|
||||
'RCNetRot90_7x7': RCNetRot90_7x7, 'RCLutRot90_7x7': RCLutRot90_7x7,
|
||||
'RCNetx1': RCNetx1, 'RCLutx1': RCLutx1,
|
||||
'RCNetx2': RCNetx2, 'RCLutx2': RCLutx2,
|
||||
}
|
||||
|
||||
def SaveCheckpoint(model, path):
|
||||
model_container = {
|
||||
'model': model.__class__.__name__,
|
||||
'state_dict': model.state_dict(),
|
||||
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
|
||||
}
|
||||
torch.save(model_container, path)
|
||||
|
||||
def LoadCheckpoint(model_path):
|
||||
model_path = Path(model_path).absolute()
|
||||
if model_path.exists():
|
||||
model_container = torch.load(model_path)
|
||||
model = AVAILABLE_MODELS[model_container['model']](**{k:v for k,v in model_container.items() if k != "model" and k != "state_dict"})
|
||||
model.load_state_dict(model_container['state_dict'], strict=True)
|
||||
return model
|
||||
else:
|
||||
raise Exception(f"Path {model_path} does not exist.")
|
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from common.lut import forward_2x2_input_SxS_output
|
||||
|
||||
class SRLut2x2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SRLut2x2, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = SRLut2x2(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut)
|
||||
x = x.view(b, c, x.shape[-2], x.shape[-1])
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
|
||||
|
||||
class SRLut3x3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SRLut3x3, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = SRLut3x3(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w)
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
|
||||
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
||||
output += unrotated_prediction
|
||||
output /= 4
|
||||
output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
@ -0,0 +1,394 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from common.utils import round_func
|
||||
from common.lut import select_index_1dlut_msb, select_index_4dlut_msb, select_index_4dlut_tetrahedral, select_index_1dlut_linear, \
|
||||
forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
|
||||
from pathlib import Path
|
||||
from einops import repeat
|
||||
|
||||
class RCLutCentered_3x3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
window_size,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(RCLutCentered_3x3, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.window_size = window_size
|
||||
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
rc_conv_luts, dense_conv_lut
|
||||
):
|
||||
scale = int(dense_conv_lut.shape[-1])
|
||||
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
||||
window_size = rc_conv_luts.shape[0]
|
||||
lut_model = RCLutCentered_3x3(window_size=window_size, quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
||||
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
|
||||
x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
|
||||
x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1]
|
||||
x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
|
||||
x = x.view(b, c, x.shape[-2], x.shape[-1])
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return "\n".join([
|
||||
f"{self.__class__.__name__}(",
|
||||
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
||||
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
||||
")"])
|
||||
|
||||
class RCLutCentered_7x7(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
window_size,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(RCLutCentered_7x7, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.window_size = window_size
|
||||
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
rc_conv_luts, dense_conv_lut
|
||||
):
|
||||
scale = int(dense_conv_lut.shape[-1])
|
||||
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
||||
window_size = rc_conv_luts.shape[0]
|
||||
lut_model = RCLutCentered_7x7(window_size=window_size, quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
||||
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
|
||||
x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
|
||||
# x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4)
|
||||
x = x.view(b, c, x.shape[-2], x.shape[-1])
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return "\n".join([
|
||||
f"{self.__class__.__name__}(",
|
||||
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
||||
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
||||
")"])
|
||||
|
||||
class RCLutRot90_3x3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(RCLutRot90_3x3, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
window_size = 3
|
||||
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
rc_conv_luts, dense_conv_lut
|
||||
):
|
||||
scale = int(dense_conv_lut.shape[-1])
|
||||
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
||||
window_size = rc_conv_luts.shape[0]
|
||||
lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
||||
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
|
||||
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
|
||||
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
||||
output += unrotated_prediction
|
||||
output /= 4
|
||||
output = output.view(b, c, output.shape[-2], output.shape[-1])
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return "\n".join([
|
||||
f"{self.__class__.__name__}(",
|
||||
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
||||
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
||||
")"])
|
||||
|
||||
class RCLutRot90_7x7(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(RCLutRot90_7x7, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
window_size = 7
|
||||
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
rc_conv_luts, dense_conv_lut
|
||||
):
|
||||
scale = int(dense_conv_lut.shape[-1])
|
||||
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
||||
window_size = rc_conv_luts.shape[0]
|
||||
lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
||||
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
|
||||
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
|
||||
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
||||
output += unrotated_prediction
|
||||
output /= 4
|
||||
output = output.view(b, c, output.shape[-2], output.shape[-1])
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return "\n".join([
|
||||
f"{self.__class__.__name__}(",
|
||||
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
||||
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
||||
")"])
|
||||
|
||||
class RCLutx1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(RCLutx1, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
window_size = 3
|
||||
self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
|
||||
window_size = 5
|
||||
self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
|
||||
window_size = 7
|
||||
self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(window_size, window_size, 256//quantization_interval+1)).type(torch.float32))
|
||||
self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
rc_conv_luts_3x3, dense_conv_lut_3x3,
|
||||
rc_conv_luts_5x5, dense_conv_lut_5x5,
|
||||
rc_conv_luts_7x7, dense_conv_lut_7x7
|
||||
):
|
||||
scale = int(dense_conv_lut_3x3.shape[-1])
|
||||
quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1)
|
||||
|
||||
lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale)
|
||||
|
||||
lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32))
|
||||
lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32))
|
||||
|
||||
lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32))
|
||||
lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32))
|
||||
|
||||
lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32))
|
||||
lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32))
|
||||
|
||||
return lut_model
|
||||
|
||||
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
||||
x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
|
||||
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
||||
print("lut:", x.min(), x.max(), x.mean())
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
output += torch.rot90(
|
||||
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3),
|
||||
k=-rotations_count,
|
||||
dims=[2, 3]
|
||||
)
|
||||
output += torch.rot90(
|
||||
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5),
|
||||
k=-rotations_count,
|
||||
dims=[2, 3]
|
||||
)
|
||||
output += torch.rot90(
|
||||
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7),
|
||||
k=-rotations_count,
|
||||
dims=[2, 3]
|
||||
)
|
||||
output /= 3*4
|
||||
output = output.view(b, c, output.shape[-2], output.shape[-1])
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return "\n".join([
|
||||
f"{self.__class__.__name__}(",
|
||||
f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}",
|
||||
f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}",
|
||||
f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}",
|
||||
f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}",
|
||||
f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}",
|
||||
f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}",
|
||||
")"])
|
||||
|
||||
|
||||
|
||||
class RCLutx2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(RCLutx2, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
||||
self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
||||
self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
||||
self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
||||
self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
||||
self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
||||
self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
||||
self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
|
||||
s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
|
||||
s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
|
||||
s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
|
||||
s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
|
||||
s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
|
||||
):
|
||||
scale = int(s2_dense_conv_lut_3x3.shape[-1])
|
||||
quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
|
||||
|
||||
lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale)
|
||||
|
||||
lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
|
||||
lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
|
||||
|
||||
lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
|
||||
lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
|
||||
|
||||
lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
|
||||
lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
|
||||
|
||||
lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
|
||||
lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
|
||||
|
||||
lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
|
||||
lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
|
||||
|
||||
lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
|
||||
lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
|
||||
|
||||
return lut_model
|
||||
|
||||
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
||||
x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
|
||||
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
output += torch.rot90(
|
||||
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
|
||||
k=-rotations_count,
|
||||
dims=[2, 3]
|
||||
)
|
||||
output += torch.rot90(
|
||||
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
|
||||
k=-rotations_count,
|
||||
dims=[2, 3]
|
||||
)
|
||||
output += torch.rot90(
|
||||
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
|
||||
k=-rotations_count,
|
||||
dims=[2, 3]
|
||||
)
|
||||
output /= 3*4
|
||||
x = output
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
output += torch.rot90(
|
||||
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
|
||||
k=-rotations_count,
|
||||
dims=[2, 3]
|
||||
)
|
||||
output += torch.rot90(
|
||||
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
|
||||
k=-rotations_count,
|
||||
dims=[2, 3]
|
||||
)
|
||||
output += torch.rot90(
|
||||
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
|
||||
k=-rotations_count,
|
||||
dims=[2, 3]
|
||||
)
|
||||
output /= 3*4
|
||||
output = output.view(b, c, output.shape[-2], output.shape[-1])
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return "\n".join([
|
||||
f"{self.__class__.__name__}(",
|
||||
f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
|
||||
f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
|
||||
f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
|
||||
f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
|
||||
f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
|
||||
f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
|
||||
f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
|
||||
f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
|
||||
f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
|
||||
f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
|
||||
f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
|
||||
f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
|
||||
")"])
|
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from common.lut import forward_2x2_input_SxS_output
|
||||
|
||||
class SRLut(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SRLut, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = SRLut(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w).type(torch.float32)
|
||||
x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut)
|
||||
x = x.view(b, c, x.shape[-2], x.shape[-1])
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
|
||||
|
||||
class SRLutRot90(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(SRLutRot90, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def init_from_lut(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = SRLutRot90(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.view(b*c, 1, h, w)
|
||||
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
|
||||
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
||||
output += unrotated_prediction
|
||||
output /= 4
|
||||
output = output.view(b, c, h*self.scale, w*self.scale)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
@ -0,0 +1,51 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path("../").resolve()) + "/")
|
||||
|
||||
from models import LoadCheckpoint
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
project_path = Path("../../").resolve()
|
||||
|
||||
start_script_time = datetime.now()
|
||||
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCNetCentered_10000.pth")
|
||||
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCLutCentered_0.pth")
|
||||
|
||||
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCNetRot90_7x7_10000.pth")
|
||||
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCLutRot90_7x7_0.pth")
|
||||
|
||||
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCNetRot90_3x3_10000.pth")
|
||||
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCLutRot90_3x3_0.pth")
|
||||
|
||||
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCNetx1_46000.pth")
|
||||
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCLutx1_0.pth")
|
||||
|
||||
net_model = LoadCheckpoint(project_path / "models" / "last_transfered_net.pth").cuda()
|
||||
lut_model = LoadCheckpoint(project_path / "models" / "last_transfered_lut.pth").cuda()
|
||||
|
||||
print(net_model)
|
||||
print(lut_model)
|
||||
|
||||
lr_image = cv2.imread(str(project_path / "data" / "Set14/LR/X4/lenna.png"))[:,:,::-1].copy()
|
||||
image_gt = cv2.imread(str(project_path / "data" / "Set14/HR/lenna.png"))[:,:,::-1].copy()
|
||||
|
||||
# lr_image = cv2.imread(str(project_path / "data" / "Synthetic/LR/X4/linear.png"))[:,:,::-1].copy()
|
||||
# image_gt = cv2.imread(str(project_path / "data" / "Synthetic/HR/linear.png"))[:,:,::-1].copy()
|
||||
|
||||
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda()
|
||||
|
||||
net_prediction = net_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy()
|
||||
lut_prediction = lut_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy()
|
||||
|
||||
|
||||
image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
|
||||
image_net = cv2.putText(net_prediction, net_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
|
||||
image_lut = cv2.putText(lut_prediction, lut_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
|
||||
|
||||
Image.fromarray(np.concatenate([image_gt, image_net, image_lut], 1)).save(project_path / "models" / 'last_transfered_demo.png')
|
||||
|
||||
print(datetime.now() - start_script_time )
|
@ -0,0 +1,199 @@
|
||||
import sys
|
||||
sys.path.insert(0, "../") # run under the project directory
|
||||
from pickle import dump
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from common.data import SRTrainDataset, SRTestDataset
|
||||
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr
|
||||
|
||||
from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS
|
||||
from common.validation import valid_steps
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
import argparse
|
||||
from schedulefree import AdamWScheduleFree
|
||||
from datetime import datetime
|
||||
|
||||
class TrainOptions:
|
||||
def __init__(self):
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False)
|
||||
parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
|
||||
parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
|
||||
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
|
||||
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
|
||||
parser.add_argument('--models_dir', type=str, default='../../models/', help="experiment folder")
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
|
||||
parser.add_argument('--datasets_dir', type=str, default="../../data/")
|
||||
parser.add_argument('--train_datasets', type=str, default='DIV2K')
|
||||
parser.add_argument('--val_datasets', type=str, default='Set5,Set14')
|
||||
parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further')
|
||||
parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations')
|
||||
parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration')
|
||||
parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration')
|
||||
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
|
||||
parser.add_argument('--worker_num', '-n', type=int, default=1)
|
||||
parser.add_argument('--prefetch_factor', '-p', type=int, default=16)
|
||||
parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
|
||||
self.parser = parser
|
||||
|
||||
def parse_args(self):
|
||||
args = self.parser.parse_args()
|
||||
args.models_dir = Path(args.models_dir)
|
||||
args.model_path = Path(args.model_path) if not args.model_path is None else None
|
||||
args.train_datasets = args.train_datasets.split(',')
|
||||
args.val_datasets = args.val_datasets.split(',')
|
||||
return args
|
||||
|
||||
def print_options(self, opt):
|
||||
message = ''
|
||||
message += '----------------- Options ---------------\n'
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ''
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = '\t[default: %s]' % str(default)
|
||||
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
||||
message += '----------------- End -------------------'
|
||||
print(message)
|
||||
print()
|
||||
|
||||
def prepare_experiment_folder(config):
|
||||
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
|
||||
assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}."
|
||||
|
||||
config.exp_dir = config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}"
|
||||
|
||||
if not config.exp_dir.exists():
|
||||
config.exp_dir.mkdir()
|
||||
|
||||
config.checkpoint_dir = config.exp_dir / "checkpoints"
|
||||
if not config.checkpoint_dir.exists():
|
||||
config.checkpoint_dir.mkdir()
|
||||
|
||||
config.valout_dir = config.exp_dir / 'val'
|
||||
if not config.valout_dir.exists():
|
||||
config.valout_dir.mkdir()
|
||||
|
||||
config.logs_dir = config.exp_dir / 'logs'
|
||||
if not config.logs_dir.exists():
|
||||
config.logs_dir.mkdir()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
script_start_time = datetime.now()
|
||||
|
||||
config_inst = TrainOptions()
|
||||
config = config_inst.parse_args()
|
||||
|
||||
if not config.model_path is None:
|
||||
model = LoadCheckpoint(config.model_path)
|
||||
config.model = model.__class__.__name__
|
||||
else:
|
||||
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale)
|
||||
# model = model.cuda()
|
||||
optimizer = AdamWScheduleFree(model.parameters())
|
||||
|
||||
prepare_experiment_folder(config)
|
||||
|
||||
# Tensorboard for monitoring
|
||||
writer = SummaryWriter(log_dir=config.logs_dir)
|
||||
logger_name = 'train'
|
||||
logger_info(logger_name, os.path.join(config.logs_dir, logger_name + '.log'))
|
||||
logger = logging.getLogger(logger_name)
|
||||
config.writer = writer
|
||||
config.logger = logger
|
||||
|
||||
config.logger.info(config_inst.print_options(config))
|
||||
print(model)
|
||||
|
||||
# Training dataset
|
||||
train_datasets = []
|
||||
for train_dataset_name in config.train_datasets:
|
||||
train_datasets.append(SRTrainDataset(
|
||||
hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR",
|
||||
lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}",
|
||||
patch_size = config.crop_size
|
||||
))
|
||||
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
|
||||
train_loader = DataLoader(
|
||||
dataset = train_dataset,
|
||||
batch_size = config.batch_size,
|
||||
num_workers = config.worker_num,
|
||||
shuffle = False,
|
||||
drop_last = False,
|
||||
pin_memory = True,
|
||||
prefetch_factor = config.prefetch_factor
|
||||
)
|
||||
train_iter = iter(train_loader)
|
||||
|
||||
test_datasets = {}
|
||||
for test_dataset_name in config.val_datasets:
|
||||
test_datasets[test_dataset_name] = SRTestDataset(
|
||||
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
|
||||
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}",
|
||||
)
|
||||
l_accum = [0., 0., 0.]
|
||||
prepare_data_time = 0.
|
||||
forward_backward_time = 0.
|
||||
accum_samples = 0
|
||||
|
||||
# TRAINING
|
||||
i = config.start_iter
|
||||
for i in range(config.start_iter + 1, config.total_iter + 1):
|
||||
torch.cuda.empty_cache()
|
||||
start_time = time.time()
|
||||
try:
|
||||
hr_patch, lr_patch = next(train_iter)
|
||||
except StopIteration:
|
||||
train_iter = iter(train_loader)
|
||||
hr_patch, lr_patch = next(train_iter)
|
||||
# hr_patch = hr_patch.cuda()
|
||||
# lr_patch = lr_patch.cuda()
|
||||
prepare_data_time += time.time() - start_time
|
||||
start_time = time.time()
|
||||
optimizer.zero_grad()
|
||||
pred = model(lr_patch)
|
||||
loss = F.mse_loss(pred/255, hr_patch/255)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
forward_backward_time += time.time() - start_time
|
||||
|
||||
# For monitoring
|
||||
accum_samples += config.batch_size
|
||||
l_accum[0] += loss.item()
|
||||
|
||||
# Show information
|
||||
if i % config.display_step == 0:
|
||||
config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i)
|
||||
|
||||
config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
|
||||
config.exp_dir, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step,
|
||||
forward_backward_time / config.display_step))
|
||||
l_accum = [0., 0., 0.]
|
||||
prepare_data_time = 0.
|
||||
forward_backward_time = 0.
|
||||
|
||||
# Save models
|
||||
if i % config.save_step == 0:
|
||||
SaveCheckpoint(model=model, path=Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth")
|
||||
|
||||
# Validation
|
||||
if i % config.val_step == 0:
|
||||
config.current_iter = i
|
||||
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
|
||||
|
||||
|
||||
|
||||
total_script_time = datetime.now() - script_start_time
|
||||
config.logger.info(f"Completed after {total_script_time}")
|
@ -0,0 +1,75 @@
|
||||
import sys
|
||||
sys.path.insert(0, "../") # run under the project directory
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
torch.backends.cudnn.benchmark = True
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
import models
|
||||
|
||||
class TransferToLutOptions():
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
self.parser.add_argument('--model_path', '-m', type=str, default='', help="model path folder")
|
||||
self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets in 2**bits. Value is in range [1, 8].")
|
||||
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")
|
||||
|
||||
def parse_args(self):
|
||||
args = self.parser.parse_args()
|
||||
args.model_path = Path(args.model_path)
|
||||
args.checkpoint_dir = Path(args.model_path).absolute().parent
|
||||
return args
|
||||
|
||||
def print_options(self, opt):
|
||||
message = ''
|
||||
message += '----------------- Options ---------------\n'
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ''
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = '\t[default: %s]' % str(default)
|
||||
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
||||
message += '----------------- End -------------------'
|
||||
print(message)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_time = datetime.now()
|
||||
print(start_time)
|
||||
config_inst = TransferToLutOptions()
|
||||
config = config_inst.parse_args()
|
||||
|
||||
config_inst.print_options(config)
|
||||
|
||||
model = models.LoadCheckpoint(config.model_path).cuda()
|
||||
print(model)
|
||||
|
||||
print()
|
||||
print("Transfering:")
|
||||
lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size)
|
||||
print()
|
||||
print(lut_model)
|
||||
|
||||
lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth"
|
||||
models.SaveCheckpoint(model=lut_model, path=lut_path)
|
||||
|
||||
lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()])
|
||||
|
||||
print()
|
||||
print(datetime.now()-start_time)
|
||||
print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB")
|
||||
|
||||
models.SaveCheckpoint(model=model, path=Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_net.pth")
|
||||
models.SaveCheckpoint(model=lut_model, path=Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_lut.pth")
|
||||
print("Updated", Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_net.pth")
|
||||
print("Updated", Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_lut.pth")
|
@ -0,0 +1,86 @@
|
||||
import sys
|
||||
sys.path.insert(0, "../") # run under the project directory
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from common.data import SRTrainDataset, SRTestDataset
|
||||
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
|
||||
from common.validation import valid_steps
|
||||
from models import LoadCheckpoint
|
||||
torch.backends.cudnn.benchmark = True
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
|
||||
class ValOptions():
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
self.parser.add_argument('--model_path', type=str, help="Model path.")
|
||||
self.parser.add_argument('--datasets_dir', type=str, default="../../data/")
|
||||
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14')
|
||||
self.parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
|
||||
|
||||
def parse_args(self):
|
||||
args = self.parser.parse_args()
|
||||
args.datasets_dir = Path(args.datasets_dir).resolve()
|
||||
args.val_datasets = args.val_datasets.split(',')
|
||||
args.exp_dir = Path(args.model_path).absolute().parent.parent
|
||||
args.model_path = Path(args.model_path)
|
||||
args.model_name = args.model_path.stem
|
||||
args.valout_dir = Path(args.exp_dir)/ 'val'
|
||||
if not args.valout_dir.exists():
|
||||
args.valout_dir.mkdir()
|
||||
args.current_iter = args.model_name.split('_')[-1]
|
||||
# Tensorboard for monitoring
|
||||
writer = SummaryWriter(log_dir=args.valout_dir)
|
||||
logger_name = f'val_{args.model_path.stem}'
|
||||
logger_info(logger_name, os.path.join(args.valout_dir, logger_name + '.log'))
|
||||
logger = logging.getLogger(logger_name)
|
||||
args.writer = writer
|
||||
args.logger = logger
|
||||
|
||||
return args
|
||||
|
||||
def print_options(self, opt):
|
||||
message = ''
|
||||
message += '----------------- Options ---------------\n'
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ''
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = '\t[default: %s]' % str(default)
|
||||
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
||||
message += '----------------- End -------------------'
|
||||
print(message)
|
||||
print()
|
||||
|
||||
|
||||
# TODO with unified save/load function any model file of net or lut can be tested with the same script.
|
||||
if __name__ == "__main__":
|
||||
config_inst = ValOptions()
|
||||
config = config_inst.parse_args()
|
||||
|
||||
config.logger.info(config_inst.print_options(config))
|
||||
|
||||
model = LoadCheckpoint(config.model_path)
|
||||
model = model.cuda()
|
||||
print(model)
|
||||
|
||||
test_datasets = {}
|
||||
for test_dataset_name in config.val_datasets:
|
||||
test_datasets[test_dataset_name] = SRTestDataset(
|
||||
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
|
||||
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
|
||||
)
|
||||
|
||||
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")
|
||||
|
||||
config.logger.info("Complete")
|
Loading…
Reference in New Issue