You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

98 lines
4.1 KiB
Python

import sys
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from common.validation import valid_steps
from models import LoadCheckpoint
torch.backends.cudnn.benchmark = True
from datetime import datetime
import argparse
class ValOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.")
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
def parse_args(self):
args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve()
args.val_datasets = args.val_datasets.split(',')
args.exp_dir = Path(args.model_path).resolve().parent.parent
args.model_path = Path(args.model_path).resolve()
args.model_name = args.model_path.stem
args.valout_dir = Path(args.exp_dir).resolve() / 'val'
if not args.valout_dir.exists():
args.valout_dir.mkdir()
args.current_iter = args.model_name.split('_')[-1]
args.results_path = os.path.join(args.valout_dir, f'results_{args.model_name}_{args.device}.csv')
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.valout_dir)
logger_name = f'val_{args.model_path.stem}'
logger_info(logger_name, os.path.join(args.valout_dir, logger_name + '.log'))
logger = logging.getLogger(logger_name)
args.writer = writer
args.logger = logger
return args
def __repr__(self):
config = self.parse_args()
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
return message
# TODO with unified save/load function any model file of net or lut can be tested with the same script.
if __name__ == "__main__":
script_start_time = datetime.now()
config_inst = ValOptions()
config = config_inst.parse_args()
config.logger.info(config_inst)
model = LoadCheckpoint(config.model_path)
model = model.to(torch.device(config.device))
print(model)
test_datasets = {}
for test_dataset_name in config.val_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
color_model=config.color_model
)
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")
results.to_csv(config.results_path)
print(config.exp_dir.stem, config.model_name)
print(results)
print()
print(f"Results saved to {config.results_path}")
total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}")