From 4eecec712c4a8b9157ed3dda8953dd06ea734458 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sat, 20 Apr 2024 15:11:35 +0000 Subject: [PATCH] update --- readme.md | 24 ++++++++++++++++++------ src/scripts/image_demo.py | 8 ++++---- src/scripts/train.py | 12 ++++++------ src/scripts/transfer_to_lut.py | 10 +++++----- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/readme.md b/readme.md index 7803a15..89d259e 100644 --- a/readme.md +++ b/readme.md @@ -1,13 +1,25 @@ ``` -python train.py --train_datasets DIV2K_pillow_bicubic --val_datasets Set5,Set14 --scale 4 --total_iter 10000 --model SRNetRot90 -python transfer_to_lut.py --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth -python train.py --model_path /wd/lut_reproduce/models/RCNetx2_DIV2K_pillow_bicubic/checkpoints/RCLutx2_0.pth --train_datasets DIV2K_pillow_bicubic --total_iter 10000 +python train.py --model SRNetRot90 -python image_demo.py -n /wd/lut_reproduce/models/RCNetx2_DIV2K_pillow_bicubic/checkpoints/RCNetCentered_3x3_10000.pth -l /wd/lut_reproduce/models/RCLutCentered_3x3_DIV2K_pillow_bicubic/checkpoints/RCLutCentered_3x3_10000.pth +python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth -python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth +python transfer_to_lut.py --model_path /wd/lut_reproduce/models/last_trained_net.pth + +python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000 + +python validate.py --val_datasets Set5,Set14,B100,Urban1[00,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth + +python image_demo.py +``` + +Help +``` +python train.py --help +python validate.py --help +python transfer_to_lut.py --help +python image_demo.py --help ``` -Requierements: +Requirements: - [shedulefree](https://github.com/facebookresearch/schedule_free) \ No newline at end of file diff --git a/src/scripts/image_demo.py b/src/scripts/image_demo.py index 2d5abff..4c091b9 100644 --- a/src/scripts/image_demo.py +++ b/src/scripts/image_demo.py @@ -28,18 +28,18 @@ class ImageDemoOptions(): args.lut_model_path = Path(args.lut_model_path).resolve() return args - def print_options(self, opt): + def __repr__(self): + config = self.parser.parse_args() message = '' message += '----------------- Options ---------------\n' - for k, v in sorted(vars(opt).items()): + for k, v in sorted(vars(config).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' - print(message) - print() + return message config_inst = ImageDemoOptions() config = config_inst.parse_args() diff --git a/src/scripts/train.py b/src/scripts/train.py index 28e709a..2981631 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -55,18 +55,18 @@ class TrainOptions: args.val_datasets = args.val_datasets.split(',') return args - def print_options(self, opt): + def __repr__(self): + config = self.parser.parse_args() message = '' message += '----------------- Options ---------------\n' - for k, v in sorted(vars(opt).items()): + for k, v in sorted(vars(config).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' - print(message) - print() + return message def prepare_experiment_folder(config): assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." @@ -114,8 +114,8 @@ if __name__ == "__main__": config.writer = writer config.logger = logger - config.logger.info(config_inst.print_options(config)) - print(model) + config.logger.info(config_inst) + config.logger.info(model) # Training dataset train_datasets = [] diff --git a/src/scripts/transfer_to_lut.py b/src/scripts/transfer_to_lut.py index 1e425a2..8b5e3ed 100644 --- a/src/scripts/transfer_to_lut.py +++ b/src/scripts/transfer_to_lut.py @@ -30,18 +30,18 @@ class TransferToLutOptions(): args.checkpoint_dir = Path(args.model_path).resolve().parent return args - def print_options(self, opt): + def __repr__(self): + config = self.parser.parse_args() message = '' message += '----------------- Options ---------------\n' - for k, v in sorted(vars(opt).items()): + for k, v in sorted(vars(config).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' - print(message) - print() + return message if __name__ == "__main__": @@ -50,7 +50,7 @@ if __name__ == "__main__": config_inst = TransferToLutOptions() config = config_inst.parse_args() - config_inst.print_options(config) + print(config_inst) model = models.LoadCheckpoint(config.model_path).cuda() if getattr(model, 'get_lut_model', None) is None: