You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
124 lines
4.9 KiB
Python
124 lines
4.9 KiB
Python
import torch
|
|
import pandas as pd
|
|
import numpy as np
|
|
from common.utils import logger_info, modcrop
|
|
from common.color import _rgb2ycbcr
|
|
from common.metrics import PSNR, cal_ssim
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
import time
|
|
from datetime import timedelta, datetime
|
|
from matplotlib import pyplot as plt
|
|
|
|
cmap = plt.get_cmap('viridis')
|
|
cmaplist = [cmap(i) for i in range(cmap.N)]
|
|
cmaplut = np.array(cmaplist)
|
|
cmaplut = np.round(cmaplut[:, 0:3]*255).astype(np.uint8)
|
|
|
|
def test_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
|
|
with torch.inference_mode():
|
|
start_time = time.perf_counter_ns()
|
|
# prepare lr_image
|
|
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
|
|
lr_image = lr_image.unsqueeze(0).to(torch.device(device))
|
|
# predict
|
|
pred_lr_image = model(lr_image)
|
|
# postprocess
|
|
pred_lr_image = pred_lr_image.squeeze(0).permute(1,2,0).type(torch.uint8)
|
|
pred_lr_image = pred_lr_image.cpu().numpy()
|
|
run_time_ns = time.perf_counter_ns() - start_time
|
|
torch.cuda.empty_cache()
|
|
|
|
if not output_image_path is None:
|
|
if pred_lr_image.shape[-1] == 3 and color_model == 'RGB':
|
|
Image.fromarray(pred_lr_image, mode=color_model).save(output_image_path)
|
|
if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr':
|
|
Image.fromarray(pred_lr_image, mode=color_model).convert("RGB").save(output_image_path)
|
|
if pred_lr_image.shape[-1] == 1:
|
|
Image.fromarray(cmaplut[pred_lr_image[:,:,0]]).save(output_image_path)
|
|
|
|
# metrics
|
|
hr_image = modcrop(hr_image, model.scale)
|
|
if pred_lr_image.shape[-1] == 3 and color_model == 'RGB':
|
|
Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
|
|
if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr':
|
|
Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0]
|
|
if pred_lr_image.shape[-1] == 1:
|
|
Y_left, Y_right = pred_lr_image[:, :, 0], hr_image[:, :, 0]
|
|
|
|
lr_area = np.prod(lr_image.shape[-2:])
|
|
psnr = PSNR(Y_left, Y_right, model.scale)
|
|
ssim = cal_ssim(Y_left, Y_right)
|
|
return psnr, ssim, run_time_ns, lr_area
|
|
|
|
def test_steps(model, datasets, config, log_prefix="", print_progress = False):
|
|
dataset_names = list(datasets.keys())
|
|
|
|
results = []
|
|
for i in range(len(dataset_names)):
|
|
dataset_name = dataset_names[i]
|
|
psnrs, ssims = [], []
|
|
run_times_ns = []
|
|
lr_areas = []
|
|
total_area = 0
|
|
start_time = time.time()
|
|
|
|
predictions_path = config.test_dir / dataset_name
|
|
if not predictions_path.exists():
|
|
predictions_path.mkdir()
|
|
|
|
test_dataset = datasets[dataset_name]
|
|
tasks = []
|
|
if print_progress:
|
|
start_datetime = datetime.now()
|
|
for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset):
|
|
if config.save_predictions:
|
|
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png'
|
|
else:
|
|
output_image_path = None
|
|
task = test_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
|
|
tasks.append(task)
|
|
if print_progress:
|
|
print(f"\r{datetime.now()-start_datetime} {idx+1}/{len(test_dataset)} {hr_image_path}", end=" "*25)
|
|
if print_progress:
|
|
print()
|
|
total_time = time.time() - start_time
|
|
|
|
for psnr, ssim, run_time_ns, lr_area in tasks:
|
|
psnrs.append(psnr)
|
|
ssims.append(ssim)
|
|
run_times_ns.append(run_time_ns)
|
|
lr_areas.append(lr_area)
|
|
total_area += lr_area
|
|
|
|
row = [
|
|
f"{dataset_name} {config.color_model}",
|
|
np.mean(psnrs),
|
|
np.mean(ssims),
|
|
np.mean(run_times_ns)*1e-9,
|
|
np.percentile(run_times_ns, q=95)*1e-9,
|
|
len(test_dataset),
|
|
np.mean(lr_areas),
|
|
total_area,
|
|
timedelta(seconds=total_time)
|
|
]
|
|
results.append(row)
|
|
column_names = [
|
|
'Dataset',
|
|
'AVG PSNR',
|
|
'AVG SSIM',
|
|
f'AVG {config.device} time, s',
|
|
f'P95 {config.device} time, s',
|
|
'Image count',
|
|
'AVG image area',
|
|
'Total area',
|
|
'Total time'
|
|
]
|
|
config.logger.info("\n" + str(pd.DataFrame([row], columns=column_names).set_index('Dataset').T))
|
|
config.writer.add_scalar(f'{dataset_name}_PSNR', np.mean(psnrs), config.current_iter)
|
|
config.writer.add_scalar(f'{dataset_name}_SSIM', np.mean(ssims), config.current_iter)
|
|
config.writer.flush()
|
|
|
|
results = pd.DataFrame(results, columns=column_names).set_index('Dataset')
|
|
|
|
return results |