diff --git a/readme.md b/readme.md index 16fb3ed..7803a15 100644 --- a/readme.md +++ b/readme.md @@ -10,5 +10,4 @@ python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path ``` Requierements: -- [shedulefree](https://github.com/facebookresearch/schedule_free) -- einops \ No newline at end of file +- [shedulefree](https://github.com/facebookresearch/schedule_free) \ No newline at end of file diff --git a/src/scripts/image_demo.py b/src/scripts/image_demo.py index 1c0c49a..2d5abff 100644 --- a/src/scripts/image_demo.py +++ b/src/scripts/image_demo.py @@ -9,25 +9,23 @@ import cv2 from PIL import Image from datetime import datetime import argparse -class DemoOptions(): +class ImageDemoOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - self.parser.add_argument('--net_model_path', '-n', type=str, default=None, help="Net model path folder") - self.parser.add_argument('--lut_model_path', '-l', type=str, default=None, help="Lut model path folder") - self.parser.add_argument('--project_path', '-q', type=str, default="../../", help="Project path.") - self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.") + self.parser.add_argument('--net_model_path', '-n', type=str, default="../../models/last_transfered_net.pth", help="Net model path folder") + self.parser.add_argument('--lut_model_path', '-l', type=str, default="../../models/last_transfered_lut.pth", help="Lut model path folder") + self.parser.add_argument('--hr_image_path', '-a', type=str, default="../../data/Set14/HR/monarch.png", help="HR image path") + self.parser.add_argument('--lr_image_path', '-b', type=str, default="../../data/Set14/LR/X4/monarch.png", help="LR image path") + self.parser.add_argument('--project_path', type=str, default="../../", help="Project path.") + self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") def parse_args(self): args = self.parser.parse_args() args.project_path = Path(args.project_path).resolve() - if args.net_model_path is None: - args.project_path / "models" / "last_transfered_net.pth" - else: - args.net_model_path = Path(args.net_model_path).resolve() - if args.lut_model_path is None: - args.project_path / "models" / "last_transfered_lut.pth" - else: - args.lut_model_path = Path(args.lut_model_path).resolve() + args.hr_image_path = Path(args.hr_image_path).resolve() + args.lr_image_path = Path(args.lr_image_path).resolve() + args.net_model_path = Path(args.net_model_path).resolve() + args.lut_model_path = Path(args.lut_model_path).resolve() return args def print_options(self, opt): @@ -43,21 +41,10 @@ class DemoOptions(): print(message) print() -config_inst = DemoOptions() +config_inst = ImageDemoOptions() config = config_inst.parse_args() start_script_time = datetime.now() -# net_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCNetCentered_10000.pth") -# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCLutCentered_0.pth") - -# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCNetRot90_7x7_10000.pth") -# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCLutRot90_7x7_0.pth") - -# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCNetRot90_3x3_10000.pth") -# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCLutRot90_3x3_0.pth") - -# net_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCNetx1_46000.pth") -# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCLutx1_0.pth") net_model = LoadCheckpoint(config.net_model_path).cuda() lut_model = LoadCheckpoint(config.lut_model_path).cuda() @@ -65,11 +52,8 @@ lut_model = LoadCheckpoint(config.lut_model_path).cuda() print(net_model) print(lut_model) -lr_image = cv2.imread(str(config.project_path / "data" / "Set14/LR/X4/lenna.png"))[:,:,::-1].copy() -image_gt = cv2.imread(str(config.project_path / "data" / "Set14/HR/lenna.png"))[:,:,::-1].copy() - -# lr_image = cv2.imread(str(project_path / "data" / "Synthetic/LR/X4/linear.png"))[:,:,::-1].copy() -# image_gt = cv2.imread(str(project_path / "data" / "Synthetic/HR/linear.png"))[:,:,::-1].copy() +lr_image = cv2.imread(str(config.project_path / "data" / "Set14/LR/X4/monarch.png"))[:,::-1,::-1].copy() +image_gt = cv2.imread(str(config.project_path / "data" / "Set14/HR/monarch.png"))[:,::-1,::-1].copy() input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda() diff --git a/src/scripts/train.py b/src/scripts/train.py index 85a3010..28e709a 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -201,5 +201,15 @@ if __name__ == "__main__": config.current_iter = i valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") + model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve() + SaveCheckpoint(model=model, path=model_path) + print("Saved to ", model_path) + + # check if it is network or lut + if hasattr(model, 'get_lut_model'): + Path(config.models_dir / f"last_trained_net.pth").symlink_to(model_path) + else: + Path(config.models_dir / f"last_trained_lut.pth").symlink_to(model_path) + total_script_time = datetime.now() - script_start_time config.logger.info(f"Completed after {total_script_time}") \ No newline at end of file diff --git a/src/scripts/transfer_to_lut.py b/src/scripts/transfer_to_lut.py index e0ee5a8..1e425a2 100644 --- a/src/scripts/transfer_to_lut.py +++ b/src/scripts/transfer_to_lut.py @@ -26,7 +26,8 @@ class TransferToLutOptions(): def parse_args(self): args = self.parser.parse_args() args.model_path = Path(args.model_path) - args.checkpoint_dir = Path(args.model_path).absolute().parent + args.models_dir = Path(args.model_path).resolve().parent.parent.parent + args.checkpoint_dir = Path(args.model_path).resolve().parent return args def print_options(self, opt): @@ -65,14 +66,14 @@ if __name__ == "__main__": lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth" models.SaveCheckpoint(model=lut_model, path=lut_path) - + lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()]) - - print() - print(datetime.now()-start_time) print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB") - models.SaveCheckpoint(model=model, path=Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_net.pth") - models.SaveCheckpoint(model=lut_model, path=Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_lut.pth") - print("Updated", Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_net.pth") - print("Updated", Path(config.model_path).absolute().parent.parent.parent / f"last_transfered_lut.pth") \ No newline at end of file + Path(config.models_dir / f"last_transfered_net.pth").symlink_to(config.model_path.resolve()) + Path(config.models_dir / f"last_transfered_lut.pth").symlink_to(lut_path.resolve()) + print("Updated link", config.models_dir / f"last_transfered_net.pth") + print("Updated link", config.models_dir / f"last_transfered_lut.pth") + + print() + print("Completed after", datetime.now()-start_time) \ No newline at end of file diff --git a/src/scripts/validate.py b/src/scripts/validate.py index 48c5266..938a196 100644 --- a/src/scripts/validate.py +++ b/src/scripts/validate.py @@ -26,7 +26,7 @@ class ValOptions(): self.parser.add_argument('--model_path', type=str, help="Model path.") self.parser.add_argument('--datasets_dir', type=str, default="../../data/", help="Path to datasets.") self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.") - self.parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') + self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') def parse_args(self): args = self.parser.parse_args() @@ -65,6 +65,7 @@ class ValOptions(): # TODO with unified save/load function any model file of net or lut can be tested with the same script. if __name__ == "__main__": + script_start_time = datetime.now() config_inst = ValOptions() config = config_inst.parse_args() @@ -83,4 +84,5 @@ if __name__ == "__main__": valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") - config.logger.info("Complete") \ No newline at end of file + total_script_time = datetime.now() - script_start_time + config.logger.info(f"Completed after {total_script_time}") \ No newline at end of file