import sys import logging import math import os import time import numpy as np import torch import torch.nn.functional as F import torch.optim as optim from PIL import Image from pathlib import Path from torch.utils.tensorboard import SummaryWriter from torch.utils.data import Dataset, DataLoader from common.data import SRTrainDataset, SRTestDataset from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from common.test import test_steps from models import LoadCheckpoint torch.backends.cudnn.benchmark = True from datetime import datetime import argparse class TestOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.") self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.") self.parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Names of test datasets.") self.parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") self.parser.add_argument('--progress', type=bool, default=True, help='Show progres bar') self.parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache') def parse_args(self): args = self.parser.parse_args() args.datasets_dir = Path(args.datasets_dir).resolve() args.test_datasets = args.test_datasets.split(',') args.exp_dir = Path(args.model_path).resolve().parent.parent args.model_path = Path(args.model_path).resolve() args.model_name = args.model_path.stem args.test_dir = Path(args.exp_dir).resolve() / 'test' if not args.test_dir.exists(): args.test_dir.mkdir() args.current_iter = int(args.model_name.split('_')[-1]) args.results_path = os.path.join(args.test_dir, f'results_{args.model_name}_{args.device}.csv') # Tensorboard for monitoring writer = SummaryWriter(log_dir=args.test_dir) logger_name = f'test_{args.model_path.stem}' logger_info(logger_name, os.path.join(args.test_dir, logger_name + '.log')) logger = logging.getLogger(logger_name) args.writer = writer args.logger = logger return args def __repr__(self): config = self.parse_args() message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(config).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' return message # TODO with unified save/load function any model file of net or lut can be tested with the same script. if __name__ == "__main__": script_start_time = datetime.now() config_inst = TestOptions() config = config_inst.parse_args() config.logger.info(config_inst) model = LoadCheckpoint(config.model_path) model = model.to(torch.device(config.device)) print(model) test_datasets = {} for test_dataset_name in config.test_datasets: test_datasets[test_dataset_name] = SRTestDataset( hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", color_model=config.color_model, reset_cache=config.reset_cache, ) results = test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}", print_progress=config.progress,) results.to_csv(config.results_path) print() print(f"experiment dir: {config.exp_dir.stem}, model: {config.model_name}, test color model: {config.color_model}") print(results) print() print(f"Results saved to {config.results_path}") total_script_time = datetime.now() - script_start_time config.logger.info(f"Completed after {total_script_time}")