main
Vladimir Protsenko 5 months ago
parent ca53d81a48
commit 19e0fca745

@ -5,7 +5,7 @@ from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
import time import time
from datetime import timedelta from datetime import timedelta, datetime
def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'): def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
with torch.inference_mode(): with torch.inference_mode():
@ -43,7 +43,7 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non
ssim = cal_ssim(Y_left, Y_right) ssim = cal_ssim(Y_left, Y_right)
return psnr, ssim, run_time_ns, lr_area return psnr, ssim, run_time_ns, lr_area
def valid_steps(model, datasets, config, log_prefix=""): def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
dataset_names = list(datasets.keys()) dataset_names = list(datasets.keys())
results = [] results = []
@ -61,11 +61,16 @@ def valid_steps(model, datasets, config, log_prefix=""):
test_dataset = datasets[dataset_name] test_dataset = datasets[dataset_name]
tasks = [] tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: if print_progress:
start_datetime = datetime.now()
for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset):
output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device) task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
tasks.append(task) tasks.append(task)
if print_progress:
print(f"\r{datetime.now()-start_datetime} {idx}/{len(test_dataset)} {hr_image_path}", end=" "*25)
if print_progress:
print()
total_time = time.time() - start_time total_time = time.time() - start_time
for psnr, ssim, run_time_ns, lr_area in tasks: for psnr, ssim, run_time_ns, lr_area in tasks:

@ -29,6 +29,7 @@ class ValOptions():
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar')
def parse_args(self): def parse_args(self):
args = self.parser.parse_args() args = self.parser.parse_args()
@ -86,7 +87,7 @@ if __name__ == "__main__":
color_model=config.color_model color_model=config.color_model
) )
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress)
results.to_csv(config.results_path) results.to_csv(config.results_path)
print() print()

Loading…
Cancel
Save