From 7d29d5f22f0bcca6f62a0cc9ed792993df60992f Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sun, 21 Apr 2024 11:09:42 +0000 Subject: [PATCH] update --- src/common/validation.py | 2 +- src/scripts/image_demo.py | 5 +++-- src/scripts/transfer_to_lut.py | 10 ++++++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/common/validation.py b/src/common/validation.py index 3de8f03..385147b 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -7,7 +7,7 @@ from PIL import Image # @ray.remote(num_cpus=1, num_gpus=0.3) def val_image_pair(model, hr_image, lr_image, output_image_path=None): - with torch.no_grad(): + with torch.inference_mode(): # prepare lr_image lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1) lr_image = lr_image.unsqueeze(0).cuda() diff --git a/src/scripts/image_demo.py b/src/scripts/image_demo.py index 1872782..48a7c1c 100644 --- a/src/scripts/image_demo.py +++ b/src/scripts/image_demo.py @@ -57,8 +57,9 @@ image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1].copy() input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda() -net_prediction = net_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() -lut_prediction = lut_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() +with torch.inference_mode(): + net_prediction = net_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() + lut_prediction = lut_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) diff --git a/src/scripts/transfer_to_lut.py b/src/scripts/transfer_to_lut.py index 8b5e3ed..4e314fe 100644 --- a/src/scripts/transfer_to_lut.py +++ b/src/scripts/transfer_to_lut.py @@ -70,8 +70,14 @@ if __name__ == "__main__": lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()]) print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB") - Path(config.models_dir / f"last_transfered_net.pth").symlink_to(config.model_path.resolve()) - Path(config.models_dir / f"last_transfered_lut.pth").symlink_to(lut_path.resolve()) + link = Path(config.models_dir / f"last_transfered_net.pth") + if link.exists(): + link.unlink() + link.symlink_to(config.model_path.resolve()) + link = Path(config.models_dir / f"last_transfered_lut.pth") + if link.exists(): + link.unlink() + link.symlink_to(lut_path.resolve()) print("Updated link", config.models_dir / f"last_transfered_net.pth") print("Updated link", config.models_dir / f"last_transfered_lut.pth")