From 19e0fca7455a657a966f77b1c1883b0e7e5d1e77 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Thu, 6 Jun 2024 21:13:54 +0400 Subject: [PATCH] merge --- src/common/validation.py | 13 +++++++++---- src/test.py | 3 ++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/common/validation.py b/src/common/validation.py index b5fcc7f..af50c7b 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -5,7 +5,7 @@ from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from pathlib import Path from PIL import Image import time -from datetime import timedelta +from datetime import timedelta, datetime def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'): with torch.inference_mode(): @@ -43,7 +43,7 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non ssim = cal_ssim(Y_left, Y_right) return psnr, ssim, run_time_ns, lr_area -def valid_steps(model, datasets, config, log_prefix=""): +def valid_steps(model, datasets, config, log_prefix="", print_progress = False): dataset_names = list(datasets.keys()) results = [] @@ -61,11 +61,16 @@ def valid_steps(model, datasets, config, log_prefix=""): test_dataset = datasets[dataset_name] tasks = [] - for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: + if print_progress: + start_datetime = datetime.now() + for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset): output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device) tasks.append(task) - + if print_progress: + print(f"\r{datetime.now()-start_datetime} {idx}/{len(test_dataset)} {hr_image_path}", end=" "*25) + if print_progress: + print() total_time = time.time() - start_time for psnr, ssim, run_time_ns, lr_area in tasks: diff --git a/src/test.py b/src/test.py index 3e7badf..08b266d 100644 --- a/src/test.py +++ b/src/test.py @@ -29,6 +29,7 @@ class ValOptions(): self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") + self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar') def parse_args(self): args = self.parser.parse_args() @@ -86,7 +87,7 @@ if __name__ == "__main__": color_model=config.color_model ) - results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") + results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress) results.to_csv(config.results_path) print()