From 9992763c9fbbf732c196ad174de58e043f35f452 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Tue, 4 Jun 2024 23:19:19 +0400 Subject: [PATCH] val2test rename, info about color mode in exp folder name and test output. --- src/common/validation.py | 6 ++++-- src/{validate.py => test.py} | 9 +++++---- src/train.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) rename src/{validate.py => test.py} (89%) diff --git a/src/common/validation.py b/src/common/validation.py index aa9b0a8..b5fcc7f 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -22,8 +22,10 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non torch.cuda.empty_cache() if not output_image_path is None: - if pred_lr_image.shape[-1] == 3: + if pred_lr_image.shape[-1] == 3 and color_model == 'RGB': Image.fromarray(pred_lr_image, mode=color_model).save(output_image_path) + if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr': + Image.fromarray(pred_lr_image, mode=color_model).convert("RGB").save(output_image_path) if pred_lr_image.shape[-1] == 1: Image.fromarray(pred_lr_image[:,:,0]).save(output_image_path) @@ -61,7 +63,7 @@ def valid_steps(model, datasets, config, log_prefix=""): tasks = [] for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None - task = val_image_pair(model, hr_image, lr_image, output_image_path, color_model=config.color_model, device=config.device) + task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device) tasks.append(task) total_time = time.time() - start_time diff --git a/src/validate.py b/src/test.py similarity index 89% rename from src/validate.py rename to src/test.py index 5757ca0..3e7badf 100644 --- a/src/validate.py +++ b/src/test.py @@ -25,7 +25,7 @@ class ValOptions(): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.") self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.") - self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.") + self.parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Names of test datasets.") self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') self.parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") @@ -33,7 +33,7 @@ class ValOptions(): def parse_args(self): args = self.parser.parse_args() args.datasets_dir = Path(args.datasets_dir).resolve() - args.val_datasets = args.val_datasets.split(',') + args.test_datasets = args.test_datasets.split(',') args.exp_dir = Path(args.model_path).resolve().parent.parent args.model_path = Path(args.model_path).resolve() args.model_name = args.model_path.stem @@ -79,7 +79,7 @@ if __name__ == "__main__": print(model) test_datasets = {} - for test_dataset_name in config.val_datasets: + for test_dataset_name in config.test_datasets: test_datasets[test_dataset_name] = SRTestDataset( hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", @@ -89,7 +89,8 @@ if __name__ == "__main__": results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") results.to_csv(config.results_path) - print(config.exp_dir.stem, config.model_name) + print() + print(f"experiment dir: {config.exp_dir.stem}, model: {config.model_name}, test color model: {config.color_model}") print(results) print() print(f"Results saved to {config.results_path}") diff --git a/src/train.py b/src/train.py index edef717..5aa13b2 100644 --- a/src/train.py +++ b/src/train.py @@ -77,7 +77,7 @@ def prepare_experiment_folder(config): assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}." - config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve() + config.exp_dir = (config.models_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve() if not config.exp_dir.exists(): config.exp_dir.mkdir()