import sys from pickle import dump import logging import math import os import time import numpy as np import torch import torch.nn.functional as F import torch.optim as optim from PIL import Image from pathlib import Path from torch.utils.tensorboard import SummaryWriter from torch.utils.data import Dataset, DataLoader from common.data import SRTrainDataset, SRTestDataset from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS from common.validation import valid_steps torch.backends.cudnn.benchmark = True import argparse from schedulefree import AdamWScheduleFree from datetime import datetime class TrainOptions: def __init__(self): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False) parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}") parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.") parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.") parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.") parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor") parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers") parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size') parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training") parser.add_argument('--models_dir', type=str, default='../models/', help="experiment folder") parser.add_argument('--datasets_dir', type=str, default="../data/") parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further') parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations') parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration') parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration') parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration') parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers") parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--device', default='cuda', help='Device of the model') parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") self.parser = parser def parse_args(self): args = self.parser.parse_args() args.datasets_dir = Path(args.datasets_dir).resolve() args.models_dir = Path(args.models_dir).resolve() args.model_path = Path(args.model_path) if not args.model_path is None else None args.train_datasets = args.train_datasets.split(',') args.val_datasets = args.val_datasets.split(',') return args def __repr__(self): config = self.parse_args() message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(config).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' return message def prepare_experiment_folder(config): assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}." assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}." config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}").resolve() if not config.exp_dir.exists(): config.exp_dir.mkdir() config.checkpoint_dir = (config.exp_dir / "checkpoints").resolve() if not config.checkpoint_dir.exists(): config.checkpoint_dir.mkdir() config.valout_dir = (config.exp_dir / 'val').resolve() if not config.valout_dir.exists(): config.valout_dir.mkdir() config.logs_dir = (config.exp_dir / 'logs').resolve() if not config.logs_dir.exists(): config.logs_dir.mkdir() if __name__ == "__main__": script_start_time = datetime.now() config_inst = TrainOptions() config = config_inst.parse_args() if not config.model_path is None: model = LoadCheckpoint(config.model_path) config.model = model.__class__.__name__ else: if 'net' in config.model.lower(): model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) if 'lut' in config.model.lower(): model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale) model = model.to(torch.device(config.device)) optimizer = AdamWScheduleFree(model.parameters()) print(optimizer) prepare_experiment_folder(config) # Tensorboard for monitoring writer = SummaryWriter(log_dir=config.logs_dir) logger_name = 'train' logger_info(logger_name, os.path.join(config.logs_dir, logger_name + '.log')) logger = logging.getLogger(logger_name) config.writer = writer config.logger = logger config.logger.info(config_inst) config.logger.info(model) # Training dataset train_datasets = [] for train_dataset_name in config.train_datasets: train_datasets.append(SRTrainDataset( hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR", lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}", patch_size = config.crop_size )) train_dataset = torch.utils.data.ConcatDataset(train_datasets) train_loader = DataLoader( dataset = train_dataset, batch_size = config.batch_size, num_workers = config.worker_num, shuffle = False, drop_last = False, pin_memory = True, prefetch_factor = config.prefetch_factor ) train_iter = iter(train_loader) test_datasets = {} for test_dataset_name in config.val_datasets: test_datasets[test_dataset_name] = SRTestDataset( hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}", ) l_accum = [0., 0., 0.] prepare_data_time = 0. forward_backward_time = 0. accum_samples = 0 # TRAINING i = config.start_iter if not config.model_path is None: config.current_iter = i valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") for i in range(config.start_iter + 1, config.total_iter + 1): torch.cuda.empty_cache() start_time = time.time() try: hr_patch, lr_patch = next(train_iter) except StopIteration: train_iter = iter(train_loader) hr_patch, lr_patch = next(train_iter) hr_patch = hr_patch.to(torch.device(config.device)) lr_patch = lr_patch.to(torch.device(config.device)) prepare_data_time += time.time() - start_time start_time = time.time() pred = model(lr_patch) loss = F.mse_loss(pred/255, hr_patch/255) loss.backward() optimizer.step() optimizer.zero_grad() forward_backward_time += time.time() - start_time # For monitoring accum_samples += config.batch_size l_accum[0] += loss.item() # Show information if i % config.display_step == 0: config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i) config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format( model.__class__.__name__, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step, forward_backward_time / config.display_step)) l_accum = [0., 0., 0.] prepare_data_time = 0. forward_backward_time = 0. # Save models if i % config.save_step == 0: SaveCheckpoint(model=model, path=Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth") # Validation if i % config.val_step == 0: config.current_iter = i valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve() SaveCheckpoint(model=model, path=model_path) print("Saved to ", model_path) # check if it is network or lut if hasattr(model, 'get_lut_model'): link = Path(config.models_dir / f"last_trained_net.pth") if link.exists(): link.unlink() link.symlink_to(model_path) else: link = Path(config.models_dir / f"last_trained_lut.pth") if link.exists(): link.unlink() link.symlink_to(model_path) link = Path(config.models_dir / f"last.pth") if link.exists(): link.unlink() link.symlink_to(model_path) total_script_time = datetime.now() - script_start_time config.logger.info(f"Completed after {total_script_time}")