main
Vladimir Protsenko 7 months ago
parent 5dacff6958
commit fd207056f8

3
.gitignore vendored

@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.png
*.0

@ -2,13 +2,11 @@
Example
```
python train.py --model SRNetRot90
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth
python transfer_to_lut.py --model_path /wd/lut_reproduce/models/last_trained_net.pth
python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth
python image_demo.py

@ -44,7 +44,7 @@ def valid_steps(model, datasets, config, log_prefix=""):
test_dataset = datasets[dataset_name]
tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_val_predictions else None
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, output_image_path)
tasks.append(task)
@ -61,5 +61,5 @@ def valid_steps(model, datasets, config, log_prefix=""):
config.logger.info(
'\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims))))
config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter)
# config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter)
config.writer.flush()

@ -86,8 +86,7 @@ class RCNetCentered_3x3(nn.Module):
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.scale = scale
window_size = 3
self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size)
self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
def forward(self, x):
b,c,h,w = x.shape

@ -44,7 +44,7 @@ class TrainOptions:
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
self.parser = parser
def parse_args(self):
@ -207,9 +207,15 @@ if __name__ == "__main__":
# check if it is network or lut
if hasattr(model, 'get_lut_model'):
Path(config.models_dir / f"last_trained_net.pth").symlink_to(model_path)
link = Path(config.models_dir / f"last_trained_net.pth")
if link.exists():
link.unlink()
link.symlink_to(model_path)
else:
Path(config.models_dir / f"last_trained_lut.pth").symlink_to(model_path)
link = Path(config.models_dir / f"last_trained_lut.pth")
if link.exists():
link.unlink()
link.symlink_to(model_path)
total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}")

@ -82,7 +82,7 @@ if __name__ == "__main__":
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
)
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}")
total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}")
Loading…
Cancel
Save