You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

286 lines
13 KiB
Python

import sys
from pickle import dump
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset
from common.utils import logger_info
from common.metrics import PSNR, cal_ssim
from common.color import _rgb2ycbcr, PIL_CONVERT_COLOR
import yaml
from models import SaveCheckpoint, LoadCheckpoint, AVAILABLE_MODELS
from common.test import test_steps
torch.backends.cudnn.benchmark = True
import argparse
from schedulefree import AdamWScheduleFree
from datetime import datetime
from types import SimpleNamespace
import signal
class SignalHandler:
def __init__(self, signal_code):
self.is_on = False
self.count = 0
signal.signal(signal_code, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print("Early stopping.")
self.is_on = True
self.count += 1
if self.count == 3:
exit(1)
signal_interraption_handler = SignalHandler(signal.SIGINT)
class TrainOptions:
def __init__(self):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False)
parser.add_argument('--model', type=str, default='SRNet', help=f"Model: {list(AVAILABLE_MODELS.keys())}")
# parser.add_argument('--model_path', type=str, default=None, help=f"Path to model for finetune.")
parser.add_argument('--train_datasets', type=str, default='DIV2K', help="Folder names of datasets to train on.")
parser.add_argument('--test_datasets', type=str, default='Set5,Set14', help="Folder names of datasets to validate on.")
parser.add_argument('--upscale_factor', '-s', type=int, default=4, help="up scale factor")
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--layers_count', type=int, default=4, help="number of convolutional layers")
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training")
parser.add_argument('--experiment_dir', type=str, default='../experiments/', help="experiment folder")
parser.add_argument('--datasets_dir', type=str, default="../data/")
parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further')
parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations')
parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration')
parser.add_argument('--val_step', type=int, default=2000, help='validate every N iteration')
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
parser.add_argument('--loader_worker_num', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--test_worker_num', type=int, default=1, help="Test parallelism. Use 1 for time measurement.")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
parser.add_argument('--color_model', type=str, default="RGB", help=f"Color model for train and test dataset. Choose from: {list(PIL_CONVERT_COLOR.keys())}")
parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache')
parser.add_argument('--learning_rate', type=float, default=0.0025, help='Learning rate')
parser.add_argument('--grad_step', type=int, default=1, help='Optimizer step.')
self.parser = parser
def parse_args(self):
args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve()
args.experiment_dir = Path(args.experiment_dir).resolve()
if Path(args.model).exists():
args.model = Path(args.model)
args.start_iter = int(args.model.stem.split("_")[-1])
args.train_datasets = args.train_datasets.split(',')
args.test_datasets = args.test_datasets.split(',')
args.quantization_interval = 2**(8-args.quantization_bits)
return args
def save_config(self, config):
yaml.dump(config, open(config.exp_dir / "config.yaml", 'w'))
def __repr__(self):
config = self.parse_args() if self.config is None else self.config
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
return message
def prepare_experiment(self):
config = self.parse_args()
if isinstance(config.model, Path):
model = LoadCheckpoint(config.model)
config.model = model.__class__.__name__
else:
config_dict = vars(config)
model = AVAILABLE_MODELS[config.model](
config = SimpleNamespace(**{k:config_dict[k] for k in config_dict.keys() if k in ['hidden_dim', 'layers_count', 'upscale_factor', 'quantization_interval']})
)
model = model.to(torch.device(config.device))
# model = torch.compile(model)
assert all([name in os.listdir(config.datasets_dir) for name in config.train_datasets]), f"On of the {config.train_datasets} was not found in {config.datasets_dir}."
assert all([name in os.listdir(config.datasets_dir) for name in config.test_datasets]), f"On of the {config.test_datasets} was not found in {config.datasets_dir}."
config.exp_dir = (config.experiment_dir / f"{config.model}_{config.color_model}_{'_'.join(config.train_datasets)}_x{config.upscale_factor}").resolve()
if not config.exp_dir.exists():
config.exp_dir.mkdir()
config.checkpoint_dir = (config.exp_dir / "checkpoints").resolve()
if not config.checkpoint_dir.exists():
config.checkpoint_dir.mkdir()
config.test_dir = (config.exp_dir / 'val').resolve()
if not config.test_dir.exists():
config.test_dir.mkdir()
config.logs_dir = (config.exp_dir / 'logs').resolve()
if not config.logs_dir.exists():
config.logs_dir.mkdir()
optimizer = AdamWScheduleFree(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.95))
self.save_config(config)
self.config = config
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=config.logs_dir)
logger_name = 'train'
logger_info(logger_name, os.path.join(config.logs_dir, logger_name + '.log'))
logger = logging.getLogger(logger_name)
config.writer = writer
config.logger = logger
config.logger.info(self)
config.logger.info(model)
config.logger.info(optimizer)
return config, model, optimizer
if __name__ == "__main__":
# torch.set_float32_matmul_precision('high')
script_start_time = datetime.now()
config_inst = TrainOptions()
config, model, optimizer = config_inst.prepare_experiment()
# Training dataset
train_datasets = []
for train_dataset_name in config.train_datasets:
train_datasets.append(SRTrainDataset(
hr_dir_path = Path(config.datasets_dir) / train_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / train_dataset_name / "LR" / f"X{config.upscale_factor}",
patch_size = config.crop_size,
color_model = config.color_model,
reset_cache=config.reset_cache
))
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
train_loader = DataLoader(
dataset = train_dataset,
batch_size = config.batch_size,
num_workers = config.loader_worker_num,
shuffle = True,
drop_last = False,
pin_memory = True,
prefetch_factor = config.prefetch_factor
)
train_iter = iter(train_loader)
test_datasets = {}
for test_dataset_name in config.test_datasets:
test_datasets[test_dataset_name] = SRTestDataset(
hr_dir_path = Path(config.datasets_dir) / test_dataset_name / "HR",
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{config.upscale_factor}",
color_model = config.color_model,
reset_cache=config.reset_cache
)
l_accum = [0., 0., 0.]
prepare_data_time = 0.
forward_backward_time = 0.
accum_samples = 0
# TRAINING
i = config.start_iter
if isinstance(config.model, Path):
config.current_iter = i
test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
loss_fn = model.get_loss_fn()
for i in range(config.start_iter + 1, config.total_iter + 1):
if signal_interraption_handler.is_on:
break
config.current_iter = i
torch.cuda.empty_cache()
start_time = time.time()
try:
hr_patch, lr_patch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
hr_patch, lr_patch = next(train_iter)
hr_patch = hr_patch.to(torch.device(config.device))
lr_patch = lr_patch.to(torch.device(config.device))
prepare_data_time += time.time() - start_time
start_time = time.time()
with torch.set_grad_enabled(True):
pred = model(x=lr_patch, script_config=config)
loss = loss_fn(pred=pred, target=hr_patch) / config.grad_step
loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
if i % config.grad_step == 0 and i > 0:
optimizer.step()
optimizer.zero_grad()
torch.cuda.empty_cache()
forward_backward_time += time.time() - start_time
# For monitoring
accum_samples += config.batch_size
l_accum[0] += loss.item()
# Show information
if i % config.display_step == 0:
config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i)
config.writer.add_scalar('loss', loss.item(), i)
config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, loss:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
model.__class__.__name__,
i,
accum_samples,
l_accum[0] / config.display_step,
prepare_data_time / config.display_step,
forward_backward_time / config.display_step))
l_accum = [0., 0., 0.]
prepare_data_time = 0.
forward_backward_time = 0.
# Save models
if i % config.save_step == 0:
SaveCheckpoint(model=model, path=Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth")
# Validation
if i % config.val_step == 0:
test_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve()
SaveCheckpoint(model=model, path=model_path)
print("Saved to ", model_path)
# check if it is network or lut
if 'net' in model.__class__.__name__.lower():
link = Path(config.experiment_dir / f"last_trained_net.pth")
if link.exists(follow_symlinks=False):
link.unlink()
link.symlink_to(model_path)
else:
link = Path(config.experiment_dir / f"last_trained_lut.pth")
if link.exists(follow_symlinks=False):
link.unlink()
link.symlink_to(model_path)
link = Path(config.experiment_dir / f"last.pth")
if link.exists(follow_symlinks=False):
link.unlink()
link.symlink_to(model_path)
total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}")