main
protsenkovi 5 months ago
parent 39f830d6e5
commit e8eed1ad73

@ -0,0 +1,34 @@
import cv2
import numpy as np
from scipy import signal
import torch
import os
PIL_CONVERT_COLOR = {
'RGB': lambda pil_image: pil_image.convert("RGB") if pil_image.mode != 'RGB' else pil_image,
'full_YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image,
'full_Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
'sdtv_Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0] if pil_image.mode == 'RGB' else NotImplementedError(f"{pil_image.mode} to Y"),
'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image,
}
def _rgb2ycbcr(img, maxVal=255):
O = np.array([[16],
[128],
[128]])
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
if maxVal == 1:
O = O / 255.0
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
t = np.dot(t, np.transpose(T))
t[:, 0] += O[0]
t[:, 1] += O[1]
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr

@ -8,7 +8,8 @@ import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
from common.utils import PIL_CONVERT_COLOR, pil2numpy
from common.color import PIL_CONVERT_COLOR
from common.utils import pil2numpy
image_extensions = ['.jpg', '.png']
def load_images_cached(images_dir_path, color_model, reset_cache):

@ -0,0 +1,46 @@
import cv2
import numpy as np
from scipy import signal
import torch
import os
def PSNR(y_true, y_pred, shave_border=4):
target_data = np.array(y_true, dtype=np.float32)
ref_data = np.array(y_pred, dtype=np.float32)
diff = ref_data - target_data
if shave_border > 0:
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
rmse = np.sqrt(np.mean(np.power(diff, 2)))
return 20 * np.log10(255. / rmse)
def cal_ssim(img1, img2):
K = [0.01, 0.03]
L = 255
kernelX = cv2.getGaussianKernel(11, 1.5)
window = kernelX * kernelX.T
M, N = np.shape(img1)
C1 = (K[0] * L) ** 2
C2 = (K[1] * L) ** 2
img1 = np.float64(img1)
img2 = np.float64(img2)
mu1 = signal.convolve2d(img1, window, 'valid')
mu2 = signal.convolve2d(img2, window, 'valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
mssim = np.mean(ssim_map)
return mssim

@ -1,7 +1,9 @@
import torch
import pandas as pd
import numpy as np
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from common.utils import logger_info, modcrop
from common.color import _rgb2ycbcr
from common.metrics import PSNR, cal_ssim
from pathlib import Path
from PIL import Image
import time
@ -13,7 +15,7 @@ cmaplist = [cmap(i) for i in range(cmap.N)]
cmaplut = np.array(cmaplist)
cmaplut = np.round(cmaplut[:, 0:3]*255).astype(np.uint8)
def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
def test_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
with torch.inference_mode():
start_time = time.perf_counter_ns()
# prepare lr_image
@ -49,7 +51,7 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non
ssim = cal_ssim(Y_left, Y_right)
return psnr, ssim, run_time_ns, lr_area
def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
def test_steps(model, datasets, config, log_prefix="", print_progress = False):
dataset_names = list(datasets.keys())
results = []
@ -61,7 +63,7 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
total_area = 0
start_time = time.time()
predictions_path = config.valout_dir / dataset_name
predictions_path = config.test_dir / dataset_name
if not predictions_path.exists():
predictions_path.mkdir()
@ -70,8 +72,11 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
if print_progress:
start_datetime = datetime.now()
for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset):
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
if config.save_predictions:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png'
else:
output_image_path = None
task = test_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
tasks.append(task)
if print_progress:
print(f"\r{datetime.now()-start_datetime} {idx+1}/{len(test_dataset)} {hr_image_path}", end=" "*25)

