main
protsenkovi 5 months ago
parent 39f830d6e5
commit e8eed1ad73

@ -0,0 +1,34 @@
import cv2
import numpy as np
from scipy import signal
import torch
import os
PIL_CONVERT_COLOR = {
'RGB': lambda pil_image: pil_image.convert("RGB") if pil_image.mode != 'RGB' else pil_image,
'full_YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image,
'full_Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
'sdtv_Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0] if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"),
'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image,
}
def _rgb2ycbcr(img, maxVal=255):
O = np.array([[16],
[128],
[128]])
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
if maxVal == 1:
O = O / 255.0
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
t = np.dot(t, np.transpose(T))
t[:, 0] += O[0]
t[:, 1] += O[1]
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr

@ -8,7 +8,8 @@ import torch
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from pathlib import Path from pathlib import Path
from common.utils import PIL_CONVERT_COLOR, pil2numpy from common.color import PIL_CONVERT_COLOR
from common.utils import pil2numpy
image_extensions = ['.jpg', '.png'] image_extensions = ['.jpg', '.png']
def load_images_cached(images_dir_path, color_model, reset_cache): def load_images_cached(images_dir_path, color_model, reset_cache):

@ -0,0 +1,46 @@
import cv2
import numpy as np
from scipy import signal
import torch
import os
def PSNR(y_true, y_pred, shave_border=4):
target_data = np.array(y_true, dtype=np.float32)
ref_data = np.array(y_pred, dtype=np.float32)
diff = ref_data - target_data
if shave_border > 0:
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
rmse = np.sqrt(np.mean(np.power(diff, 2)))
return 20 * np.log10(255. / rmse)
def cal_ssim(img1, img2):
K = [0.01, 0.03]
L = 255
kernelX = cv2.getGaussianKernel(11, 1.5)
window = kernelX * kernelX.T
M, N = np.shape(img1)
C1 = (K[0] * L) ** 2
C2 = (K[1] * L) ** 2
img1 = np.float64(img1)
img2 = np.float64(img2)
mu1 = signal.convolve2d(img1, window, 'valid')
mu2 = signal.convolve2d(img2, window, 'valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
mssim = np.mean(ssim_map)
return mssim

@ -1,7 +1,9 @@
import torch import torch
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from common.utils import logger_info, modcrop
from common.color import _rgb2ycbcr
from common.metrics import PSNR, cal_ssim
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
import time import time
@ -13,7 +15,7 @@ cmaplist = [cmap(i) for i in range(cmap.N)]
cmaplut = np.array(cmaplist) cmaplut = np.array(cmaplist)
cmaplut = np.round(cmaplut[:, 0:3]*255).astype(np.uint8) cmaplut = np.round(cmaplut[:, 0:3]*255).astype(np.uint8)
def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'): def test_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
with torch.inference_mode(): with torch.inference_mode():
start_time = time.perf_counter_ns() start_time = time.perf_counter_ns()
# prepare lr_image # prepare lr_image
@ -49,7 +51,7 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non
ssim = cal_ssim(Y_left, Y_right) ssim = cal_ssim(Y_left, Y_right)
return psnr, ssim, run_time_ns, lr_area return psnr, ssim, run_time_ns, lr_area
def valid_steps(model, datasets, config, log_prefix="", print_progress = False): def test_steps(model, datasets, config, log_prefix="", print_progress = False):
dataset_names = list(datasets.keys()) dataset_names = list(datasets.keys())
results = [] results = []
@ -61,7 +63,7 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
total_area = 0 total_area = 0
start_time = time.time() start_time = time.time()
predictions_path = config.valout_dir / dataset_name predictions_path = config.test_dir / dataset_name
if not predictions_path.exists(): if not predictions_path.exists():
predictions_path.mkdir() predictions_path.mkdir()
@ -70,8 +72,11 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
if print_progress: if print_progress:
start_datetime = datetime.now() start_datetime = datetime.now()
for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset): for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset):
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png' if config.save_predictions else None if config.save_predictions:
task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device) output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png'
else:
output_image_path = None
task = test_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
tasks.append(task) tasks.append(task)
if print_progress: if print_progress:
print(f"\r{datetime.now()-start_datetime} {idx+1}/{len(test_dataset)} {hr_image_path}", end=" "*25) print(f"\r{datetime.now()-start_datetime} {idx+1}/{len(test_dataset)} {hr_image_path}", end=" "*25)

