From c4b7821001d969e57a813a6992efb5e7434cc74b Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Tue, 4 Jun 2024 19:45:09 +0400 Subject: [PATCH] added choice for color model --- src/common/data.py | 42 +++++++++++++++++++++------------------- src/common/utils.py | 14 +++++++++++++- src/common/validation.py | 21 +++++++++++++++----- src/image_demo.py | 19 ++++++++++-------- src/train.py | 7 ++++--- src/validate.py | 2 ++ 6 files changed, 68 insertions(+), 37 deletions(-) diff --git a/src/common/data.py b/src/common/data.py index d2e44e2..b34f5a8 100644 --- a/src/common/data.py +++ b/src/common/data.py @@ -8,15 +8,16 @@ import torch from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler from pathlib import Path +from common.utils import PIL_CONVERT_COLOR, pil2numpy image_extensions = ['.jpg', '.png'] -def load_images_cached(images_dir_path): +def load_images_cached(images_dir_path, color_model): image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions]) - cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_cache.npy" + cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_{color_model}_cache.npy" cache_path = cache_path.resolve() if not Path(cache_path).exists(): print("Caching to:", cache_path) - value = {f:np.array(Image.open(f)) for f in image_paths} + value = {f:pil2numpy(PIL_CONVERT_COLOR[color_model](Image.open(f))) for f in image_paths} np.save(cache_path, value, allow_pickle=True) else: value = np.load(cache_path, allow_pickle=True).item() @@ -24,12 +25,12 @@ def load_images_cached(images_dir_path): return list(value.keys()), list(value.values()) class SRTrainDataset(Dataset): - def __init__(self, hr_dir_path, lr_dir_path, patch_size, rigid_aug=True): + def __init__(self, hr_dir_path, lr_dir_path, patch_size, color_model = "RGB", rigid_aug=True): super(SRTrainDataset, self).__init__() self.sz = patch_size self.rigid_aug = rigid_aug - self.hr_image_names, self.hr_images = load_images_cached(hr_dir_path) - self.lr_image_names, self.lr_images = load_images_cached(lr_dir_path) + self.hr_image_names, self.hr_images = load_images_cached(hr_dir_path, color_model=color_model) + self.lr_image_names, self.lr_images = load_images_cached(lr_dir_path, color_model=color_model) assert len(self.hr_images) == len(self.lr_images) def __getitem__(self, idx): @@ -48,18 +49,19 @@ class SRTrainDataset(Dataset): i = random.randint(0, lr_image.shape[0] - self.sz) j = random.randint(0, lr_image.shape[1] - self.sz) - # c = random.choice([0, 1, 2]) - hr_patch = hr_image[ - (i*scale):(i*scale + self.sz*scale), - (j*scale):(j*scale + self.sz*scale), - : - ] - lr_patch = lr_image[ - i:(i + self.sz), - j:(j + self.sz), - : - ] + if len(hr_image.shape) == 3: + hr_patch = hr_image[ + (i*scale):(i*scale + self.sz*scale), + (j*scale):(j*scale + self.sz*scale), + : + ] + lr_patch = lr_image[ + i:(i + self.sz), + j:(j + self.sz), + : + ] + if self.rigid_aug: if random.uniform(0, 1) < 0.5: @@ -85,10 +87,10 @@ class SRTrainDataset(Dataset): return len(self.hr_images) class SRTestDataset(Dataset): - def __init__(self, hr_dir_path, lr_dir_path): + def __init__(self, hr_dir_path, lr_dir_path, color_model): super(SRTestDataset, self).__init__() - self.hr_image_paths, self.hr_images = load_images_cached(hr_dir_path) - self.lr_image_paths, self.lr_images = load_images_cached(lr_dir_path) + self.hr_image_paths, self.hr_images = load_images_cached(hr_dir_path, color_model=color_model) + self.lr_image_paths, self.lr_images = load_images_cached(lr_dir_path, color_model=color_model) assert len(self.hr_images) == len(self.lr_images) def __getitem__(self, idx): diff --git a/src/common/utils.py b/src/common/utils.py index be0811c..03c37eb 100644 --- a/src/common/utils.py +++ b/src/common/utils.py @@ -6,6 +6,18 @@ from scipy import signal import torch import os +PIL_CONVERT_COLOR = { + 'RGB': lambda pil_image: pil_image.convert("RGB") if pil_image.mode != 'RGB' else pil_image, + 'YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image, + 'Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0), + 'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image, +} + +def pil2numpy(image): + np_image = np.array(image) + if len(np_image.shape) == 2: + np_image = np_image[:,:,None] + return np_image def round_func(input): # Backward Pass Differentiable Approximation (BPDA) @@ -41,7 +53,7 @@ def modcrop(image, modulo): sz = image.shape sz = sz - np.mod(sz, modulo) image = image[0:sz[0], 0:sz[1]] - elif image.shape[2] == 3: + elif len(image.shape) == 3: sz = image.shape[0:2] sz = sz - np.mod(sz, modulo) image = image[0:sz[0], 0:sz[1], :] diff --git a/src/common/validation.py b/src/common/validation.py index f1b6fc1..aa9b0a8 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -7,7 +7,7 @@ from PIL import Image import time from datetime import timedelta -def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'): +def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'): with torch.inference_mode(): start_time = time.perf_counter_ns() # prepare lr_image @@ -22,13 +22,24 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cu torch.cuda.empty_cache() if not output_image_path is None: - Image.fromarray(pred_lr_image).save(output_image_path) + if pred_lr_image.shape[-1] == 3: + Image.fromarray(pred_lr_image, mode=color_model).save(output_image_path) + if pred_lr_image.shape[-1] == 1: + Image.fromarray(pred_lr_image[:,:,0]).save(output_image_path) # metrics hr_image = modcrop(hr_image, model.scale) - Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] + if pred_lr_image.shape[-1] == 3 and color_model == 'RGB': + Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] + if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr': + Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0] + if pred_lr_image.shape[-1] == 1: + Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0] + lr_area = np.prod(lr_image.shape[-2:]) - return PSNR(Y_left, Y_right, model.scale), cal_ssim(Y_left, Y_right), run_time_ns, lr_area + psnr = PSNR(Y_left, Y_right, model.scale) + ssim = cal_ssim(Y_left, Y_right) + return psnr, ssim, run_time_ns, lr_area def valid_steps(model, datasets, config, log_prefix=""): dataset_names = list(datasets.keys()) @@ -50,7 +61,7 @@ def valid_steps(model, datasets, config, log_prefix=""): tasks = [] for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None - task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device) + task = val_image_pair(model, hr_image, lr_image, output_image_path, color_model=config.color_model, device=config.device) tasks.append(task) total_time = time.time() - start_time diff --git a/src/image_demo.py b/src/image_demo.py index 0af73c6..f4dde28 100644 --- a/src/image_demo.py +++ b/src/image_demo.py @@ -1,6 +1,6 @@ from pathlib import Path import sys - +from common.utils import PIL_CONVERT_COLOR, pil2numpy from models import LoadCheckpoint import torch import numpy as np @@ -8,6 +8,7 @@ import cv2 from PIL import Image from datetime import datetime import argparse + class ImageDemoOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -18,6 +19,7 @@ class ImageDemoOptions(): self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--mirror', action='store_true', default=False) self.parser.add_argument('--device', default='cuda', help='Device of the model') + self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for input image.") def parse_args(self): args = self.parser.parse_args() @@ -49,25 +51,26 @@ models = [LoadCheckpoint(x).to(config.device) for x in config.model_paths] for m in models: print(m) -lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1] -image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1] +lr_image = pil2numpy(PIL_CONVERT_COLOR[config.color_model](Image.open(str(config.lr_image_path)))) +image_gt = pil2numpy(PIL_CONVERT_COLOR[config.color_model](Image.open(str(config.hr_image_path)))) +image_gt = np.concatenate([image_gt]*3, axis=-1) if image_gt.shape[-1] == 1 else image_gt + if config.mirror: - lr_image = lr_image[:,::-1,:] - image_gt = image_gt[:,::-1,:] -lr_image = lr_image.copy() -image_gt = image_gt.copy() + lr_image = lr_image[:,::-1,:].copy() + image_gt = image_gt[:,::-1,:].copy() input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].to(config.device) predictions = [] for model in models: with torch.inference_mode(): - prediction = model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() + prediction = model(input_image).cpu().type(torch.uint8).squeeze(0).permute(1,2,0).numpy().copy() predictions.append(prediction) image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) images_predicted = [] for model_path, model, prediction in zip(config.model_paths, models, predictions): + prediction = np.concatenate([prediction]*3, axis=-1) if prediction.shape[-1] == 1 else prediction prediction = cv2.putText(prediction, model_path.stem, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) images_predicted.append(prediction) diff --git a/src/train.py b/src/train.py index 4ac1940..edef717 100644 --- a/src/train.py +++ b/src/train.py @@ -47,6 +47,7 @@ class TrainOptions: parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--device', default='cuda', help='Device of the model') parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") + parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") self.parser = parser @@ -130,7 +131,8 @@ if __name__ == "__main__": train_datasets.append(SRTrainDataset( hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR", lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}", - patch_size = config.crop_size + patch_size = config.crop_size, + color_model = config.color_model )) train_dataset = torch.utils.data.ConcatDataset(train_datasets) train_loader = DataLoader( @@ -149,6 +151,7 @@ if __name__ == "__main__": test_datasets[test_dataset_name] = SRTestDataset( hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}", + color_model = config.color_model ) l_accum = [0., 0., 0.] prepare_data_time = 0. @@ -180,8 +183,6 @@ if __name__ == "__main__": loss.backward() optimizer.step() optimizer.zero_grad() - del hr_patch - del lr_patch torch.cuda.empty_cache() forward_backward_time += time.time() - start_time diff --git a/src/validate.py b/src/validate.py index f2d6d2f..5757ca0 100644 --- a/src/validate.py +++ b/src/validate.py @@ -28,6 +28,7 @@ class ValOptions(): self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.") self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') + self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") def parse_args(self): args = self.parser.parse_args() @@ -82,6 +83,7 @@ if __name__ == "__main__": test_datasets[test_dataset_name] = SRTestDataset( hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", + color_model=config.color_model ) results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")