main
Vladimir Protsenko 8 months ago
parent 5dacff6958
commit fd207056f8

3
.gitignore vendored

@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
*.png
*.0

@ -2,13 +2,11 @@
Example Example
``` ```
python train.py --model SRNetRot90 python train.py --model SRNetRot90
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth
python transfer_to_lut.py --model_path /wd/lut_reproduce/models/last_trained_net.pth python transfer_to_lut.py --model_path /wd/lut_reproduce/models/last_trained_net.pth
python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000 python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth
python image_demo.py python image_demo.py

@ -44,7 +44,7 @@ def valid_steps(model, datasets, config, log_prefix=""):
test_dataset = datasets[dataset_name] test_dataset = datasets[dataset_name]
tasks = [] tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_val_predictions else None output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, output_image_path) task = val_image_pair(model, hr_image, lr_image, output_image_path)
tasks.append(task) tasks.append(task)
@ -61,5 +61,5 @@ def valid_steps(model, datasets, config, log_prefix=""):
config.logger.info( config.logger.info(
'\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims)))) '\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims))))
config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter) # config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter)
config.writer.flush() config.writer.flush()

@ -86,8 +86,7 @@ class RCNetCentered_3x3(nn.Module):
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.layers_count = layers_count self.layers_count = layers_count
self.scale = scale self.scale = scale
window_size = 3 self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size)
def forward(self, x): def forward(self, x):
b,c,h,w = x.shape b,c,h,w = x.shape

@ -44,7 +44,7 @@ class TrainOptions:
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration') parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers") parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
self.parser = parser self.parser = parser
def parse_args(self): def parse_args(self):
@ -207,9 +207,15 @@ if __name__ == "__main__":
# check if it is network or lut # check if it is network or lut
if hasattr(model, 'get_lut_model'): if hasattr(model, 'get_lut_model'):
Path(config.models_dir / f"last_trained_net.pth").symlink_to(model_path) link = Path(config.models_dir / f"last_trained_net.pth")
if link.exists():
link.unlink()
link.symlink_to(model_path)
else: else:
Path(config.models_dir / f"last_trained_lut.pth").symlink_to(model_path) link = Path(config.models_dir / f"last_trained_lut.pth")
if link.exists():
link.unlink()
link.symlink_to(model_path)
total_script_time = datetime.now() - script_start_time total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}") config.logger.info(f"Completed after {total_script_time}")

@ -82,7 +82,7 @@ if __name__ == "__main__":
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
) )
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}")
total_script_time = datetime.now() - script_start_time total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}") config.logger.info(f"Completed after {total_script_time}")
Loading…
Cancel
Save