You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
3.9 KiB
Python

7 months ago
import sys
6 months ago
7 months ago
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from common.validation import valid_steps
from models import LoadCheckpoint
torch.backends.cudnn.benchmark = True
from datetime import datetime
import argparse
class ValOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
6 months ago
self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
6 months ago
self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.")
7 months ago
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.")
7 months ago
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
6 months ago
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
7 months ago
def parse_args(self):
args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve()
args.val_datasets = args.val_datasets.split(',')
6 months ago
args.exp_dir = Path(args.model_path).resolve().parent.parent
args.model_path = Path(args.model_path).resolve()
7 months ago
args.model_name = args.model_path.stem
6 months ago
args.valout_dir = Path(args.exp_dir).resolve() / 'val'
7 months ago
if not args.valout_dir.exists():
args.valout_dir.mkdir()
args.current_iter = args.model_name.split('_')[-1]
6 months ago
args.results_path = os.path.join(args.valout_dir, f'results_{args.model_name}_{args.device}.csv')
7 months ago
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.valout_dir)
logger_name = f'val_{args.model_path.stem}'
logger_info(logger_name, os.path.join(args.valout_dir, logger_name + '.log'))
logger = logging.getLogger(logger_name)
args.writer = writer
args.logger = logger
return args
7 months ago
def __repr__(self):
6 months ago
config = self.parse_args()
7 months ago
message = ''
message += '----------------- Options ---------------\n'
7 months ago
for k, v in sorted(vars(config).items()):
7 months ago
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
7 months ago
return message
7 months ago
# TODO with unified save/load function any model file of net or lut can be tested with the same script.
if __name__ == "__main__":
7 months ago
script_start_time = datetime.now()
7 months ago
config_inst = ValOptions()
config = config_inst.parse_args()
7 months ago
config.logger.info(config_inst)
7 months ago
model = LoadCheckpoint(config.model_path)
6 months ago
model = model.to(torch.device(config.device))
7 months ago
print(model)
test_datasets = {}
for test_dataset_name in config.val_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
)
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")
6 months ago
results.to_csv(config.results_path)
6 months ago
print(config.exp_dir.stem)
6 months ago
print(results)
print()
print(f"Results saved to {config.results_path}")
7 months ago
7 months ago
total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}")