@ -6,14 +6,6 @@ from scipy import signal
import torch
import os
PIL_CONVERT_COLOR = {
'RGB': lambda pil_image: pil_image.convert("RGB") if pil_image.mode != 'RGB' else pil_image,
'YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image,
'Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
# 'Y': lambda pil_image: _rgb2ycbcr(np.array(pil_image))[:,:,0],# if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image,
}
def pil2numpy(image):
np_image = np.array(image)
if len(np_image.shape) == 2:
@ -42,7 +34,6 @@ def logger_info(logger_name, log_path='default_logger.log'):
fh.setFormatter(formatter)
log.setLevel(level)
log.addHandler(fh)
# print(len(log.handlers))
sh = logging.StreamHandler()
sh.setFormatter(formatter)
@ -61,66 +52,3 @@ def modcrop(image, modulo):
else:
raise NotImplementedError
return image
def _rgb2ycbcr(img, maxVal=255):
O = np.array([[16],
[128],
[128]])
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
if maxVal == 1:
O = O / 255.0
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
t = np.dot(t, np.transpose(T))
t[:, 0] += O[0]
t[:, 1] += O[1]
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr
def PSNR(y_true, y_pred, shave_border=4):
target_data = np.array(y_true, dtype=np.float32)
ref_data = np.array(y_pred, dtype=np.float32)
diff = ref_data - target_data
if shave_border > 0:
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
rmse = np.sqrt(np.mean(np.power(diff, 2)))
return 20 * np.log10(255. / rmse)
def cal_ssim(img1, img2):
K = [0.01, 0.03]
L = 255
kernelX = cv2.getGaussianKernel(11, 1.5)
window = kernelX * kernelX.T
M, N = np.shape(img1)
C1 = (K[0] * L) ** 2
C2 = (K[1] * L) ** 2
img1 = np.float64(img1)
img2 = np.float64(img2)
mu1 = signal.convolve2d(img1, window, 'valid')
mu2 = signal.convolve2d(img2, window, 'valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
mssim = np.mean(ssim_map)
return mssim

@ -14,13 +14,13 @@ from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from common.validation import valid_steps
from common.test import test_steps
from models import LoadCheckpoint
torch.backends.cudnn.benchmark = True
from datetime import datetime
import argparse
class ValOptions():
class TestOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
@ -39,15 +39,15 @@ class ValOptions():
args.exp_dir = Path(args.model_path).resolve().parent.parent
args.model_path = Path(args.model_path).resolve()
args.model_name = args.model_path.stem
args.valout_dir = Path(args.exp_dir).resolve() / 'val'
if not args.valout_dir.exists():
args.valout_dir.mkdir()
args.test_dir = Path(args.exp_dir).resolve() / 'test'
if not args.test_dir.exists():
args.test_dir.mkdir()
args.current_iter = int(args.model_name.split('_')[-1])
args.results_path = os.path.join(args.valout_dir, f'results_{args.model_name}_{args.device}.csv')
args.results_path = os.path.join(args.test_dir, f'results_{args.model_name}_{args.device}.csv')
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.valout_dir)
logger_name = f'val_{args.model_path.stem}'
logger_info(logger_name, os.path.join(args.valout_dir, logger_name + '.log'))
writer = SummaryWriter(log_dir=args.test_dir)
logger_name = f'test_{args.model_path.stem}'
logger_info(logger_name, os.path.join(args.test_dir, logger_name + '.log'))
logger = logging.getLogger(logger_name)
args.writer = writer
args.logger = logger
@ -71,7 +71,7 @@ class ValOptions():
# TODO with unified save/load function any model file of net or lut can be tested with the same script.
if __name__ == "__main__":
script_start_time = datetime.now()
config_inst = ValOptions()
config_inst = TestOptions()
config = config_inst.parse_args()
config.logger.info(config_inst)
@ -89,7 +89,7 @@ if __name__ == "__main__":
reset_cache=config.reset_cache,
)
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress)
results = test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress,)
results.to_csv(config.results_path)
print()

@ -14,10 +14,12 @@ from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr
from common.utils import logger_info
from common.metrics import PSNR, cal_ssim
from common.color import _rgb2ycbcr, PIL_CONVERT_COLOR
from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS
from common.validation import valid_steps
from common.test import test_steps
torch.backends.cudnn.benchmark = True
import argparse
@ -30,7 +32,7 @@ class TrainOptions:
parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.")
parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.")
parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.")
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers")
@ -48,7 +50,7 @@ class TrainOptions:
parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
parser.add_argument('--color_model', type=str, default="RGB", help=f"Color model for train and test dataset. Choose from: {list(PIL_CONVERT_COLOR.keys())}")
parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache')
self.parser = parser
@ -59,7 +61,7 @@ class TrainOptions:
args.models_dir = Path(args.models_dir).resolve()
args.model_path = Path(args.model_path) if not args.model_path is None else None
args.train_datasets = args.train_datasets.split(',')
args.val_datasets = args.val_datasets.split(',')
args.test_datasets = args.test_datasets.split(',')
if not args.model_path is None:
args.start_iter = int(args.model_path.stem.split("_")[-1])
return args
@ -79,7 +81,7 @@ class TrainOptions:
def prepare_experiment_folder(config):
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}."
config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve()
@ -90,9 +92,9 @@ def prepare_experiment_folder(config):
if not config.checkpoint_dir.exists():
config.checkpoint_dir.mkdir()
config.valout_dir = (config.exp_dir / 'val').resolve()
if not config.valout_dir.exists():
config.valout_dir.mkdir()
config.test_dir = (config.exp_dir / 'val').resolve()
if not config.test_dir.exists():
config.test_dir.mkdir()
config.logs_dir = (config.exp_dir / 'logs').resolve()
if not config.logs_dir.exists():
@ -154,7 +156,7 @@ if __name__ == "__main__":
train_iter = iter(train_loader)
test_datasets = {}
for test_dataset_name in config.val_datasets:
for test_dataset_name in config.test_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}",
@ -222,7 +224,7 @@ if __name__ == "__main__":
# Validation
if i % config.val_step == 0:
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve()
SaveCheckpoint(model=model, path=model_path)

Loading…
Cancel
Save