You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
10 KiB
Python

import sys
from pickle import dump
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr
from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS
from common.validation import valid_steps
torch.backends.cudnn.benchmark = True
import argparse
from schedulefree import AdamWScheduleFree
from datetime import datetime
class TrainOptions:
def __init__(self):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False)
parser.add_argument('--model', type=str, default='RCNetx1', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.")
parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.")
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training")
parser.add_argument('--models_dir', type=str, default='../models/', help="experiment folder")
parser.add_argument('--datasets_dir', type=str, default="../data/")
parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further')
parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations')
parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration')
parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration')
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
self.parser = parser
def parse_args(self):
args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve()
args.models_dir = Path(args.models_dir).resolve()
args.model_path = Path(args.model_path) if not args.model_path is None else None
args.train_datasets = args.train_datasets.split(',')
args.val_datasets = args.val_datasets.split(',')
return args
def __repr__(self):
config = self.parse_args()
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
return message
def prepare_experiment_folder(config):
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.val_datasets]), f"On of the {config.val_datasets} was not found in {config.datasets_dir}."
config.exp_dir = (config.models_dir / f"{config.model}_{'_'.join(config.train_datasets)}").resolve()
if not config.exp_dir.exists():
config.exp_dir.mkdir()
config.checkpoint_dir = (config.exp_dir / "checkpoints").resolve()
if not config.checkpoint_dir.exists():
config.checkpoint_dir.mkdir()
config.valout_dir = (config.exp_dir / 'val').resolve()
if not config.valout_dir.exists():
config.valout_dir.mkdir()
config.logs_dir = (config.exp_dir / 'logs').resolve()
if not config.logs_dir.exists():
config.logs_dir.mkdir()
if __name__ == "__main__":
script_start_time = datetime.now()
config_inst = TrainOptions()
config = config_inst.parse_args()
if not config.model_path is None:
model = LoadCheckpoint(config.model_path)
config.model = model.__class__.__name__
else:
if 'net' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale)
if 'lut' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale)
model = model.to(torch.device(config.device))
optimizer = AdamWScheduleFree(model.parameters())
print(optimizer)
prepare_experiment_folder(config)
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=config.logs_dir)
logger_name = 'train'
logger_info(logger_name, os.path.join(config.logs_dir, logger_name + '.log'))
logger = logging.getLogger(logger_name)
config.writer = writer
config.logger = logger
config.logger.info(config_inst)
config.logger.info(model)
# Training dataset
train_datasets = []
for train_dataset_name in config.train_datasets:
train_datasets.append(SRTrainDataset(
hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.scale}",
patch_size = config.crop_size
))
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
train_loader = DataLoader(
dataset = train_dataset,
batch_size = config.batch_size,
num_workers = config.worker_num,
shuffle = False,
drop_last = False,
pin_memory = True,
prefetch_factor = config.prefetch_factor
)
train_iter = iter(train_loader)
test_datasets = {}
for test_dataset_name in config.val_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.scale}",
)
l_accum = [0., 0., 0.]
prepare_data_time = 0.
forward_backward_time = 0.
accum_samples = 0
# TRAINING
i = config.start_iter
if not config.model_path is None:
config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
for i in range(config.start_iter + 1, config.total_iter + 1):
torch.cuda.empty_cache()
start_time = time.time()
try:
hr_patch, lr_patch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
hr_patch, lr_patch = next(train_iter)
hr_patch = hr_patch.to(torch.device(config.device))
lr_patch = lr_patch.to(torch.device(config.device))
prepare_data_time += time.time() - start_time
start_time = time.time()
pred = model(lr_patch)
loss = F.mse_loss(pred/255, hr_patch/255)
loss.backward()
optimizer.step()
optimizer.zero_grad()
forward_backward_time += time.time() - start_time
# For monitoring
accum_samples += config.batch_size
l_accum[0] += loss.item()
# Show information
if i % config.display_step == 0:
config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i)
config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
model.__class__.__name__, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step,
forward_backward_time / config.display_step))
l_accum = [0., 0., 0.]
prepare_data_time = 0.
forward_backward_time = 0.
# Save models
if i % config.save_step == 0:
SaveCheckpoint(model=model, path=Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth")
# Validation
if i % config.val_step == 0:
config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve()
SaveCheckpoint(model=model, path=model_path)
print("Saved to ", model_path)
# check if it is network or lut
if hasattr(model, 'get_lut_model'):
link = Path(config.models_dir / f"last_trained_net.pth")
if link.exists():
link.unlink()
link.symlink_to(model_path)
else:
link = Path(config.models_dir / f"last_trained_lut.pth")
if link.exists():
link.unlink()
link.symlink_to(model_path)
link = Path(config.models_dir / f"last.pth")
if link.exists():
link.unlink()
link.symlink_to(model_path)
total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}")