From f2eea3236354690cc83f37d682b8bd110acc6003 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Mon, 3 Jun 2024 11:59:27 +0400 Subject: [PATCH] added device to image_demo. added scale to name of experiment. fix bug srnet again. --- src/common/validation.py | 2 +- src/image_demo.py | 11 ++++++----- src/models/srnet.py | 2 +- src/train.py | 6 ++++-- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/common/validation.py b/src/common/validation.py index 1c50d6e..f1b6fc1 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -20,7 +20,7 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cu pred_lr_image = pred_lr_image.cpu().numpy() run_time_ns = time.perf_counter_ns() - start_time torch.cuda.empty_cache() - + if not output_image_path is None: Image.fromarray(pred_lr_image).save(output_image_path) diff --git a/src/image_demo.py b/src/image_demo.py index bd12beb..0af73c6 100644 --- a/src/image_demo.py +++ b/src/image_demo.py @@ -14,9 +14,10 @@ class ImageDemoOptions(): self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../models/last_transfered_net.pth","../models/last_transfered_lut.pth"], help="Model paths for comparison") self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path") self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") - self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.") - self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") - self.parser.add_argument('--mirror', action='store_true', default=False) + self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.") + self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") + self.parser.add_argument('--mirror', action='store_true', default=False) + self.parser.add_argument('--device', default='cuda', help='Device of the model') def parse_args(self): args = self.parser.parse_args() @@ -44,7 +45,7 @@ config = config_inst.parse_args() start_script_time = datetime.now() print(config_inst) -models = [LoadCheckpoint(x).cuda() for x in config.model_paths] +models = [LoadCheckpoint(x).to(config.device) for x in config.model_paths] for m in models: print(m) @@ -56,7 +57,7 @@ if config.mirror: lr_image = lr_image.copy() image_gt = image_gt.copy() -input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda() +input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].to(config.device) predictions = [] for model in models: diff --git a/src/models/srnet.py b/src/models/srnet.py index 3623a6a..f83c1f1 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -102,7 +102,7 @@ class SRNetR90(nn.Module): x = x.permute(0,1,2,4,3,5) x = x.reshape(b, c, h*scale, w*scale) return x - + def forward(self, x): b,c,h,w = x.shape x = x.reshape(b*c, 1, h, w) diff --git a/src/train.py b/src/train.py index 216a7e8..4ac1940 100644 --- a/src/train.py +++ b/src/train.py @@ -76,7 +76,7 @@ def prepare_experiment_folder(config): assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}." - config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}").resolve() + config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve() if not config.exp_dir.exists(): config.exp_dir.mkdir() @@ -180,7 +180,9 @@ if __name__ == "__main__": loss.backward() optimizer.step() optimizer.zero_grad() - + del hr_patch + del lr_patch + torch.cuda.empty_cache() forward_backward_time += time.time() - start_time # For monitoring