main
Vladimir Protsenko 8 months ago
parent fd207056f8
commit 7d29d5f22f

@ -7,7 +7,7 @@ from PIL import Image
# @ray.remote(num_cpus=1, num_gpus=0.3)
def val_image_pair(model, hr_image, lr_image, output_image_path=None):
with torch.no_grad():
with torch.inference_mode():
# prepare lr_image
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
lr_image = lr_image.unsqueeze(0).cuda()

@ -57,6 +57,7 @@ image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1].copy()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda()
with torch.inference_mode():
net_prediction = net_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy()
lut_prediction = lut_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy()

@ -70,8 +70,14 @@ if __name__ == "__main__":
lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()])
print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB")
Path(config.models_dir / f"last_transfered_net.pth").symlink_to(config.model_path.resolve())
Path(config.models_dir / f"last_transfered_lut.pth").symlink_to(lut_path.resolve())
link = Path(config.models_dir / f"last_transfered_net.pth")
if link.exists():
link.unlink()
link.symlink_to(config.model_path.resolve())
link = Path(config.models_dir / f"last_transfered_lut.pth")
if link.exists():
link.unlink()
link.symlink_to(lut_path.resolve())
print("Updated link", config.models_dir / f"last_transfered_net.pth")
print("Updated link", config.models_dir / f"last_transfered_lut.pth")

Loading…
Cancel
Save