diff --git a/src/common/color.py b/src/common/color.py new file mode 100644 index 0000000..b06a928 --- /dev/null +++ b/src/common/color.py @@ -0,0 +1,34 @@ +import cv2 +import numpy as np +from scipy import signal +import torch +import os + + +PIL_CONVERT_COLOR = { + 'RGB': lambda pil_image: pil_image.convert("RGB") if pil_image.mode != 'RGB' else pil_image, + 'full_YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image, + 'full_Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0), + 'sdtv_Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0] if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"), + 'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image, +} + +def _rgb2ycbcr(img, maxVal=255): + O = np.array([[16], + [128], + [128]]) + T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941], + [-0.148223529411765, -0.290992156862745, 0.439215686274510], + [0.439215686274510, -0.367788235294118, -0.071427450980392]]) + + if maxVal == 1: + O = O / 255.0 + + t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2])) + t = np.dot(t, np.transpose(T)) + t[:, 0] += O[0] + t[:, 1] += O[1] + t[:, 2] += O[2] + ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]]) + + return ycbcr \ No newline at end of file diff --git a/src/common/data.py b/src/common/data.py index 01c6059..789e2ca 100644 --- a/src/common/data.py +++ b/src/common/data.py @@ -8,7 +8,8 @@ import torch from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler from pathlib import Path -from common.utils import PIL_CONVERT_COLOR, pil2numpy +from common.color import PIL_CONVERT_COLOR +from common.utils import pil2numpy image_extensions = ['.jpg', '.png'] def load_images_cached(images_dir_path, color_model, reset_cache): diff --git a/src/common/metrics.py b/src/common/metrics.py new file mode 100644 index 0000000..53997bb --- /dev/null +++ b/src/common/metrics.py @@ -0,0 +1,46 @@ +import cv2 +import numpy as np +from scipy import signal +import torch +import os + + +def PSNR(y_true, y_pred, shave_border=4): + target_data = np.array(y_true, dtype=np.float32) + ref_data = np.array(y_pred, dtype=np.float32) + + diff = ref_data - target_data + if shave_border > 0: + diff = diff[shave_border:-shave_border, shave_border:-shave_border] + rmse = np.sqrt(np.mean(np.power(diff, 2))) + + return 20 * np.log10(255. / rmse) + + +def cal_ssim(img1, img2): + K = [0.01, 0.03] + L = 255 + kernelX = cv2.getGaussianKernel(11, 1.5) + window = kernelX * kernelX.T + + M, N = np.shape(img1) + + C1 = (K[0] * L) ** 2 + C2 = (K[1] * L) ** 2 + img1 = np.float64(img1) + img2 = np.float64(img2) + + mu1 = signal.convolve2d(img1, window, 'valid') + mu2 = signal.convolve2d(img2, window, 'valid') + + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + + sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq + sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq + sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + mssim = np.mean(ssim_map) + return mssim diff --git a/src/common/validation.py b/src/common/test.py similarity index 85% rename from src/common/validation.py rename to src/common/test.py index 8fc989e..569ca07 100644 --- a/src/common/validation.py +++ b/src/common/test.py @@ -1,7 +1,9 @@ import torch import pandas as pd import numpy as np -from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop +from common.utils import logger_info, modcrop +from common.color import _rgb2ycbcr +from common.metrics import PSNR, cal_ssim from pathlib import Path from PIL import Image import time @@ -13,7 +15,7 @@ cmaplist = [cmap(i) for i in range(cmap.N)] cmaplut = np.array(cmaplist) cmaplut = np.round(cmaplut[:, 0:3]*255).astype(np.uint8) -def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'): +def test_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'): with torch.inference_mode(): start_time = time.perf_counter_ns() # prepare lr_image @@ -49,7 +51,7 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non ssim = cal_ssim(Y_left, Y_right) return psnr, ssim, run_time_ns, lr_area -def valid_steps(model, datasets, config, log_prefix="", print_progress = False): +def test_steps(model, datasets, config, log_prefix="", print_progress = False): dataset_names = list(datasets.keys()) results = [] @@ -61,7 +63,7 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False): total_area = 0 start_time = time.time() - predictions_path = config.valout_dir / dataset_name + predictions_path = config.test_dir / dataset_name if not predictions_path.exists(): predictions_path.mkdir() @@ -70,8 +72,11 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False): if print_progress: start_datetime = datetime.now() for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset): - output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png' if config.save_predictions else None - task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device) + if config.save_predictions: + output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png' + else: + output_image_path = None + task = test_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device) tasks.append(task) if print_progress: print(f"\r{datetime.now()-start_datetime} {idx+1}/{len(test_dataset)} {hr_image_path}", end=" "*25) diff --git a/src/common/utils.py b/src/common/utils.py index 1f8ef27..a7514d7 100644 --- a/src/common/utils.py +++ b/src/common/utils.py @@ -6,14 +6,6 @@ from scipy import signal import torch import os -PIL_CONVERT_COLOR = { - 'RGB': lambda pil_image: pil_image.convert("RGB") if pil_image.mode != 'RGB' else pil_image, - 'YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image, - 'Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0), - # 'Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0],# if pil_image.mode != 'YCbCr' else pil_image.getchannel(0), - 'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image, -} - def pil2numpy(image): np_image = np.array(image) if len(np_image.shape) == 2: @@ -42,7 +34,6 @@ def logger_info(logger_name, log_path='default_logger.log'): fh.setFormatter(formatter) log.setLevel(level) log.addHandler(fh) - # print(len(log.handlers)) sh = logging.StreamHandler() sh.setFormatter(formatter) @@ -60,67 +51,4 @@ def modcrop(image, modulo): image = image[0:sz[0], 0:sz[1], :] else: raise NotImplementedError - return image - - -def _rgb2ycbcr(img, maxVal=255): - O = np.array([[16], - [128], - [128]]) - T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941], - [-0.148223529411765, -0.290992156862745, 0.439215686274510], - [0.439215686274510, -0.367788235294118, -0.071427450980392]]) - - if maxVal == 1: - O = O / 255.0 - - t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2])) - t = np.dot(t, np.transpose(T)) - t[:, 0] += O[0] - t[:, 1] += O[1] - t[:, 2] += O[2] - ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]]) - - return ycbcr - - -def PSNR(y_true, y_pred, shave_border=4): - target_data = np.array(y_true, dtype=np.float32) - ref_data = np.array(y_pred, dtype=np.float32) - - diff = ref_data - target_data - if shave_border > 0: - diff = diff[shave_border:-shave_border, shave_border:-shave_border] - rmse = np.sqrt(np.mean(np.power(diff, 2))) - - return 20 * np.log10(255. / rmse) - - -def cal_ssim(img1, img2): - K = [0.01, 0.03] - L = 255 - kernelX = cv2.getGaussianKernel(11, 1.5) - window = kernelX * kernelX.T - - M, N = np.shape(img1) - - C1 = (K[0] * L) ** 2 - C2 = (K[1] * L) ** 2 - img1 = np.float64(img1) - img2 = np.float64(img2) - - mu1 = signal.convolve2d(img1, window, 'valid') - mu2 = signal.convolve2d(img2, window, 'valid') - - mu1_sq = mu1 * mu1 - mu2_sq = mu2 * mu2 - mu1_mu2 = mu1 * mu2 - - sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq - sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq - sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2 - - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) - mssim = np.mean(ssim_map) - return mssim - + return image \ No newline at end of file diff --git a/src/test.py b/src/test.py index f0cd138..7cdb134 100644 --- a/src/test.py +++ b/src/test.py @@ -14,13 +14,13 @@ from torch.utils.tensorboard import SummaryWriter from torch.utils.data import Dataset, DataLoader from common.data import SRTrainDataset, SRTestDataset from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop -from common.validation import valid_steps +from common.test import test_steps from models import LoadCheckpoint torch.backends.cudnn.benchmark = True from datetime import datetime import argparse -class ValOptions(): +class TestOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.") @@ -39,15 +39,15 @@ class ValOptions(): args.exp_dir = Path(args.model_path).resolve().parent.parent args.model_path = Path(args.model_path).resolve() args.model_name = args.model_path.stem - args.valout_dir = Path(args.exp_dir).resolve() / 'val' - if not args.valout_dir.exists(): - args.valout_dir.mkdir() + args.test_dir = Path(args.exp_dir).resolve() / 'test' + if not args.test_dir.exists(): + args.test_dir.mkdir() args.current_iter = int(args.model_name.split('_')[-1]) - args.results_path = os.path.join(args.valout_dir, f'results_{args.model_name}_{args.device}.csv') + args.results_path = os.path.join(args.test_dir, f'results_{args.model_name}_{args.device}.csv') # Tensorboard for monitoring - writer = SummaryWriter(log_dir=args.valout_dir) - logger_name = f'val_{args.model_path.stem}' - logger_info(logger_name, os.path.join(args.valout_dir, logger_name + '.log')) + writer = SummaryWriter(log_dir=args.test_dir) + logger_name = f'test_{args.model_path.stem}' + logger_info(logger_name, os.path.join(args.test_dir, logger_name + '.log')) logger = logging.getLogger(logger_name) args.writer = writer args.logger = logger @@ -71,7 +71,7 @@ class ValOptions(): # TODO with unified save/load function any model file of net or lut can be tested with the same script. if __name__ == "__main__": script_start_time = datetime.now() - config_inst = ValOptions() + config_inst = TestOptions() config = config_inst.parse_args() config.logger.info(config_inst) @@ -89,7 +89,7 @@ if __name__ == "__main__": reset_cache=config.reset_cache, ) - results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress) + results = test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress,) results.to_csv(config.results_path) print() diff --git a/src/train.py b/src/train.py index 4103f31..0a3ea38 100644 --- a/src/train.py +++ b/src/train.py @@ -14,10 +14,12 @@ from pathlib import Path from torch.utils.tensorboard import SummaryWriter from torch.utils.data import Dataset, DataLoader from common.data import SRTrainDataset, SRTestDataset -from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr +from common.utils import logger_info +from common.metrics import PSNR, cal_ssim +from common.color import _rgb2ycbcr, PIL_CONVERT_COLOR from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS -from common.validation import valid_steps +from common.test import test_steps torch.backends.cudnn.benchmark = True import argparse @@ -30,7 +32,7 @@ class TrainOptions: parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}") parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.") parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.") - parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.") + parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.") parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor") parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers") parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers") @@ -48,7 +50,7 @@ class TrainOptions: parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--device', default='cuda', help='Device of the model') parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") - parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") + parser.add_argument('--color_model', type=str, default="RGB", help=f"Color model for train and test dataset. Choose from: {list(PIL_CONVERT_COLOR.keys())}") parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache') self.parser = parser @@ -59,7 +61,7 @@ class TrainOptions: args.models_dir = Path(args.models_dir).resolve() args.model_path = Path(args.model_path) if not args.model_path is None else None args.train_datasets = args.train_datasets.split(',') - args.val_datasets = args.val_datasets.split(',') + args.test_datasets = args.test_datasets.split(',') if not args.model_path is None: args.start_iter = int(args.model_path.stem.split("_")[-1]) return args @@ -79,7 +81,7 @@ class TrainOptions: def prepare_experiment_folder(config): assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." - assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}." + assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}." config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve() @@ -90,9 +92,9 @@ def prepare_experiment_folder(config): if not config.checkpoint_dir.exists(): config.checkpoint_dir.mkdir() - config.valout_dir = (config.exp_dir / 'val').resolve() - if not config.valout_dir.exists(): - config.valout_dir.mkdir() + config.test_dir = (config.exp_dir / 'val').resolve() + if not config.test_dir.exists(): + config.test_dir.mkdir() config.logs_dir = (config.exp_dir / 'logs').resolve() if not config.logs_dir.exists(): @@ -154,7 +156,7 @@ if __name__ == "__main__": train_iter = iter(train_loader) test_datasets = {} - for test_dataset_name in config.val_datasets: + for test_dataset_name in config.test_datasets: test_datasets[test_dataset_name] = SRTestDataset( hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}", @@ -222,7 +224,7 @@ if __name__ == "__main__": # Validation if i % config.val_step == 0: - valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") + test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve() SaveCheckpoint(model=model, path=model_path)