import sys import logging import math import os import time import numpy as np import torch import torch.nn.functional as F import torch.optim as optim from PIL import Image from pathlib import Path from torch.utils.tensorboard import SummaryWriter from torch.utils.data import Dataset, DataLoader from common.data import SRTrainDataset, SRTestDataset from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from common.validation import valid_steps from models import LoadCheckpoint torch.backends.cudnn.benchmark = True from datetime import datetime import argparse class ValOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.") self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.") self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.") self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') def parse_args(self): args = self.parser.parse_args() args.datasets_dir = Path(args.datasets_dir).resolve() args.val_datasets = args.val_datasets.split(',') args.exp_dir = Path(args.model_path).resolve().parent.parent args.model_path = Path(args.model_path).resolve() args.model_name = args.model_path.stem args.valout_dir = Path(args.exp_dir).resolve() / 'val' if not args.valout_dir.exists(): args.valout_dir.mkdir() args.current_iter = args.model_name.split('_')[-1] args.results_path = os.path.join(args.valout_dir, f'results_{args.model_name}_{args.device}.csv') # Tensorboard for monitoring writer = SummaryWriter(log_dir=args.valout_dir) logger_name = f'val_{args.model_path.stem}' logger_info(logger_name, os.path.join(args.valout_dir, logger_name + '.log')) logger = logging.getLogger(logger_name) args.writer = writer args.logger = logger return args def __repr__(self): config = self.parse_args() message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(config).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' return message # TODO with unified save/load function any model file of net or lut can be tested with the same script. if __name__ == "__main__": script_start_time = datetime.now() config_inst = ValOptions() config = config_inst.parse_args() config.logger.info(config_inst) model = LoadCheckpoint(config.model_path) model = model.to(torch.device(config.device)) print(model) test_datasets = {} for test_dataset_name in config.val_datasets: test_datasets[test_dataset_name] = SRTestDataset( hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", ) results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") results.to_csv(config.results_path) print(config.model_name) print(results) print() print(f"Results saved to {config.results_path}") total_script_time = datetime.now() - script_start_time config.logger.info(f"Completed after {total_script_time}")