import torch import pandas as pd import numpy as np from common.utils import logger_info, modcrop from common.color import _rgb2ycbcr from common.metrics import PSNR, cal_ssim from pathlib import Path from PIL import Image import time from datetime import timedelta, datetime from matplotlib import pyplot as plt cmap = plt.get_cmap('viridis') cmaplist = [cmap(i) for i in range(cmap.N)] cmaplut = np.array(cmaplist) cmaplut = np.round(cmaplut[:, 0:3]*255).astype(np.uint8) def test_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'): with torch.inference_mode(): start_time = time.perf_counter_ns() # prepare lr_image lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1) lr_image = lr_image.unsqueeze(0).to(torch.device(device)) # predict pred_lr_image = model(lr_image) # postprocess pred_lr_image = pred_lr_image.squeeze(0).permute(1,2,0).type(torch.uint8) pred_lr_image = pred_lr_image.cpu().numpy() run_time_ns = time.perf_counter_ns() - start_time torch.cuda.empty_cache() if not output_image_path is None: if pred_lr_image.shape[-1] == 3 and color_model == 'RGB': Image.fromarray(pred_lr_image, mode=color_model).save(output_image_path) if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr': Image.fromarray(pred_lr_image, mode=color_model).convert("RGB").save(output_image_path) if pred_lr_image.shape[-1] == 1: Image.fromarray(cmaplut[pred_lr_image[:,:,0]]).save(output_image_path) # metrics hr_image = modcrop(hr_image, model.scale) if pred_lr_image.shape[-1] == 3 and color_model == 'RGB': Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr': Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0] if pred_lr_image.shape[-1] == 1: Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0] lr_area = np.prod(lr_image.shape[-2:]) psnr = PSNR(Y_left, Y_right, model.scale) ssim = cal_ssim(Y_left, Y_right) return psnr, ssim, run_time_ns, lr_area def test_steps(model, datasets, config, log_prefix="", print_progress = False): dataset_names = list(datasets.keys()) results = [] for i in range(len(dataset_names)): dataset_name = dataset_names[i] psnrs, ssims = [], [] run_times_ns = [] lr_areas = [] total_area = 0 start_time = time.time() predictions_path = config.test_dir / dataset_name if not predictions_path.exists(): predictions_path.mkdir() test_dataset = datasets[dataset_name] tasks = [] if print_progress: start_datetime = datetime.now() for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset): if config.save_predictions: output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png' else: output_image_path = None task = test_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device) tasks.append(task) if print_progress: print(f"\r{datetime.now()-start_datetime} {idx+1}/{len(test_dataset)} {hr_image_path}", end=" "*25) if print_progress: print() total_time = time.time() - start_time for psnr, ssim, run_time_ns, lr_area in tasks: psnrs.append(psnr) ssims.append(ssim) run_times_ns.append(run_time_ns) lr_areas.append(lr_area) total_area += lr_area row = [ f"{dataset_name} {config.color_model}", np.mean(psnrs), np.mean(ssims), np.mean(run_times_ns)*1e-9, np.percentile(run_times_ns, q=95)*1e-9, len(test_dataset), np.mean(lr_areas), total_area, timedelta(seconds=total_time) ] results.append(row) column_names = [ 'Dataset', 'AVG PSNR', 'AVG SSIM', f'AVG {config.device} time, s', f'P95 {config.device} time, s', 'Image count', 'AVG image area', 'Total area', 'Total time' ] config.logger.info("\n" + str(pd.DataFrame([row], columns=column_names).set_index('Dataset').T)) config.writer.add_scalar(f'{dataset_name}_PSNR', np.mean(psnrs), config.current_iter) config.writer.add_scalar(f'{dataset_name}_SSIM', np.mean(ssims), config.current_iter) config.writer.flush() results = pd.DataFrame(results, columns=column_names).set_index('Dataset') return results