added choice for color model

main
Vladimir Protsenko 7 months ago
parent f2eea32363
commit c4b7821001

@ -8,15 +8,16 @@ import torch
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from pathlib import Path from pathlib import Path
from common.utils import PIL_CONVERT_COLOR, pil2numpy
image_extensions = ['.jpg', '.png'] image_extensions = ['.jpg', '.png']
def load_images_cached(images_dir_path): def load_images_cached(images_dir_path, color_model):
image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions]) image_paths = sorted([f for f in Path(images_dir_path).glob("*") if f.suffix.lower() in image_extensions])
cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_cache.npy" cache_path = Path(images_dir_path).parent / f"{Path(images_dir_path).stem}_{color_model}_cache.npy"
cache_path = cache_path.resolve() cache_path = cache_path.resolve()
if not Path(cache_path).exists(): if not Path(cache_path).exists():
print("Caching to:", cache_path) print("Caching to:", cache_path)
value = {f:np.array(Image.open(f)) for f in image_paths} value = {f:pil2numpy(PIL_CONVERT_COLOR[color_model](Image.open(f))) for f in image_paths}
np.save(cache_path, value, allow_pickle=True) np.save(cache_path, value, allow_pickle=True)
else: else:
value = np.load(cache_path, allow_pickle=True).item() value = np.load(cache_path, allow_pickle=True).item()
@ -24,12 +25,12 @@ def load_images_cached(images_dir_path):
return list(value.keys()), list(value.values()) return list(value.keys()), list(value.values())
class SRTrainDataset(Dataset): class SRTrainDataset(Dataset):
def __init__(self, hr_dir_path, lr_dir_path, patch_size, rigid_aug=True): def __init__(self, hr_dir_path, lr_dir_path, patch_size, color_model = "RGB", rigid_aug=True):
super(SRTrainDataset, self).__init__() super(SRTrainDataset, self).__init__()
self.sz = patch_size self.sz = patch_size
self.rigid_aug = rigid_aug self.rigid_aug = rigid_aug
self.hr_image_names, self.hr_images = load_images_cached(hr_dir_path) self.hr_image_names, self.hr_images = load_images_cached(hr_dir_path, color_model=color_model)
self.lr_image_names, self.lr_images = load_images_cached(lr_dir_path) self.lr_image_names, self.lr_images = load_images_cached(lr_dir_path, color_model=color_model)
assert len(self.hr_images) == len(self.lr_images) assert len(self.hr_images) == len(self.lr_images)
def __getitem__(self, idx): def __getitem__(self, idx):
@ -48,8 +49,8 @@ class SRTrainDataset(Dataset):
i = random.randint(0, lr_image.shape[0] - self.sz) i = random.randint(0, lr_image.shape[0] - self.sz)
j = random.randint(0, lr_image.shape[1] - self.sz) j = random.randint(0, lr_image.shape[1] - self.sz)
# c = random.choice([0, 1, 2])
if len(hr_image.shape) == 3:
hr_patch = hr_image[ hr_patch = hr_image[
(i*scale):(i*scale + self.sz*scale), (i*scale):(i*scale + self.sz*scale),
(j*scale):(j*scale + self.sz*scale), (j*scale):(j*scale + self.sz*scale),
@ -61,6 +62,7 @@ class SRTrainDataset(Dataset):
: :
] ]
if self.rigid_aug: if self.rigid_aug:
if random.uniform(0, 1) < 0.5: if random.uniform(0, 1) < 0.5:
hr_patch = np.fliplr(hr_patch) hr_patch = np.fliplr(hr_patch)
@ -85,10 +87,10 @@ class SRTrainDataset(Dataset):
return len(self.hr_images) return len(self.hr_images)
class SRTestDataset(Dataset): class SRTestDataset(Dataset):
def __init__(self, hr_dir_path, lr_dir_path): def __init__(self, hr_dir_path, lr_dir_path, color_model):
super(SRTestDataset, self).__init__() super(SRTestDataset, self).__init__()
self.hr_image_paths, self.hr_images = load_images_cached(hr_dir_path) self.hr_image_paths, self.hr_images = load_images_cached(hr_dir_path, color_model=color_model)
self.lr_image_paths, self.lr_images = load_images_cached(lr_dir_path) self.lr_image_paths, self.lr_images = load_images_cached(lr_dir_path, color_model=color_model)
assert len(self.hr_images) == len(self.lr_images) assert len(self.hr_images) == len(self.lr_images)
def __getitem__(self, idx): def __getitem__(self, idx):

@ -6,6 +6,18 @@ from scipy import signal
import torch import torch
import os import os
PIL_CONVERT_COLOR = {
'RGB': lambda pil_image: pil_image.convert("RGB") if pil_image.mode != 'RGB' else pil_image,
'YCbCr': lambda pil_image: pil_image.convert("YCbCr") if pil_image.mode != 'YCbCr' else pil_image,
'Y': lambda pil_image: pil_image.convert("YCbCr").getchannel(0) if pil_image.mode != 'YCbCr' else pil_image.getchannel(0),
'L': lambda pil_image: pil_image.convert("L") if pil_image.mode != 'L' else pil_image,
}
def pil2numpy(image):
np_image = np.array(image)
if len(np_image.shape) == 2:
np_image = np_image[:,:,None]
return np_image
def round_func(input): def round_func(input):
# Backward Pass Differentiable Approximation (BPDA) # Backward Pass Differentiable Approximation (BPDA)
@ -41,7 +53,7 @@ def modcrop(image, modulo):
sz = image.shape sz = image.shape
sz = sz - np.mod(sz, modulo) sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1]] image = image[0:sz[0], 0:sz[1]]
elif image.shape[2] == 3: elif len(image.shape) == 3:
sz = image.shape[0:2] sz = image.shape[0:2]
sz = sz - np.mod(sz, modulo) sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1], :] image = image[0:sz[0], 0:sz[1], :]

@ -7,7 +7,7 @@ from PIL import Image
import time import time
from datetime import timedelta from datetime import timedelta
def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'): def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
with torch.inference_mode(): with torch.inference_mode():
start_time = time.perf_counter_ns() start_time = time.perf_counter_ns()
# prepare lr_image # prepare lr_image
@ -22,13 +22,24 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cu
torch.cuda.empty_cache() torch.cuda.empty_cache()
if not output_image_path is None: if not output_image_path is None:
Image.fromarray(pred_lr_image).save(output_image_path) if pred_lr_image.shape[-1] == 3:
Image.fromarray(pred_lr_image, mode=color_model).save(output_image_path)
if pred_lr_image.shape[-1] == 1:
Image.fromarray(pred_lr_image[:,:,0]).save(output_image_path)
# metrics # metrics
hr_image = modcrop(hr_image, model.scale) hr_image = modcrop(hr_image, model.scale)
if pred_lr_image.shape[-1] == 3 and color_model == 'RGB':
Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr':
Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0]
if pred_lr_image.shape[-1] == 1:
Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0]
lr_area = np.prod(lr_image.shape[-2:]) lr_area = np.prod(lr_image.shape[-2:])
return PSNR(Y_left, Y_right, model.scale), cal_ssim(Y_left, Y_right), run_time_ns, lr_area psnr = PSNR(Y_left, Y_right, model.scale)
ssim = cal_ssim(Y_left, Y_right)
return psnr, ssim, run_time_ns, lr_area
def valid_steps(model, datasets, config, log_prefix=""): def valid_steps(model, datasets, config, log_prefix=""):
dataset_names = list(datasets.keys()) dataset_names = list(datasets.keys())
@ -50,7 +61,7 @@ def valid_steps(model, datasets, config, log_prefix=""):
tasks = [] tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device) task = val_image_pair(model, hr_image, lr_image, output_image_path, color_model=config.color_model, device=config.device)
tasks.append(task) tasks.append(task)
total_time = time.time() - start_time total_time = time.time() - start_time

