main
Vladimir Protsenko 8 months ago
parent 3885617c8f
commit 4eecec712c

@ -1,13 +1,25 @@
``` ```
python train.py --train_datasets DIV2K_pillow_bicubic --val_datasets Set5,Set14 --scale 4 --total_iter 10000 --model SRNetRot90 python train.py --model SRNetRot90
python transfer_to_lut.py --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth
python train.py --model_path /wd/lut_reproduce/models/RCNetx2_DIV2K_pillow_bicubic/checkpoints/RCLutx2_0.pth --train_datasets DIV2K_pillow_bicubic --total_iter 10000
python image_demo.py -n /wd/lut_reproduce/models/RCNetx2_DIV2K_pillow_bicubic/checkpoints/RCNetCentered_3x3_10000.pth -l /wd/lut_reproduce/models/RCLutCentered_3x3_DIV2K_pillow_bicubic/checkpoints/RCLutCentered_3x3_10000.pth python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/luts/models/SRNet_DIV2K_pillow_bicubic/checkpoints/SRNet_10000.pth python transfer_to_lut.py --model_path /wd/lut_reproduce/models/last_trained_net.pth
python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000
python validate.py --val_datasets Set5,Set14,B100,Urban1[00,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth
python image_demo.py
```
Help
```
python train.py --help
python validate.py --help
python transfer_to_lut.py --help
python image_demo.py --help
``` ```
Requierements: Requirements:
- [shedulefree](https://github.com/facebookresearch/schedule_free) - [shedulefree](https://github.com/facebookresearch/schedule_free)

@ -28,18 +28,18 @@ class ImageDemoOptions():
args.lut_model_path = Path(args.lut_model_path).resolve() args.lut_model_path = Path(args.lut_model_path).resolve()
return args return args
def print_options(self, opt): def __repr__(self):
config = self.parser.parse_args()
message = '' message = ''
message += '----------------- Options ---------------\n' message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()): for k, v in sorted(vars(config).items()):
comment = '' comment = ''
default = self.parser.get_default(k) default = self.parser.get_default(k)
if v != default: if v != default:
comment = '\t[default: %s]' % str(default) comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------' message += '----------------- End -------------------'
print(message) return message
print()
config_inst = ImageDemoOptions() config_inst = ImageDemoOptions()
config = config_inst.parse_args() config = config_inst.parse_args()

@ -55,18 +55,18 @@ class TrainOptions:
args.val_datasets = args.val_datasets.split(',') args.val_datasets = args.val_datasets.split(',')
return args return args
def print_options(self, opt): def __repr__(self):
config = self.parser.parse_args()
message = '' message = ''
message += '----------------- Options ---------------\n' message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()): for k, v in sorted(vars(config).items()):
comment = '' comment = ''
default = self.parser.get_default(k) default = self.parser.get_default(k)
if v != default: if v != default:
comment = '\t[default: %s]' % str(default) comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------' message += '----------------- End -------------------'
print(message) return message
print()
def prepare_experiment_folder(config): def prepare_experiment_folder(config):
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
@ -114,8 +114,8 @@ if __name__ == "__main__":
config.writer = writer config.writer = writer
config.logger = logger config.logger = logger
config.logger.info(config_inst.print_options(config)) config.logger.info(config_inst)
print(model) config.logger.info(model)
# Training dataset # Training dataset
train_datasets = [] train_datasets = []

@ -30,18 +30,18 @@ class TransferToLutOptions():
args.checkpoint_dir = Path(args.model_path).resolve().parent args.checkpoint_dir = Path(args.model_path).resolve().parent
return args return args
def print_options(self, opt): def __repr__(self):
config = self.parser.parse_args()
message = '' message = ''
message += '----------------- Options ---------------\n' message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()): for k, v in sorted(vars(config).items()):
comment = '' comment = ''
default = self.parser.get_default(k) default = self.parser.get_default(k)
if v != default: if v != default:
comment = '\t[default: %s]' % str(default) comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------' message += '----------------- End -------------------'
print(message) return message
print()
if __name__ == "__main__": if __name__ == "__main__":
@ -50,7 +50,7 @@ if __name__ == "__main__":
config_inst = TransferToLutOptions() config_inst = TransferToLutOptions()
config = config_inst.parse_args() config = config_inst.parse_args()
config_inst.print_options(config) print(config_inst)
model = models.LoadCheckpoint(config.model_path).cuda() model = models.LoadCheckpoint(config.model_path).cuda()
if getattr(model, 'get_lut_model', None) is None: if getattr(model, 'get_lut_model', None) is None:

Loading…
Cancel
Save