added device to image_demo. added scale to name of experiment. fix bug srnet again.

main
Vladimir Protsenko 7 months ago
parent ae8f7b6742
commit f2eea32363

@ -17,6 +17,7 @@ class ImageDemoOptions():
self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False)
self.parser.add_argument('--device', default='cuda', help='Device of the model')
def parse_args(self):
args = self.parser.parse_args()
@ -44,7 +45,7 @@ config = config_inst.parse_args()
start_script_time = datetime.now()
print(config_inst)
models = [LoadCheckpoint(x).cuda() for x in config.model_paths]
models = [LoadCheckpoint(x).to(config.device) for x in config.model_paths]
for m in models:
print(m)
@ -56,7 +57,7 @@ if config.mirror:
lr_image = lr_image.copy()
image_gt = image_gt.copy()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].to(config.device)
predictions = []
for model in models:

@ -76,7 +76,7 @@ def prepare_experiment_folder(config):
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}."
config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}").resolve()
config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}_x{config.scale}").resolve()
if not config.exp_dir.exists():
config.exp_dir.mkdir()
@ -180,7 +180,9 @@ if __name__ == "__main__":
loss.backward()
optimizer.step()
optimizer.zero_grad()
del hr_patch
del lr_patch
torch.cuda.empty_cache()
forward_backward_time += time.time() - start_time
# For monitoring

Loading…
Cancel
Save