@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
import sys import sys
from common.utils import PIL_CONVERT_COLOR, pil2numpy
from models import LoadCheckpoint from models import LoadCheckpoint
import torch import torch
import numpy as np import numpy as np
@ -8,6 +8,7 @@ import cv2
from PIL import Image from PIL import Image
from datetime import datetime from datetime import datetime
import argparse import argparse
class ImageDemoOptions(): class ImageDemoOptions():
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@ -18,6 +19,7 @@ class ImageDemoOptions():
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False) self.parser.add_argument('--mirror', action='store_true', default=False)
self.parser.add_argument('--device', default='cuda', help='Device of the model') self.parser.add_argument('--device', default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for input image.")
def parse_args(self): def parse_args(self):
args = self.parser.parse_args() args = self.parser.parse_args()
@ -49,25 +51,26 @@ models = [LoadCheckpoint(x).to(config.device) for x in config.model_paths]
for m in models: for m in models:
print(m) print(m)
lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1] lr_image = pil2numpy(PIL_CONVERT_COLOR[config.color_model](Image.open(str(config.lr_image_path))))
image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1] image_gt = pil2numpy(PIL_CONVERT_COLOR[config.color_model](Image.open(str(config.hr_image_path))))
image_gt = np.concatenate([image_gt]*3, axis=-1) if image_gt.shape[-1] == 1 else image_gt
if config.mirror: if config.mirror:
lr_image = lr_image[:,::-1,:] lr_image = lr_image[:,::-1,:].copy()
image_gt = image_gt[:,::-1,:] image_gt = image_gt[:,::-1,:].copy()
lr_image = lr_image.copy()
image_gt = image_gt.copy()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].to(config.device) input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].to(config.device)
predictions = [] predictions = []
for model in models: for model in models:
with torch.inference_mode(): with torch.inference_mode():
prediction = model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() prediction = model(input_image).cpu().type(torch.uint8).squeeze(0).permute(1,2,0).numpy().copy()
predictions.append(prediction) predictions.append(prediction)
image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
images_predicted = [] images_predicted = []
for model_path, model, prediction in zip(config.model_paths, models, predictions): for model_path, model, prediction in zip(config.model_paths, models, predictions):
prediction = np.concatenate([prediction]*3, axis=-1) if prediction.shape[-1] == 1 else prediction
prediction = cv2.putText(prediction, model_path.stem, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) prediction = cv2.putText(prediction, model_path.stem, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
images_predicted.append(prediction) images_predicted.append(prediction)

@ -47,6 +47,7 @@ class TrainOptions:
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model') parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
self.parser = parser self.parser = parser
@ -130,7 +131,8 @@ if __name__ == "__main__":
train_datasets.append(SRTrainDataset( train_datasets.append(SRTrainDataset(
hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR", hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}", lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}",
patch_size = config.crop_size patch_size = config.crop_size,
color_model = config.color_model
)) ))
train_dataset = torch.utils.data.ConcatDataset(train_datasets) train_dataset = torch.utils.data.ConcatDataset(train_datasets)
train_loader = DataLoader( train_loader = DataLoader(
@ -149,6 +151,7 @@ if __name__ == "__main__":
test_datasets[test_dataset_name] = SRTestDataset( test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}",
color_model = config.color_model
) )
l_accum = [0., 0., 0.] l_accum = [0., 0., 0.]
prepare_data_time = 0. prepare_data_time = 0.
@ -180,8 +183,6 @@ if __name__ == "__main__":
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
del hr_patch
del lr_patch
torch.cuda.empty_cache() torch.cuda.empty_cache()
forward_backward_time += time.time() - start_time forward_backward_time += time.time() - start_time

@ -28,6 +28,7 @@ class ValOptions():
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.") self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
def parse_args(self): def parse_args(self):
args = self.parser.parse_args() args = self.parser.parse_args()
@ -82,6 +83,7 @@ if __name__ == "__main__":
test_datasets[test_dataset_name] = SRTestDataset( test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
color_model=config.color_model
) )
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")

Loading…
Cancel
Save