|
|
@ -44,7 +44,7 @@ class TrainOptions:
|
|
|
|
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
|
|
|
|
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
|
|
|
|
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
|
|
|
|
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
|
|
|
|
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
|
|
|
|
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
|
|
|
|
parser.add_argument('--save_val_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
|
|
|
|
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
|
|
|
|
self.parser = parser
|
|
|
|
self.parser = parser
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(self):
|
|
|
|
def parse_args(self):
|
|
|
@ -207,9 +207,15 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
# check if it is network or lut
|
|
|
|
# check if it is network or lut
|
|
|
|
if hasattr(model, 'get_lut_model'):
|
|
|
|
if hasattr(model, 'get_lut_model'):
|
|
|
|
Path(config.models_dir / f"last_trained_net.pth").symlink_to(model_path)
|
|
|
|
link = Path(config.models_dir / f"last_trained_net.pth")
|
|
|
|
|
|
|
|
if link.exists():
|
|
|
|
|
|
|
|
link.unlink()
|
|
|
|
|
|
|
|
link.symlink_to(model_path)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
Path(config.models_dir / f"last_trained_lut.pth").symlink_to(model_path)
|
|
|
|
link = Path(config.models_dir / f"last_trained_lut.pth")
|
|
|
|
|
|
|
|
if link.exists():
|
|
|
|
|
|
|
|
link.unlink()
|
|
|
|
|
|
|
|
link.symlink_to(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
total_script_time = datetime.now() - script_start_time
|
|
|
|
total_script_time = datetime.now() - script_start_time
|
|
|
|
config.logger.info(f"Completed after {total_script_time}")
|
|
|
|
config.logger.info(f"Completed after {total_script_time}")
|