@ -6,14 +6,6 @@ from scipy import signal
import torch import torch
import os import os
PIL_CONVERT_COLOR = {
'RGB': lambda pil_image: pil_image.convert("RGB") if pil_image.mode != 'RGB' else pil_image,
'YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image,
'Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
# 'Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0],# if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image,
}
def pil2numpy(image): def pil2numpy(image):
np_image = np.array(image) np_image = np.array(image)
if len(np_image.shape) == 2: if len(np_image.shape) == 2:
@ -42,7 +34,6 @@ def logger_info(logger_name, log_path='default_logger.log'):
fh.setFormatter(formatter) fh.setFormatter(formatter)
log.setLevel(level) log.setLevel(level)
log.addHandler(fh) log.addHandler(fh)
# print(len(log.handlers))
sh = logging.StreamHandler() sh = logging.StreamHandler()
sh.setFormatter(formatter) sh.setFormatter(formatter)
@ -61,66 +52,3 @@ def modcrop(image, modulo):
else: else:
raise NotImplementedError raise NotImplementedError
return image return image
def _rgb2ycbcr(img, maxVal=255):
O = np.array([[16],
[128],
[128]])
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
if maxVal == 1:
O = O / 255.0
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
t = np.dot(t, np.transpose(T))
t[:, 0] += O[0]
t[:, 1] += O[1]
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr
def PSNR(y_true, y_pred, shave_border=4):
target_data = np.array(y_true, dtype=np.float32)
ref_data = np.array(y_pred, dtype=np.float32)
diff = ref_data - target_data
if shave_border > 0:
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
rmse = np.sqrt(np.mean(np.power(diff, 2)))
return 20 * np.log10(255. / rmse)
def cal_ssim(img1, img2):
K = [0.01, 0.03]
L = 255
kernelX = cv2.getGaussianKernel(11, 1.5)
window = kernelX * kernelX.T
M, N = np.shape(img1)
C1 = (K[0] * L) ** 2
C2 = (K[1] * L) ** 2
img1 = np.float64(img1)
img2 = np.float64(img2)
mu1 = signal.convolve2d(img1, window, 'valid')
mu2 = signal.convolve2d(img2, window, 'valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
mssim = np.mean(ssim_map)
return mssim

@ -14,13 +14,13 @@ from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from common.validation import valid_steps from common.test import test_steps
from models import LoadCheckpoint from models import LoadCheckpoint
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
from datetime import datetime from datetime import datetime
import argparse import argparse
class ValOptions(): class TestOptions():
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.") self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
@ -39,15 +39,15 @@ class ValOptions():
args.exp_dir = Path(args.model_path).resolve().parent.parent args.exp_dir = Path(args.model_path).resolve().parent.parent
args.model_path = Path(args.model_path).resolve() args.model_path = Path(args.model_path).resolve()
args.model_name = args.model_path.stem args.model_name = args.model_path.stem
args.valout_dir = Path(args.exp_dir).resolve() / 'val' args.test_dir = Path(args.exp_dir).resolve() / 'test'
if not args.valout_dir.exists(): if not args.test_dir.exists():
args.valout_dir.mkdir() args.test_dir.mkdir()
args.current_iter = int(args.model_name.split('_')[-1]) args.current_iter = int(args.model_name.split('_')[-1])
args.results_path = os.path.join(args.valout_dir, f'results_{args.model_name}_{args.device}.csv') args.results_path = os.path.join(args.test_dir, f'results_{args.model_name}_{args.device}.csv')
# Tensorboard for monitoring # Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.valout_dir) writer = SummaryWriter(log_dir=args.test_dir)
logger_name = f'val_{args.model_path.stem}' logger_name = f'test_{args.model_path.stem}'
logger_info(logger_name, os.path.join(args.valout_dir, logger_name + '.log')) logger_info(logger_name, os.path.join(args.test_dir, logger_name + '.log'))
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
args.writer = writer args.writer = writer
args.logger = logger args.logger = logger
@ -71,7 +71,7 @@ class ValOptions():
# TODO with unified save/load function any model file of net or lut can be tested with the same script. # TODO with unified save/load function any model file of net or lut can be tested with the same script.
if __name__ == "__main__": if __name__ == "__main__":
script_start_time = datetime.now() script_start_time = datetime.now()
config_inst = ValOptions() config_inst = TestOptions()
config = config_inst.parse_args() config = config_inst.parse_args()
config.logger.info(config_inst) config.logger.info(config_inst)
@ -89,7 +89,7 @@ if __name__ == "__main__":
reset_cache=config.reset_cache, reset_cache=config.reset_cache,
) )
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress) results = test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress,)
results.to_csv(config.results_path) results.to_csv(config.results_path)
print() print()

