val2test rename, info about color mode in exp folder name and test output.

main
Vladimir Protsenko 7 months ago
parent c4b7821001
commit 9992763c9f

@ -22,8 +22,10 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non
torch.cuda.empty_cache()
if not output_image_path is None:
if pred_lr_image.shape[-1] == 3:
if pred_lr_image.shape[-1] == 3 and color_model == 'RGB':
Image.fromarray(pred_lr_image, mode=color_model).save(output_image_path)
if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr':
Image.fromarray(pred_lr_image, mode=color_model).convert("RGB").save(output_image_path)
if pred_lr_image.shape[-1] == 1:
Image.fromarray(pred_lr_image[:,:,0]).save(output_image_path)
@ -61,7 +63,7 @@ def valid_steps(model, datasets, config, log_prefix=""):
tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, output_image_path, color_model=config.color_model, device=config.device)
task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
tasks.append(task)
total_time = time.time() - start_time

@ -25,7 +25,7 @@ class ValOptions():
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.")
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.")
self.parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Names of test datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
@ -33,7 +33,7 @@ class ValOptions():
def parse_args(self):
args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve()
args.val_datasets = args.val_datasets.split(',')
args.test_datasets = args.test_datasets.split(',')
args.exp_dir = Path(args.model_path).resolve().parent.parent
args.model_path = Path(args.model_path).resolve()
args.model_name = args.model_path.stem
@ -79,7 +79,7 @@ if __name__ == "__main__":
print(model)
test_datasets = {}
for test_dataset_name in config.val_datasets:
for test_dataset_name in config.test_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
@ -89,7 +89,8 @@ if __name__ == "__main__":
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")
results.to_csv(config.results_path)
print(config.exp_dir.stem, config.model_name)
print()
print(f"experiment dir: {config.exp_dir.stem}, model: {config.model_name}, test color model: {config.color_model}")
print(results)
print()
print(f"Results saved to {config.results_path}")

@ -77,7 +77,7 @@ def prepare_experiment_folder(config):
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}."
config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve()
config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve()
if not config.exp_dir.exists():
config.exp_dir.mkdir()

Loading…
Cancel
Save