added choice for color model

main
Vladimir Protsenko 5 months ago
parent f2eea32363
commit c4b7821001

@ -8,15 +8,16 @@ import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
from common.utils import PIL_CONVERT_COLOR, pil2numpy
image_extensions = ['.jpg', '.png']
def load_images_cached(images_dir_path):
def load_images_cached(images_dir_path, color_model):
image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions])
cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_cache.npy"
cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_{color_model}_cache.npy"
cache_path = cache_path.resolve()
if not Path(cache_path).exists():
print("Caching to:", cache_path)
value = {f:np.array(Image.open(f)) for f in image_paths}
value = {f:pil2numpy(PIL_CONVERT_COLOR[color_model](Image.open(f))) for f in image_paths}
np.save(cache_path, value, allow_pickle=True)
else:
value = np.load(cache_path, allow_pickle=True).item()
@ -24,12 +25,12 @@ def load_images_cached(images_dir_path):
return list(value.keys()), list(value.values())
class SRTrainDataset(Dataset):
def __init__(self, hr_dir_path, lr_dir_path, patch_size, rigid_aug=True):
def __init__(self, hr_dir_path, lr_dir_path, patch_size, color_model = "RGB", rigid_aug=True):
super(SRTrainDataset, self).__init__()
self.sz = patch_size
self.rigid_aug = rigid_aug
self.hr_image_names, self.hr_images = load_images_cached(hr_dir_path)
self.lr_image_names, self.lr_images = load_images_cached(lr_dir_path)
self.hr_image_names, self.hr_images = load_images_cached(hr_dir_path, color_model=color_model)
self.lr_image_names, self.lr_images = load_images_cached(lr_dir_path, color_model=color_model)
assert len(self.hr_images) == len(self.lr_images)
def __getitem__(self, idx):
@ -48,18 +49,19 @@ class SRTrainDataset(Dataset):
i = random.randint(0, lr_image.shape[0] - self.sz)
j = random.randint(0, lr_image.shape[1] - self.sz)
# c = random.choice([0, 1, 2])
hr_patch = hr_image[
(i*scale):(i*scale + self.sz*scale),
(j*scale):(j*scale + self.sz*scale),
:
]
lr_patch = lr_image[
i:(i + self.sz),
j:(j + self.sz),
:
]
if len(hr_image.shape) == 3:
hr_patch = hr_image[
(i*scale):(i*scale + self.sz*scale),
(j*scale):(j*scale + self.sz*scale),
:
]
lr_patch = lr_image[
i:(i + self.sz),
j:(j + self.sz),
:
]
if self.rigid_aug:
if random.uniform(0, 1) < 0.5:
@ -85,10 +87,10 @@ class SRTrainDataset(Dataset):
return len(self.hr_images)
class SRTestDataset(Dataset):
def __init__(self, hr_dir_path, lr_dir_path):
def __init__(self, hr_dir_path, lr_dir_path, color_model):
super(SRTestDataset, self).__init__()
self.hr_image_paths, self.hr_images = load_images_cached(hr_dir_path)
self.lr_image_paths, self.lr_images = load_images_cached(lr_dir_path)
self.hr_image_paths, self.hr_images = load_images_cached(hr_dir_path, color_model=color_model)
self.lr_image_paths, self.lr_images = load_images_cached(lr_dir_path, color_model=color_model)
assert len(self.hr_images) == len(self.lr_images)
def __getitem__(self, idx):

@ -6,6 +6,18 @@ from scipy import signal
import torch
import os
PIL_CONVERT_COLOR = {
'RGB': lambda pil_image: pil_image.convert("RGB") if pil_image.mode != 'RGB' else pil_image,
'YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image,
'Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image,
}
def pil2numpy(image):
np_image = np.array(image)
if len(np_image.shape) == 2:
np_image = np_image[:,:,None]
return np_image
def round_func(input):
# Backward Pass Differentiable Approximation (BPDA)
@ -41,7 +53,7 @@ def modcrop(image, modulo):
sz = image.shape
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1]]
elif image.shape[2] == 3:
elif len(image.shape) == 3:
sz = image.shape[0:2]
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1], :]