@ -14,10 +14,12 @@ from pathlib import Path
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr from common.utils import logger_info
from common.metrics import PSNR, cal_ssim
from common.color import _rgb2ycbcr, PIL_CONVERT_COLOR
from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS
from common.validation import valid_steps from common.test import test_steps
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
import argparse import argparse
@ -30,7 +32,7 @@ class TrainOptions:
parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}") parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.") parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.") parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.")
parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.") parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.")
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor") parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers") parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers") parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers")
@ -48,7 +50,7 @@ class TrainOptions:
parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model') parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") parser.add_argument('--color_model', type=str, default="RGB", help=f"Color model for train and test dataset. Choose from: {list(PIL_CONVERT_COLOR.keys())}")
parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache') parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache')
self.parser = parser self.parser = parser
@ -59,7 +61,7 @@ class TrainOptions:
args.models_dir = Path(args.models_dir).resolve() args.models_dir = Path(args.models_dir).resolve()
args.model_path = Path(args.model_path) if not args.model_path is None else None args.model_path = Path(args.model_path) if not args.model_path is None else None
args.train_datasets = args.train_datasets.split(',') args.train_datasets = args.train_datasets.split(',')
args.val_datasets = args.val_datasets.split(',') args.test_datasets = args.test_datasets.split(',')
if not args.model_path is None: if not args.model_path is None:
args.start_iter = int(args.model_path.stem.split("_")[-1]) args.start_iter = int(args.model_path.stem.split("_")[-1])
return args return args
@ -79,7 +81,7 @@ class TrainOptions:
def prepare_experiment_folder(config): def prepare_experiment_folder(config):
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}."
config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve() config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve()
@ -90,9 +92,9 @@ def prepare_experiment_folder(config):
if not config.checkpoint_dir.exists(): if not config.checkpoint_dir.exists():
config.checkpoint_dir.mkdir() config.checkpoint_dir.mkdir()
config.valout_dir = (config.exp_dir / 'val').resolve() config.test_dir = (config.exp_dir / 'val').resolve()
if not config.valout_dir.exists(): if not config.test_dir.exists():
config.valout_dir.mkdir() config.test_dir.mkdir()
config.logs_dir = (config.exp_dir / 'logs').resolve() config.logs_dir = (config.exp_dir / 'logs').resolve()
if not config.logs_dir.exists(): if not config.logs_dir.exists():
@ -154,7 +156,7 @@ if __name__ == "__main__":
train_iter = iter(train_loader) train_iter = iter(train_loader)
test_datasets = {} test_datasets = {}
for test_dataset_name in config.val_datasets: for test_dataset_name in config.test_datasets:
test_datasets[test_dataset_name] = SRTestDataset( test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}",
@ -222,7 +224,7 @@ if __name__ == "__main__":
# Validation # Validation
if i % config.val_step == 0: if i % config.val_step == 0:
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve() model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve()
SaveCheckpoint(model=model, path=model_path) SaveCheckpoint(model=model, path=model_path)

Loading…
Cancel
Save