You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
3.2 KiB
Python

import torch
import pandas as pd
import numpy as np
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from pathlib import Path
from PIL import Image
import time
def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'):
with torch.inference_mode():
start_time = time.perf_counter_ns()
# prepare lr_image
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
lr_image = lr_image.unsqueeze(0).to(torch.device(device))
# predict
pred_lr_image = model(lr_image)
# postprocess
pred_lr_image = pred_lr_image.squeeze(0).permute(1,2,0).type(torch.uint8)
pred_lr_image = pred_lr_image.cpu().numpy()
run_time_ns = time.perf_counter_ns() - start_time
torch.cuda.empty_cache()
if not output_image_path is None:
Image.fromarray(pred_lr_image).save(output_image_path)
# metrics
hr_image = modcrop(hr_image, model.scale)
Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
lr_area = np.prod(lr_image.shape[-2:])
return PSNR(Y_left, Y_right, model.scale), cal_ssim(Y_left, Y_right), run_time_ns, lr_area
def valid_steps(model, datasets, config, log_prefix=""):
dataset_names = list(datasets.keys())
results = []
for i in range(len(dataset_names)):
dataset_name = dataset_names[i]
psnrs, ssims = [], []
run_times_ns = []
lr_areas = []
total_area = 0
start_time = time.time()
predictions_path = config.valout_dir / dataset_name
if not predictions_path.exists():
predictions_path.mkdir()
test_dataset = datasets[dataset_name]
tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device)
tasks.append(task)
total_time = time.time() - start_time
for psnr, ssim, run_time_ns, lr_area in tasks:
psnrs.append(psnr)
ssims.append(ssim)
run_times_ns.append(run_time_ns)
lr_areas.append(lr_area)
total_area += lr_area
row = [
dataset_name,
np.mean(psnrs),
np.mean(ssims),
np.mean(run_times_ns)*1e-9,
np.percentile(run_times_ns, q=95)*1e-9,
len(test_dataset),
np.mean(lr_areas),
total_area,
total_time
]
results.append(row)
column_names = [
'Dataset',
'AVG PSNR',
'AVG SSIM',
f'AVG {config.device} time, s',
f'P95 {config.device} time, s',
'Image count',
'AVG image area',
'Total area',
'Total time, s'
]
config.logger.info("\n" + str(pd.DataFrame([row], columns=column_names).set_index('Dataset').T))
config.writer.flush()
results = pd.DataFrame(results, columns=column_names).set_index('Dataset')
return results