From 274abed989705d0e8560f6c705a6991a35a31992 Mon Sep 17 00:00:00 2001 From: protsenkovi Date: Fri, 5 Jul 2024 00:12:55 +0400 Subject: [PATCH] update --- src/common/losses.py | 368 ++++++++++++++++++++++++++++++++++++ src/common/utils.py | 1 + src/eval_bicubic_metrics.py | 193 ------------------- src/image_demo.py | 7 +- src/models/sdynet.py | 258 +++++++++++++++++++++++++ src/train.py | 29 ++- 6 files changed, 655 insertions(+), 201 deletions(-) delete mode 100644 src/eval_bicubic_metrics.py diff --git a/src/common/losses.py b/src/common/losses.py index 69f9aab..00ac9c6 100644 --- a/src/common/losses.py +++ b/src/common/losses.py @@ -1,6 +1,9 @@ import torch from torch import nn +from typing import List, Optional, Tuple, Union +import torch.nn.functional as F +from torch import Tensor class FourierLoss(nn.Module): def __init__(self, weight=None, size_average=True): @@ -151,3 +154,368 @@ class FocalFrequencyLoss(nn.Module): # calculate focal frequency loss return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight + + + + + +def _fspecial_gauss_1d(size: int, sigma: float) -> Tensor: + r"""Create 1-D gauss kernel + Args: + size (int): the size of gauss kernel + sigma (float): sigma of normal distribution + Returns: + torch.Tensor: 1D kernel (1 x 1 x size) + """ + coords = torch.arange(size, dtype=torch.float) + coords -= size // 2 + + g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) + g /= g.sum() + + return g.unsqueeze(0).unsqueeze(0) + + +def gaussian_filter(input: Tensor, win: Tensor) -> Tensor: + r""" Blur input with 1-D kernel + Args: + input (torch.Tensor): a batch of tensors to be blurred + window (torch.Tensor): 1-D gauss kernel + Returns: + torch.Tensor: blurred tensors + """ + assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape + if len(input.shape) == 4: + conv = F.conv2d + elif len(input.shape) == 5: + conv = F.conv3d + else: + raise NotImplementedError(input.shape) + + C = input.shape[1] + out = input + for i, s in enumerate(input.shape[2:]): + if s >= win.shape[-1]: + out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C) + else: + warnings.warn( + f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}" + ) + + return out + + +def _ssim( + X: Tensor, + Y: Tensor, + data_range: float, + win: Tensor, + size_average: bool = True, + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) +) -> Tuple[Tensor, Tensor]: + r""" Calculate ssim index for X and Y + + Args: + X (torch.Tensor): images + Y (torch.Tensor): images + data_range (float or int): value range of input images. (usually 1.0 or 255) + win (torch.Tensor): 1-D gauss kernel + size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar + + Returns: + Tuple[torch.Tensor, torch.Tensor]: ssim results. + """ + K1, K2 = K + # batch, channel, [depth,] height, width = X.shape + compensation = 1.0 + + C1 = (K1 * data_range) ** 2 + C2 = (K2 * data_range) ** 2 + + win = win.to(X.device, dtype=X.dtype) + + mu1 = gaussian_filter(X, win) + mu2 = gaussian_filter(Y, win) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) + sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) + sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) + + cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 + ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map + + ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) + cs = torch.flatten(cs_map, 2).mean(-1) + return ssim_per_channel, cs + + +def ssim( + X: Tensor, + Y: Tensor, + data_range: float = 255, + size_average: bool = True, + win_size: int = 11, + win_sigma: float = 1.5, + win: Optional[Tensor] = None, + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + nonnegative_ssim: bool = False, +) -> Tensor: + r""" interface of ssim + Args: + X (torch.Tensor): a batch of images, (N,C,H,W) + Y (torch.Tensor): a batch of images, (N,C,H,W) + data_range (float or int, optional): value range of input images. (usually 1.0 or 255) + size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar + win_size: (int, optional): the size of gauss kernel + win_sigma: (float, optional): sigma of normal distribution + win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma + K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu + + Returns: + torch.Tensor: ssim results + """ + if not X.shape == Y.shape: + raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.") + + for d in range(len(X.shape) - 1, 1, -1): + X = X.squeeze(dim=d) + Y = Y.squeeze(dim=d) + + if len(X.shape) not in (4, 5): + raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") + + #if not X.type() == Y.type(): + # raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.") + + if win is not None: # set win_size + win_size = win.shape[-1] + + if not (win_size % 2 == 1): + raise ValueError("Window size should be odd.") + + if win is None: + win = _fspecial_gauss_1d(win_size, win_sigma) + win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) + + ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K) + if nonnegative_ssim: + ssim_per_channel = torch.relu(ssim_per_channel) + + if size_average: + return ssim_per_channel.mean() + else: + return ssim_per_channel.mean(1) + + +def ms_ssim( + X: Tensor, + Y: Tensor, + data_range: float = 255, + size_average: bool = True, + win_size: int = 11, + win_sigma: float = 1.5, + win: Optional[Tensor] = None, + weights: Optional[List[float]] = None, + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) +) -> Tensor: + r""" interface of ms-ssim + Args: + X (torch.Tensor): a batch of images, (N,C,[T,]H,W) + Y (torch.Tensor): a batch of images, (N,C,[T,]H,W) + data_range (float or int, optional): value range of input images. (usually 1.0 or 255) + size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar + win_size: (int, optional): the size of gauss kernel + win_sigma: (float, optional): sigma of normal distribution + win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma + weights (list, optional): weights for different levels + K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + Returns: + torch.Tensor: ms-ssim results + """ + if not X.shape == Y.shape: + raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.") + + for d in range(len(X.shape) - 1, 1, -1): + X = X.squeeze(dim=d) + Y = Y.squeeze(dim=d) + + #if not X.type() == Y.type(): + # raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.") + + if len(X.shape) == 4: + avg_pool = F.avg_pool2d + elif len(X.shape) == 5: + avg_pool = F.avg_pool3d + else: + raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") + + if win is not None: # set win_size + win_size = win.shape[-1] + + if not (win_size % 2 == 1): + raise ValueError("Window size should be odd.") + + smaller_side = min(X.shape[-2:]) + assert smaller_side > (win_size - 1) * ( + 2 ** 4 + ), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4)) + + if weights is None: + weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] + weights_tensor = X.new_tensor(weights) + + if win is None: + win = _fspecial_gauss_1d(win_size, win_sigma) + win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) + + levels = weights_tensor.shape[0] + mcs = [] + for i in range(levels): + ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K) + + if i < levels - 1: + mcs.append(torch.relu(cs)) + padding = [s % 2 for s in X.shape[2:]] + X = avg_pool(X, kernel_size=2, padding=padding) + Y = avg_pool(Y, kernel_size=2, padding=padding) + + ssim_per_channel = torch.relu(ssim_per_channel) # type: ignore # (batch, channel) + mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel) + ms_ssim_val = torch.prod(mcs_and_ssim ** weights_tensor.view(-1, 1, 1), dim=0) + + if size_average: + return ms_ssim_val.mean() + else: + return ms_ssim_val.mean(1) + + +class SSIM(torch.nn.Module): + def __init__( + self, + data_range: float = 255, + size_average: bool = True, + win_size: int = 11, + win_sigma: float = 1.5, + channel: int = 3, + spatial_dims: int = 2, + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + nonnegative_ssim: bool = False, + ) -> None: + r""" class for ssim + Args: + data_range (float or int, optional): value range of input images. (usually 1.0 or 255) + size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar + win_size: (int, optional): the size of gauss kernel + win_sigma: (float, optional): sigma of normal distribution + channel (int, optional): input channels (default: 3) + K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. + """ + + super(SSIM, self).__init__() + self.win_size = win_size + self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) + self.size_average = size_average + self.data_range = data_range + self.K = K + self.nonnegative_ssim = nonnegative_ssim + + def forward(self, X: Tensor, Y: Tensor) -> Tensor: + return ssim( + X, + Y, + data_range=self.data_range, + size_average=self.size_average, + win=self.win, + K=self.K, + nonnegative_ssim=self.nonnegative_ssim, + ) + + +class MS_SSIM(torch.nn.Module): + def __init__( + self, + data_range: float = 255, + size_average: bool = True, + win_size: int = 11, + win_sigma: float = 1.5, + channel: int = 3, + spatial_dims: int = 2, + weights: Optional[List[float]] = None, + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + ) -> None: + r""" class for ms-ssim + Args: + data_range (float or int, optional): value range of input images. (usually 1.0 or 255) + size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar + win_size: (int, optional): the size of gauss kernel + win_sigma: (float, optional): sigma of normal distribution + channel (int, optional): input channels (default: 3) + weights (list, optional): weights for different levels + K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + """ + + super(MS_SSIM, self).__init__() + self.win_size = win_size + self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) + self.size_average = size_average + self.data_range = data_range + self.weights = weights + self.K = K + + def forward(self, X: Tensor, Y: Tensor) -> Tensor: + return ms_ssim( + X, + Y, + data_range=self.data_range, + size_average=self.size_average, + win=self.win, + weights=self.weights, + K=self.K, + ) + +def get_outnorm(x:torch.Tensor, out_norm:str='') -> torch.Tensor: + """ Common function to get a loss normalization value. Can + normalize by either the batch size ('b'), the number of + channels ('c'), the image size ('i') or combinations + ('bi', 'bci', etc) + """ + # b, c, h, w = x.size() + img_shape = x.shape + + if not out_norm: + return 1 + + norm = 1 + if 'b' in out_norm: + # normalize by batch size + # norm /= b + norm /= img_shape[0] + if 'c' in out_norm: + # normalize by the number of channels + # norm /= c + norm /= img_shape[-3] + if 'i' in out_norm: + # normalize by image/map size + # norm /= h*w + norm /= img_shape[-1]*img_shape[-2] + + return norm + + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + def __init__(self, eps=1e-6, out_norm:str='bci'): + super(CharbonnierLoss, self).__init__() + self.eps = eps + self.out_norm = out_norm + + def forward(self, x, y): + norm = get_outnorm(x, self.out_norm) + loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2)) + return loss*norm \ No newline at end of file diff --git a/src/common/utils.py b/src/common/utils.py index a7514d7..c1bedc7 100644 --- a/src/common/utils.py +++ b/src/common/utils.py @@ -6,6 +6,7 @@ from scipy import signal import torch import os + def pil2numpy(image): np_image = np.array(image) if len(np_image.shape) == 2: diff --git a/src/eval_bicubic_metrics.py b/src/eval_bicubic_metrics.py deleted file mode 100644 index 18b0dc3..0000000 --- a/src/eval_bicubic_metrics.py +++ /dev/null @@ -1,193 +0,0 @@ -import os -from pathlib import Path -import numpy as np -import cv2 -from scipy import signal -from skimage.metrics import structural_similarity -from PIL import Image -import argparse - -import time -from datetime import datetime -import ray -ray.init(num_cpus=16, num_gpus=0, ignore_reinit_error=True, log_to_driver=False) - -parser = argparse.ArgumentParser() -parser.add_argument("path_to_dataset", type=str) -parser.add_argument("--scale", type=int, default=4) -args = parser.parse_args() - -def cal_ssim(img1, img2): - K = [0.01, 0.03] - L = 255 - kernelX = cv2.getGaussianKernel(11, 1.5) - window = kernelX * kernelX.T - - M, N = np.shape(img1) - - C1 = (K[0] * L) ** 2 - C2 = (K[1] * L) ** 2 - img1 = np.float64(img1) - img2 = np.float64(img2) - - mu1 = signal.convolve2d(img1, window, 'valid') - mu2 = signal.convolve2d(img2, window, 'valid') - - mu1_sq = mu1 * mu1 - mu2_sq = mu2 * mu2 - mu1_mu2 = mu1 * mu2 - - sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq - sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq - sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2 - - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) - mssim = np.mean(ssim_map) - return mssim - -def PSNR(y_true, y_pred, shave_border=4): - target_data = np.array(y_true, dtype=np.float32) - ref_data = np.array(y_pred, dtype=np.float32) - - diff = ref_data - target_data - if shave_border > 0: - diff = diff[shave_border:-shave_border, shave_border:-shave_border] - rmse = np.sqrt(np.mean(np.power(diff, 2))) - - return 20 * np.log10(255. / rmse) - -def _rgb2ycbcr(img, maxVal=255): - O = np.array([[16], - [128], - [128]]) - T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941], - [-0.148223529411765, -0.290992156862745, 0.439215686274510], - [0.439215686274510, -0.367788235294118, -0.071427450980392]]) - - if maxVal == 1: - O = O / 255.0 - - t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2])) - t = np.dot(t, np.transpose(T)) - t[:, 0] += O[0] - t[:, 1] += O[1] - t[:, 2] += O[2] - ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]]) - - return ycbcr - - -def modcrop(image, modulo): - if len(image.shape) == 2: - sz = image.shape - sz = sz - np.mod(sz, modulo) - image = image[0:sz[0], 0:sz[1]] - elif image.shape[2] == 3: - sz = image.shape[0:2] - sz = sz - np.mod(sz, modulo) - image = image[0:sz[0], 0:sz[1], :] - else: - raise NotImplementedError - return image - -scale = args.scale - -dataset_path = Path(args.path_to_dataset) -hr_path = dataset_path / "HR/" -lr_path = dataset_path / f"LR_bicubic/X{scale}/" - - -print(hr_path, lr_path) - -hr_files = os.listdir(hr_path) -lr_files = os.listdir(lr_path) - -@ray.remote(num_cpus=1) -def benchmark_image_pair(hr_image_path, lr_image_path, interpolation_function): - hr_image = cv2.imread(hr_image_path) - lr_image = cv2.imread(lr_image_path) - - hr_image = hr_image[:,:,::-1] # BGR -> RGB - lr_image = lr_image[:,:,::-1] # BGR -> RGB - - start_time = datetime.now() - upscaled_lr_image = interpolation_function(lr_image, scale) - processing_time = datetime.now() - start_time - - hr_image = modcrop(hr_image, scale) - upscaled_lr_image = upscaled_lr_image - - psnr = PSNR(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0]) - cpsnr = PSNR(hr_image, upscaled_lr_image) - - cv2_psnr = cv2.PSNR(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0]) - cv2_cpsnr = cv2.PSNR(hr_image, upscaled_lr_image) - - ssim = cal_ssim(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0]) - cv2_ssim = cal_ssim(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0]) - ssim_scikit, diff = structural_similarity(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0], full=True, data_range=255.0) - cv2_scikit_ssim, diff = structural_similarity(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], full=True, data_range=255.0) - - return ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time.total_seconds() - - -def benchmark_interpolation(interpolation_function): - psnrs, cpsnrs, ssims = [], [], [] - cv2_psnrs, cv2_cpsnrs, scikit_ssims = [], [], [] - cv2_scikit_ssims = [] - cv2_ssims = [] - tasks = [] - for hr_name, lr_name in zip(hr_files, lr_files): - hr_image_path = str(hr_path / hr_name) - lr_image_path = str(lr_path / lr_name) - tasks.append(benchmark_image_pair.remote(hr_image_path, lr_image_path, interpolation_function)) - - ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None) - while len(remaining_refs) > 0: - print(f"\rReady {len(ready_refs)}/{len(hr_files)}", end=" ") - ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None) - - for task in tasks: - ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time = ray.get(task) - ssims.append(ssim) - cv2_ssims.append(cv2_ssim) - scikit_ssims.append(ssim_scikit) - cv2_scikit_ssims.append(cv2_scikit_ssim) - psnrs.append(psnr) - cpsnrs.append(cpsnr) - cv2_psnrs.append(cv2_psnr) - cv2_cpsnrs.append(cv2_cpsnr) - processing_times.append(processing_time) - - print() - print(f"AVG PSNR: {np.mean(psnrs):.2f} PSNR + _rgb2ycbcr") - print(f"AVG PSNR: {np.mean(cv2_psnrs):.2f} cv2.PSNR + cv2.cvtColor") - print(f"AVG cPSNR: {np.mean(cpsnrs):.2f} PSNR") - print(f"AVG cPSNR: {np.mean(cv2_cpsnrs):.2f} cv2.PSNR ") - print(f"AVG SSIM: {np.mean(ssims):.4f} cal_ssim + _rgb2ycbcr") - print(f"AVG SSIM: {np.mean(cv2_ssims):.4f} cal_ssim + cv2.cvtColor") - print(f"AVG SSIM: {np.mean(scikit_ssims):.4f} structural_similarity + _rgb2ycbcr") - print(f"AVG SSIM: {np.mean(cv2_scikit_ssims):.4f} structural_similarity + cv2.cvtColor") - print(f"AVG Time s: {np.percentile(processing_times, q=0.9)}") - print(f"{np.mean(psnrs):.2f},{np.mean(cv2_psnrs):.2f},{np.mean(cpsnrs):.2f},{np.mean(cv2_cpsnrs):.2f},{np.mean(ssims):.4f},{np.mean(cv2_ssims):.4f},{np.mean(scikit_ssims):.4f},{np.mean(cv2_scikit_ssims):.4f},{np.percentile(processing_times, q=0.9)}") - -def cv2_interpolation(image, scale): - scaled_image = cv2.resize( - image, - None, None, - fx=scale, fy=scale, - interpolation=cv2.INTER_CUBIC - ) - return scaled_image - -def pillow_interpolation(image, scale): - image = Image.fromarray(image[:,:,::-1]) - width, height = int(image.width * scale), int(image.height * scale) - scaled_image = image.resize((width, height), resample=Image.Resampling.BICUBIC) - return np.array(scaled_image)[:,:,::-1] - -print("cv2 bicubic interpolation") -benchmark_interpolation(cv2_interpolation) -print() -print("pillow bicubic interpolation") -benchmark_interpolation(pillow_interpolation) \ No newline at end of file diff --git a/src/image_demo.py b/src/image_demo.py index eb1c38b..27aa11e 100644 --- a/src/image_demo.py +++ b/src/image_demo.py @@ -1,6 +1,7 @@ from pathlib import Path import sys -from common.utils import PIL_CONVERT_COLOR, pil2numpy +from common.color import PIL_CONVERT_COLOR +from common.utils import pil2numpy from models import LoadCheckpoint import torch import numpy as np @@ -12,10 +13,10 @@ import argparse class ImageDemoOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../models/last_transfered_net.pth","../models/last_transfered_lut.pth"], help="Model paths for comparison") + self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../experiments/last_transfered_net.pth","../experiments/last_transfered_lut.pth"], help="Model paths for comparison") self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path") self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") - self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.") + self.parser.add_argument('--output_path', type=str, default="../experiments/", help="Output path.") self.parser.add_argument('--output_name', type=str, default="image_demo.png", help="Output name.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--mirror', action='store_true', default=False) diff --git a/src/models/sdynet.py b/src/models/sdynet.py index cbac125..14e133e 100644 --- a/src/models/sdynet.py +++ b/src/models/sdynet.py @@ -5,6 +5,7 @@ import numpy as np from common.utils import round_func from common import lut from common import layers +from common import losses from pathlib import Path from . import sdylut from models.base import SRNetBase @@ -534,4 +535,261 @@ class SDYMixNetx1v3(SRNetBase): def get_loss_fn(self): def loss_fn(pred, target): return F.mse_loss(pred/255, target/255) + return loss_fn + +class SDYMixNetx1v4(SRNetBase): + """ + 22 + 12 23 32 21 + 11 13 33 31 + 10 14 34 30 + 01 03 43 41 + 00 04 44 40 + """ + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYMixNetx1v4, self).__init__() + self.scale = scale + + self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes= [[3,3]], center=[3,3], window_size=7) + self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes= [[2,3],[3,4],[4,3],[3,2]], center=[3,3], window_size=7) + self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes= [[2,2],[2,4],[4,4],[4,2]], center=[3,3], window_size=7) + self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes= [[1,2],[1,4],[5,4],[5,2]], center=[3,3], window_size=7) + self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes= [[1,1],[1,5],[5,5],[5,1]], center=[3,3], window_size=7) + self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes= [[2,1],[2,5],[4,5],[4,1]], center=[3,3], window_size=7) + self._extract_pattern_7 = layers.PercievePattern(receptive_field_idxes= [[3,1],[1,3],[3,5],[5,3]], center=[3,3], window_size=7) + self._extract_pattern_8 = layers.PercievePattern(receptive_field_idxes= [[0,0],[0,6],[6,6],[6,0]], center=[3,3], window_size=7) + self._extract_pattern_9 = layers.PercievePattern(receptive_field_idxes= [[0,3],[3,6],[6,3],[3,0]], center=[3,3], window_size=7) + self._extract_pattern_10 = layers.PercievePattern(receptive_field_idxes=[[2,0],[1,0],[0,1],[0,2]], center=[3,3], window_size=7) + self._extract_pattern_11 = layers.PercievePattern(receptive_field_idxes=[[0,4],[0,5],[1,6],[2,6]], center=[3,3], window_size=7) + self._extract_pattern_12 = layers.PercievePattern(receptive_field_idxes=[[4,6],[5,6],[6,5],[6,4]], center=[3,3], window_size=7) + self._extract_pattern_13 = layers.PercievePattern(receptive_field_idxes=[[6,2],[6,1],[5,0],[4,0]], center=[3,3], window_size=7) + + self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_7 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_8 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_9 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_10 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_11 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_12 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_13 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + + self._extract_pattern_mix = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=13) + self.stage1_Mix = layers.UpscaleBlockChebyKAN(in_features=13, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_7, self.stage1_7)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_8, self.stage1_8)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_9, self.stage1_9)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_10, self.stage1_10)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_11, self.stage1_11)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_12, self.stage1_12)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_13, self.stage1_13)], dim=1) + output = self.forward_stage(output, self._extract_pattern_mix, self.stage1_Mix) + x = output + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + +class SDYMixNetx1v5(SRNetBase): + """ + 22 + 12 23 32 21 + 11 13 33 31 + 10 14 34 30 + 01 03 43 41 + 00 04 44 40 + """ + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYMixNetx1v5, self).__init__() + self.scale = scale + + self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes= [[0,0]], center=[0,0], window_size=1) + self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes= [[0,1],[1,2],[2,1],[1,0]], center=[1,1], window_size=3) + self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes= [[0,0],[0,4],[4,4],[4,0]], center=[2,2], window_size=5) + self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes= [[0,3],[3,6],[6,3],[3,0]], center=[3,3], window_size=7) + + self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + self._extract_pattern_mix = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=4) + self.stage1_Mix = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1) + output = self.forward_stage(output, self._extract_pattern_mix, self.stage1_Mix) + x = output + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + return loss_fn + +class SDYMixNetx1v6(SRNetBase): + """ + 22 + 12 23 32 21 + 11 13 33 31 + 10 14 34 30 + 01 03 43 41 + 00 04 44 40 + """ + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYMixNetx1v6, self).__init__() + self.scale = scale + + self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes= [[3,3]], center=[3,3], window_size=7) + self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes= [[2,3],[3,4],[4,3],[3,2]], center=[3,3], window_size=7) + self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes= [[2,2],[2,4],[4,4],[4,2]], center=[3,3], window_size=7) + self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes= [[1,2],[1,4],[5,4],[5,2]], center=[3,3], window_size=7) + self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes= [[1,1],[1,5],[5,5],[5,1]], center=[3,3], window_size=7) + self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes= [[2,1],[2,5],[4,5],[4,1]], center=[3,3], window_size=7) + self._extract_pattern_7 = layers.PercievePattern(receptive_field_idxes= [[3,1],[1,3],[3,5],[5,3]], center=[3,3], window_size=7) + self._extract_pattern_8 = layers.PercievePattern(receptive_field_idxes= [[0,0],[0,6],[6,6],[6,0]], center=[3,3], window_size=7) + self._extract_pattern_9 = layers.PercievePattern(receptive_field_idxes= [[0,3],[3,6],[6,3],[3,0]], center=[3,3], window_size=7) + self._extract_pattern_10 = layers.PercievePattern(receptive_field_idxes=[[2,0],[1,0],[0,1],[0,2]], center=[3,3], window_size=7) + self._extract_pattern_11 = layers.PercievePattern(receptive_field_idxes=[[0,4],[0,5],[1,6],[2,6]], center=[3,3], window_size=7) + self._extract_pattern_12 = layers.PercievePattern(receptive_field_idxes=[[4,6],[5,6],[6,5],[6,4]], center=[3,3], window_size=7) + self._extract_pattern_13 = layers.PercievePattern(receptive_field_idxes=[[6,2],[6,1],[5,0],[4,0]], center=[3,3], window_size=7) + + self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_7 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_8 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_9 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_10 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_11 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_12 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + self.stage1_13 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale) + + self._extract_pattern_mix = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=13) + self.stage1_Mix = layers.UpscaleBlock(in_features=13, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_7, self.stage1_7)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_8, self.stage1_8)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_9, self.stage1_9)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_10, self.stage1_10)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_11, self.stage1_11)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_12, self.stage1_12)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_13, self.stage1_13)], dim=1) + output = self.forward_stage(output, self._extract_pattern_mix, self.stage1_Mix) + x = output + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_loss_fn(self): + def loss_fn(pred, target): + return F.mse_loss(pred/255, target/255) + ssim_loss + return loss_fn + + +class SDYMixNetx1v7(SRNetBase): + """ + 22 + 12 23 32 21 + 11 13 33 31 + 10 14 34 30 + 01 03 43 41 + 00 04 44 40 + """ + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SDYMixNetx1v7, self).__init__() + self.scale = scale + + self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes= [[3,3]], center=[3,3], window_size=7) + self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes= [[2,3],[3,4],[4,3],[3,2]], center=[3,3], window_size=7) + self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes= [[2,2],[2,4],[4,4],[4,2]], center=[3,3], window_size=7) + self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes= [[1,2],[1,4],[5,4],[5,2]], center=[3,3], window_size=7) + self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes= [[1,1],[1,5],[5,5],[5,1]], center=[3,3], window_size=7) + self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes= [[2,1],[2,5],[4,5],[4,1]], center=[3,3], window_size=7) + self._extract_pattern_7 = layers.PercievePattern(receptive_field_idxes= [[3,1],[1,3],[3,5],[5,3]], center=[3,3], window_size=7) + self._extract_pattern_8 = layers.PercievePattern(receptive_field_idxes= [[0,0],[0,6],[6,6],[6,0]], center=[3,3], window_size=7) + self._extract_pattern_9 = layers.PercievePattern(receptive_field_idxes= [[0,3],[3,6],[6,3],[3,0]], center=[3,3], window_size=7) + self._extract_pattern_10 = layers.PercievePattern(receptive_field_idxes=[[2,0],[1,0],[0,1],[0,2]], center=[3,3], window_size=7) + self._extract_pattern_11 = layers.PercievePattern(receptive_field_idxes=[[0,4],[0,5],[1,6],[2,6]], center=[3,3], window_size=7) + self._extract_pattern_12 = layers.PercievePattern(receptive_field_idxes=[[4,6],[5,6],[6,5],[6,4]], center=[3,3], window_size=7) + self._extract_pattern_13 = layers.PercievePattern(receptive_field_idxes=[[6,2],[6,1],[5,0],[4,0]], center=[3,3], window_size=7) + + self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_7 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_8 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_9 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_10 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_11 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_12 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + self.stage1_13 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale) + + self._extract_pattern_mix = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=13) + self.stage1_Mix = layers.UpscaleBlock(in_features=13, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + + def forward(self, x, config=None): + b,c,h,w = x.shape + x = x.reshape(b*c, 1, h, w) + output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_7, self.stage1_7)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_8, self.stage1_8)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_9, self.stage1_9)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_10, self.stage1_10)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_11, self.stage1_11)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_12, self.stage1_12)], dim=1) + output = torch.cat([output, self.forward_stage(x, self._extract_pattern_13, self.stage1_13)], dim=1) + output = self.forward_stage(output, self._extract_pattern_mix, self.stage1_Mix) + x = output + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_loss_fn(self): + ssim_loss = losses.SSIM(channel=1, data_range=255) + l1_loss = losses.CharbonnierLoss() + def loss_fn(pred, target): + # return F.mse_loss(pred/255, target/255)# + ssim_loss + return ssim_loss(pred, target) + l1_loss(pred, target) return loss_fn \ No newline at end of file diff --git a/src/train.py b/src/train.py index f5f1606..e0e86a5 100644 --- a/src/train.py +++ b/src/train.py @@ -27,6 +27,19 @@ import argparse from schedulefree import AdamWScheduleFree from datetime import datetime +import signal + +class SignalHandler: + def __init__(self, signal_code): + self.is_on = False + signal.signal(signal_code, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + print("Early stopping.") + self.is_on = True + +signal_interraption_handler = SignalHandler(signal.SIGINT) + class TrainOptions: def __init__(self): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False) @@ -55,6 +68,7 @@ class TrainOptions: parser.add_argument('--color_model', type=str, default="RGB", help=f"Color model for train and test dataset. Choose from: {list(PIL_CONVERT_COLOR.keys())}") parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache') parser.add_argument('--learning_rate', type=float, default=0.0025, help='Learning rate') + parser.add_argument('--grad_step', type=int, default=1, help='Optimizer step.') self.parser = parser @@ -186,6 +200,8 @@ if __name__ == "__main__": loss_fn = model.get_loss_fn() for i in range(config.start_iter + 1, config.total_iter + 1): + if signal_interraption_handler.is_on: + break config.current_iter = i torch.cuda.empty_cache() start_time = time.time() @@ -199,12 +215,15 @@ if __name__ == "__main__": prepare_data_time += time.time() - start_time start_time = time.time() + with torch.set_grad_enabled(True): + pred = model(x=lr_patch, config=config) + loss = loss_fn(pred=pred, target=hr_patch) / config.grad_step + loss.backward() + + if i % config.grad_step == 0 and i > 0: + optimizer.step() + optimizer.zero_grad() - pred = model(x=lr_patch, config=config) - loss = loss_fn(pred=pred, target=hr_patch) - loss.backward() - optimizer.step() - optimizer.zero_grad() torch.cuda.empty_cache() forward_backward_time += time.time() - start_time