added device to image_demo. added scale to name of experiment. fix bug srnet again.

main
Vladimir Protsenko 7 months ago
parent ae8f7b6742
commit f2eea32363

@ -20,7 +20,7 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cu
pred_lr_image = pred_lr_image.cpu().numpy()
run_time_ns = time.perf_counter_ns() - start_time
torch.cuda.empty_cache()
if not output_image_path is None:
Image.fromarray(pred_lr_image).save(output_image_path)

@ -14,9 +14,10 @@ class ImageDemoOptions():
self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../models/last_transfered_net.pth","../models/last_transfered_lut.pth"], help="Model paths for comparison")
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path")
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False)
self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False)
self.parser.add_argument('--device', default='cuda', help='Device of the model')
def parse_args(self):
args = self.parser.parse_args()
@ -44,7 +45,7 @@ config = config_inst.parse_args()
start_script_time = datetime.now()
print(config_inst)
models = [LoadCheckpoint(x).cuda() for x in config.model_paths]
models = [LoadCheckpoint(x).to(config.device) for x in config.model_paths]
for m in models:
print(m)
@ -56,7 +57,7 @@ if config.mirror:
lr_image = lr_image.copy()
image_gt = image_gt.copy()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].to(config.device)
predictions = []
for model in models:

@ -102,7 +102,7 @@ class SRNetR90(nn.Module):
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)

@ -76,7 +76,7 @@ def prepare_experiment_folder(config):
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}."
config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}").resolve()
config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve()
if not config.exp_dir.exists():
config.exp_dir.mkdir()
@ -180,7 +180,9 @@ if __name__ == "__main__":
loss.backward()
optimizer.step()
optimizer.zero_grad()
del hr_patch
del lr_patch
torch.cuda.empty_cache()
forward_backward_time += time.time() - start_time
# For monitoring

Loading…
Cancel
Save