@ -7,7 +7,7 @@ from PIL import Image
import time
from datetime import timedelta
def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'):
def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
with torch.inference_mode():
start_time = time.perf_counter_ns()
# prepare lr_image
@ -22,13 +22,24 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cu
torch.cuda.empty_cache()
if not output_image_path is None:
Image.fromarray(pred_lr_image).save(output_image_path)
if pred_lr_image.shape[-1] == 3:
Image.fromarray(pred_lr_image, mode=color_model).save(output_image_path)
if pred_lr_image.shape[-1] == 1:
Image.fromarray(pred_lr_image[:,:,0]).save(output_image_path)
# metrics
hr_image = modcrop(hr_image, model.scale)
Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
if pred_lr_image.shape[-1] == 3 and color_model == 'RGB':
Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr':
Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0]
if pred_lr_image.shape[-1] == 1:
Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0]
lr_area = np.prod(lr_image.shape[-2:])
return PSNR(Y_left, Y_right, model.scale), cal_ssim(Y_left, Y_right), run_time_ns, lr_area
psnr = PSNR(Y_left, Y_right, model.scale)
ssim = cal_ssim(Y_left, Y_right)
return psnr, ssim, run_time_ns, lr_area
def valid_steps(model, datasets, config, log_prefix=""):
dataset_names = list(datasets.keys())
@ -50,7 +61,7 @@ def valid_steps(model, datasets, config, log_prefix=""):
tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device)
task = val_image_pair(model, hr_image, lr_image, output_image_path, color_model=config.color_model, device=config.device)
tasks.append(task)
total_time = time.time() - start_time

@ -1,6 +1,6 @@
from pathlib import Path
import sys
from common.utils import PIL_CONVERT_COLOR, pil2numpy
from models import LoadCheckpoint
import torch
import numpy as np
@ -8,6 +8,7 @@ import cv2
from PIL import Image
from datetime import datetime
import argparse
class ImageDemoOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@ -18,6 +19,7 @@ class ImageDemoOptions():
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False)
self.parser.add_argument('--device', default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for input image.")
def parse_args(self):
args = self.parser.parse_args()
@ -49,25 +51,26 @@ models = [LoadCheckpoint(x).to(config.device) for x in config.model_paths]
for m in models:
print(m)
lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1]
image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1]
lr_image = pil2numpy(PIL_CONVERT_COLOR[config.color_model](Image.open(str(config.lr_image_path))))
image_gt = pil2numpy(PIL_CONVERT_COLOR[config.color_model](Image.open(str(config.hr_image_path))))
image_gt = np.concatenate([image_gt]*3, axis=-1) if image_gt.shape[-1] == 1 else image_gt
if config.mirror:
lr_image = lr_image[:,::-1,:]
image_gt = image_gt[:,::-1,:]
lr_image = lr_image.copy()
image_gt = image_gt.copy()
lr_image = lr_image[:,::-1,:].copy()
image_gt = image_gt[:,::-1,:].copy()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].to(config.device)
predictions = []
for model in models:
with torch.inference_mode():
prediction = model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy()
prediction = model(input_image).cpu().type(torch.uint8).squeeze(0).permute(1,2,0).numpy().copy()
predictions.append(prediction)
image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
images_predicted = []
for model_path, model, prediction in zip(config.model_paths, models, predictions):
prediction = np.concatenate([prediction]*3, axis=-1) if prediction.shape[-1] == 1 else prediction
prediction = cv2.putText(prediction, model_path.stem, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
images_predicted.append(prediction)

@ -47,6 +47,7 @@ class TrainOptions:
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
self.parser = parser
@ -130,7 +131,8 @@ if __name__ == "__main__":
train_datasets.append(SRTrainDataset(
hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}",
patch_size = config.crop_size
patch_size = config.crop_size,
color_model = config.color_model
))
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
train_loader = DataLoader(
@ -149,6 +151,7 @@ if __name__ == "__main__":
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}",
color_model = config.color_model
)
l_accum = [0., 0., 0.]
prepare_data_time = 0.
@ -180,8 +183,6 @@ if __name__ == "__main__":
loss.backward()
optimizer.step()
optimizer.zero_grad()
del hr_patch
del lr_patch
torch.cuda.empty_cache()
forward_backward_time += time.time() - start_time

@ -28,6 +28,7 @@ class ValOptions():
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
def parse_args(self):
args = self.parser.parse_args()
@ -82,6 +83,7 @@ if __name__ == "__main__":
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
color_model=config.color_model
)
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")

Loading…
Cancel
Save