From fd207056f8fbf07fbab1da15033889aa13f430b3 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sat, 20 Apr 2024 22:08:02 +0000 Subject: [PATCH] update --- .gitignore | 3 +++ readme.md | 2 -- src/common/validation.py | 4 ++-- src/models/rcnet.py | 3 +-- src/scripts/train.py | 12 +++++++++--- src/scripts/validate.py | 2 +- 6 files changed, 16 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 68bc17f..2323b12 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +*.png +*.0 \ No newline at end of file diff --git a/readme.md b/readme.md index 7b9a0fb..9af4c4f 100644 --- a/readme.md +++ b/readme.md @@ -2,13 +2,11 @@ Example ``` python train.py --model SRNetRot90 - python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth python transfer_to_lut.py --model_path /wd/lut_reproduce/models/last_trained_net.pth python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000 - python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth python image_demo.py diff --git a/src/common/validation.py b/src/common/validation.py index 7bea7f5..3de8f03 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -44,7 +44,7 @@ def valid_steps(model, datasets, config, log_prefix=""): test_dataset = datasets[dataset_name] tasks = [] for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: - output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_val_predictions else None + output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_predictions else None task = val_image_pair(model, hr_image, lr_image, output_image_path) tasks.append(task) @@ -61,5 +61,5 @@ def valid_steps(model, datasets, config, log_prefix=""): config.logger.info( '\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims)))) - config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter) + # config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter) config.writer.flush() \ No newline at end of file diff --git a/src/models/rcnet.py b/src/models/rcnet.py index 1f93478..3c8dd32 100644 --- a/src/models/rcnet.py +++ b/src/models/rcnet.py @@ -86,8 +86,7 @@ class RCNetCentered_3x3(nn.Module): self.hidden_dim = hidden_dim self.layers_count = layers_count self.scale = scale - window_size = 3 - self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size) + self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3) def forward(self, x): b,c,h,w = x.shape diff --git a/src/scripts/train.py b/src/scripts/train.py index 2981631..4d1dbec 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -44,7 +44,7 @@ class TrainOptions: parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration') parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers") parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") - parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') + parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') self.parser = parser def parse_args(self): @@ -207,9 +207,15 @@ if __name__ == "__main__": # check if it is network or lut if hasattr(model, 'get_lut_model'): - Path(config.models_dir / f"last_trained_net.pth").symlink_to(model_path) + link = Path(config.models_dir / f"last_trained_net.pth") + if link.exists(): + link.unlink() + link.symlink_to(model_path) else: - Path(config.models_dir / f"last_trained_lut.pth").symlink_to(model_path) + link = Path(config.models_dir / f"last_trained_lut.pth") + if link.exists(): + link.unlink() + link.symlink_to(model_path) total_script_time = datetime.now() - script_start_time config.logger.info(f"Completed after {total_script_time}") \ No newline at end of file diff --git a/src/scripts/validate.py b/src/scripts/validate.py index acde257..396aab0 100644 --- a/src/scripts/validate.py +++ b/src/scripts/validate.py @@ -82,7 +82,7 @@ if __name__ == "__main__": lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", ) - valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") + valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}") total_script_time = datetime.now() - script_start_time config.logger.info(f"Completed after {total_script_time}") \ No newline at end of file