main
Vladimir Protsenko 7 months ago
parent ca53d81a48
commit 19e0fca745

@ -5,7 +5,7 @@ from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from pathlib import Path
from PIL import Image
import time
from datetime import timedelta
from datetime import timedelta, datetime
def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
with torch.inference_mode():
@ -43,7 +43,7 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non
ssim = cal_ssim(Y_left, Y_right)
return psnr, ssim, run_time_ns, lr_area
def valid_steps(model, datasets, config, log_prefix=""):
def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
dataset_names = list(datasets.keys())
results = []
@ -61,11 +61,16 @@ def valid_steps(model, datasets, config, log_prefix=""):
test_dataset = datasets[dataset_name]
tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
if print_progress:
start_datetime = datetime.now()
for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset):
output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
tasks.append(task)
if print_progress:
print(f"\r{datetime.now()-start_datetime} {idx}/{len(test_dataset)} {hr_image_path}", end=" "*25)
if print_progress:
print()
total_time = time.time() - start_time
for psnr, ssim, run_time_ns, lr_area in tasks:

@ -29,6 +29,7 @@ class ValOptions():
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar')
def parse_args(self):
args = self.parser.parse_args()
@ -86,7 +87,7 @@ if __name__ == "__main__":
color_model=config.color_model
)
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress)
results.to_csv(config.results_path)
print()

Loading…
Cancel
Save