main
protsenkovi 6 months ago
parent 5c22148bed
commit 274abed989

@ -1,6 +1,9 @@
import torch import torch
from torch import nn from torch import nn
from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
from torch import Tensor
class FourierLoss(nn.Module): class FourierLoss(nn.Module):
def __init__(self, weight=None, size_average=True): def __init__(self, weight=None, size_average=True):
@ -151,3 +154,368 @@ class FocalFrequencyLoss(nn.Module):
# calculate focal frequency loss # calculate focal frequency loss
return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight
def _fspecial_gauss_1d(size: int, sigma: float) -> Tensor:
r"""Create 1-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 1D kernel (1 x 1 x size)
"""
coords = torch.arange(size, dtype=torch.float)
coords -= size // 2
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
g /= g.sum()
return g.unsqueeze(0).unsqueeze(0)
def gaussian_filter(input: Tensor, win: Tensor) -> Tensor:
r""" Blur input with 1-D kernel
Args:
input (torch.Tensor): a batch of tensors to be blurred
window (torch.Tensor): 1-D gauss kernel
Returns:
torch.Tensor: blurred tensors
"""
assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
if len(input.shape) == 4:
conv = F.conv2d
elif len(input.shape) == 5:
conv = F.conv3d
else:
raise NotImplementedError(input.shape)
C = input.shape[1]
out = input
for i, s in enumerate(input.shape[2:]):
if s >= win.shape[-1]:
out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
else:
warnings.warn(
f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
)
return out
def _ssim(
X: Tensor,
Y: Tensor,
data_range: float,
win: Tensor,
size_average: bool = True,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
) -> Tuple[Tensor, Tensor]:
r""" Calculate ssim index for X and Y
Args:
X (torch.Tensor): images
Y (torch.Tensor): images
data_range (float or int): value range of input images. (usually 1.0 or 255)
win (torch.Tensor): 1-D gauss kernel
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
Returns:
Tuple[torch.Tensor, torch.Tensor]: ssim results.
"""
K1, K2 = K
# batch, channel, [depth,] height, width = X.shape
compensation = 1.0
C1 = (K1 * data_range) ** 2
C2 = (K2 * data_range) ** 2
win = win.to(X.device, dtype=X.dtype)
mu1 = gaussian_filter(X, win)
mu2 = gaussian_filter(Y, win)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
cs = torch.flatten(cs_map, 2).mean(-1)
return ssim_per_channel, cs
def ssim(
X: Tensor,
Y: Tensor,
data_range: float = 255,
size_average: bool = True,
win_size: int = 11,
win_sigma: float = 1.5,
win: Optional[Tensor] = None,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
nonnegative_ssim: bool = False,
) -> Tensor:
r""" interface of ssim
Args:
X (torch.Tensor): a batch of images, (N,C,H,W)
Y (torch.Tensor): a batch of images, (N,C,H,W)
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
Returns:
torch.Tensor: ssim results
"""
if not X.shape == Y.shape:
raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")
for d in range(len(X.shape) - 1, 1, -1):
X = X.squeeze(dim=d)
Y = Y.squeeze(dim=d)
if len(X.shape) not in (4, 5):
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
#if not X.type() == Y.type():
# raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
if win is not None: # set win_size
win_size = win.shape[-1]
if not (win_size % 2 == 1):
raise ValueError("Window size should be odd.")
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)
if nonnegative_ssim:
ssim_per_channel = torch.relu(ssim_per_channel)
if size_average:
return ssim_per_channel.mean()
else:
return ssim_per_channel.mean(1)
def ms_ssim(
X: Tensor,
Y: Tensor,
data_range: float = 255,
size_average: bool = True,
win_size: int = 11,
win_sigma: float = 1.5,
win: Optional[Tensor] = None,
weights: Optional[List[float]] = None,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
) -> Tensor:
r""" interface of ms-ssim
Args:
X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
weights (list, optional): weights for different levels
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
Returns:
torch.Tensor: ms-ssim results
"""
if not X.shape == Y.shape:
raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")
for d in range(len(X.shape) - 1, 1, -1):
X = X.squeeze(dim=d)
Y = Y.squeeze(dim=d)
#if not X.type() == Y.type():
# raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
if len(X.shape) == 4:
avg_pool = F.avg_pool2d
elif len(X.shape) == 5:
avg_pool = F.avg_pool3d
else:
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
if win is not None: # set win_size
win_size = win.shape[-1]
if not (win_size % 2 == 1):
raise ValueError("Window size should be odd.")
smaller_side = min(X.shape[-2:])
assert smaller_side > (win_size - 1) * (
2 ** 4
), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4))
if weights is None:
weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
weights_tensor = X.new_tensor(weights)
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
levels = weights_tensor.shape[0]
mcs = []
for i in range(levels):
ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)
if i < levels - 1:
mcs.append(torch.relu(cs))
padding = [s % 2 for s in X.shape[2:]]
X = avg_pool(X, kernel_size=2, padding=padding)
Y = avg_pool(Y, kernel_size=2, padding=padding)
ssim_per_channel = torch.relu(ssim_per_channel) # type: ignore # (batch, channel)
mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel)
ms_ssim_val = torch.prod(mcs_and_ssim ** weights_tensor.view(-1, 1, 1), dim=0)
if size_average:
return ms_ssim_val.mean()
else:
return ms_ssim_val.mean(1)
class SSIM(torch.nn.Module):
def __init__(
self,
data_range: float = 255,
size_average: bool = True,
win_size: int = 11,
win_sigma: float = 1.5,
channel: int = 3,
spatial_dims: int = 2,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
nonnegative_ssim: bool = False,
) -> None:
r""" class for ssim
Args:
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
channel (int, optional): input channels (default: 3)
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.
"""
super(SSIM, self).__init__()
self.win_size = win_size
self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
self.size_average = size_average
self.data_range = data_range
self.K = K
self.nonnegative_ssim = nonnegative_ssim
def forward(self, X: Tensor, Y: Tensor) -> Tensor:
return ssim(
X,
Y,
data_range=self.data_range,
size_average=self.size_average,
win=self.win,
K=self.K,
nonnegative_ssim=self.nonnegative_ssim,
)
class MS_SSIM(torch.nn.Module):
def __init__(
self,
data_range: float = 255,
size_average: bool = True,
win_size: int = 11,
win_sigma: float = 1.5,
channel: int = 3,
spatial_dims: int = 2,
weights: Optional[List[float]] = None,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
) -> None:
r""" class for ms-ssim
Args:
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
channel (int, optional): input channels (default: 3)
weights (list, optional): weights for different levels
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
"""
super(MS_SSIM, self).__init__()
self.win_size = win_size
self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
self.size_average = size_average
self.data_range = data_range
self.weights = weights
self.K = K
def forward(self, X: Tensor, Y: Tensor) -> Tensor:
return ms_ssim(
X,
Y,
data_range=self.data_range,
size_average=self.size_average,
win=self.win,
weights=self.weights,
K=self.K,
)
def get_outnorm(x:torch.Tensor, out_norm:str='') -> torch.Tensor:
""" Common function to get a loss normalization value. Can
normalize by either the batch size ('b'), the number of
channels ('c'), the image size ('i') or combinations
('bi', 'bci', etc)
"""
# b, c, h, w = x.size()
img_shape = x.shape
if not out_norm:
return 1
norm = 1
if 'b' in out_norm:
# normalize by batch size
# norm /= b
norm /= img_shape[0]
if 'c' in out_norm:
# normalize by the number of channels
# norm /= c
norm /= img_shape[-3]
if 'i' in out_norm:
# normalize by image/map size
# norm /= h*w
norm /= img_shape[-1]*img_shape[-2]
return norm
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-6, out_norm:str='bci'):
super(CharbonnierLoss, self).__init__()
self.eps = eps
self.out_norm = out_norm
def forward(self, x, y):
norm = get_outnorm(x, self.out_norm)
loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2))
return loss*norm

@ -6,6 +6,7 @@ from scipy import signal
import torch import torch
import os import os
def pil2numpy(image): def pil2numpy(image):
np_image = np.array(image) np_image = np.array(image)
if len(np_image.shape) == 2: if len(np_image.shape) == 2:

@ -1,193 +0,0 @@
import os
from pathlib import Path
import numpy as np
import cv2
from scipy import signal
from skimage.metrics import structural_similarity
from PIL import Image
import argparse
import time
from datetime import datetime
import ray
ray.init(num_cpus=16, num_gpus=0, ignore_reinit_error=True, log_to_driver=False)
parser = argparse.ArgumentParser()
parser.add_argument("path_to_dataset", type=str)
parser.add_argument("--scale", type=int, default=4)
args = parser.parse_args()
def cal_ssim(img1, img2):
K = [0.01, 0.03]
L = 255
kernelX = cv2.getGaussianKernel(11, 1.5)
window = kernelX * kernelX.T
M, N = np.shape(img1)
C1 = (K[0] * L) ** 2
C2 = (K[1] * L) ** 2
img1 = np.float64(img1)
img2 = np.float64(img2)
mu1 = signal.convolve2d(img1, window, 'valid')
mu2 = signal.convolve2d(img2, window, 'valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
mssim = np.mean(ssim_map)
return mssim
def PSNR(y_true, y_pred, shave_border=4):
target_data = np.array(y_true, dtype=np.float32)
ref_data = np.array(y_pred, dtype=np.float32)
diff = ref_data - target_data
if shave_border > 0:
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
rmse = np.sqrt(np.mean(np.power(diff, 2)))
return 20 * np.log10(255. / rmse)
def _rgb2ycbcr(img, maxVal=255):
O = np.array([[16],
[128],
[128]])
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
if maxVal == 1:
O = O / 255.0
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
t = np.dot(t, np.transpose(T))
t[:, 0] += O[0]
t[:, 1] += O[1]
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr
def modcrop(image, modulo):
if len(image.shape) == 2:
sz = image.shape
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1]]
elif image.shape[2] == 3:
sz = image.shape[0:2]
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1], :]
else:
raise NotImplementedError
return image
scale = args.scale
dataset_path = Path(args.path_to_dataset)
hr_path = dataset_path / "HR/"
lr_path = dataset_path / f"LR_bicubic/X{scale}/"
print(hr_path, lr_path)
hr_files = os.listdir(hr_path)
lr_files = os.listdir(lr_path)
@ray.remote(num_cpus=1)
def benchmark_image_pair(hr_image_path, lr_image_path, interpolation_function):
hr_image = cv2.imread(hr_image_path)
lr_image = cv2.imread(lr_image_path)
hr_image = hr_image[:,:,::-1] # BGR -> RGB
lr_image = lr_image[:,:,::-1] # BGR -> RGB
start_time = datetime.now()
upscaled_lr_image = interpolation_function(lr_image, scale)
processing_time = datetime.now() - start_time
hr_image = modcrop(hr_image, scale)
upscaled_lr_image = upscaled_lr_image
psnr = PSNR(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0])
cpsnr = PSNR(hr_image, upscaled_lr_image)
cv2_psnr = cv2.PSNR(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0])
cv2_cpsnr = cv2.PSNR(hr_image, upscaled_lr_image)
ssim = cal_ssim(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0])
cv2_ssim = cal_ssim(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0])
ssim_scikit, diff = structural_similarity(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0], full=True, data_range=255.0)
cv2_scikit_ssim, diff = structural_similarity(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], full=True, data_range=255.0)
return ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time.total_seconds()
def benchmark_interpolation(interpolation_function):
psnrs, cpsnrs, ssims = [], [], []
cv2_psnrs, cv2_cpsnrs, scikit_ssims = [], [], []
cv2_scikit_ssims = []
cv2_ssims = []
tasks = []
for hr_name, lr_name in zip(hr_files, lr_files):
hr_image_path = str(hr_path / hr_name)
lr_image_path = str(lr_path / lr_name)
tasks.append(benchmark_image_pair.remote(hr_image_path, lr_image_path, interpolation_function))
ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
while len(remaining_refs) > 0:
print(f"\rReady {len(ready_refs)}/{len(hr_files)}", end=" ")
ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None)
for task in tasks:
ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time = ray.get(task)
ssims.append(ssim)
cv2_ssims.append(cv2_ssim)
scikit_ssims.append(ssim_scikit)
cv2_scikit_ssims.append(cv2_scikit_ssim)
psnrs.append(psnr)
cpsnrs.append(cpsnr)
cv2_psnrs.append(cv2_psnr)
cv2_cpsnrs.append(cv2_cpsnr)
processing_times.append(processing_time)
print()
print(f"AVG PSNR: {np.mean(psnrs):.2f} PSNR + _rgb2ycbcr")
print(f"AVG PSNR: {np.mean(cv2_psnrs):.2f} cv2.PSNR + cv2.cvtColor")
print(f"AVG cPSNR: {np.mean(cpsnrs):.2f} PSNR")
print(f"AVG cPSNR: {np.mean(cv2_cpsnrs):.2f} cv2.PSNR ")
print(f"AVG SSIM: {np.mean(ssims):.4f} cal_ssim + _rgb2ycbcr")
print(f"AVG SSIM: {np.mean(cv2_ssims):.4f} cal_ssim + cv2.cvtColor")
print(f"AVG SSIM: {np.mean(scikit_ssims):.4f} structural_similarity + _rgb2ycbcr")
print(f"AVG SSIM: {np.mean(cv2_scikit_ssims):.4f} structural_similarity + cv2.cvtColor")
print(f"AVG Time s: {np.percentile(processing_times, q=0.9)}")
print(f"{np.mean(psnrs):.2f},{np.mean(cv2_psnrs):.2f},{np.mean(cpsnrs):.2f},{np.mean(cv2_cpsnrs):.2f},{np.mean(ssims):.4f},{np.mean(cv2_ssims):.4f},{np.mean(scikit_ssims):.4f},{np.mean(cv2_scikit_ssims):.4f},{np.percentile(processing_times, q=0.9)}")
def cv2_interpolation(image, scale):
scaled_image = cv2.resize(
image,
None, None,
fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC
)
return scaled_image
def pillow_interpolation(image, scale):
image = Image.fromarray(image[:,:,::-1])
width, height = int(image.width * scale), int(image.height * scale)
scaled_image = image.resize((width, height), resample=Image.Resampling.BICUBIC)
return np.array(scaled_image)[:,:,::-1]
print("cv2 bicubic interpolation")
benchmark_interpolation(cv2_interpolation)
print()
print("pillow bicubic interpolation")
benchmark_interpolation(pillow_interpolation)

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
import sys import sys
from common.utils import PIL_CONVERT_COLOR, pil2numpy from common.color import PIL_CONVERT_COLOR
from common.utils import pil2numpy
from models import LoadCheckpoint from models import LoadCheckpoint
import torch import torch
import numpy as np import numpy as np
@ -12,10 +13,10 @@ import argparse
class ImageDemoOptions(): class ImageDemoOptions():
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../models/last_transfered_net.pth","../models/last_transfered_lut.pth"], help="Model paths for comparison") self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../experiments/last_transfered_net.pth","../experiments/last_transfered_lut.pth"], help="Model paths for comparison")
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path") self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path")
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.") self.parser.add_argument('--output_path', type=str, default="../experiments/", help="Output path.")
self.parser.add_argument('--output_name', type=str, default="image_demo.png", help="Output name.") self.parser.add_argument('--output_name', type=str, default="image_demo.png", help="Output name.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False) self.parser.add_argument('--mirror', action='store_true', default=False)

@ -5,6 +5,7 @@ import numpy as np
from common.utils import round_func from common.utils import round_func
from common import lut from common import lut
from common import layers from common import layers
from common import losses
from pathlib import Path from pathlib import Path
from . import sdylut from . import sdylut
from models.base import SRNetBase from models.base import SRNetBase
@ -535,3 +536,260 @@ class SDYMixNetx1v3(SRNetBase):
def loss_fn(pred, target): def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255) return F.mse_loss(pred/255, target/255)
return loss_fn return loss_fn
class SDYMixNetx1v4(SRNetBase):
"""
22
12 23 32 21
11 13 33 31
10 14 34 30
01 03 43 41
00 04 44 40
"""
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYMixNetx1v4, self).__init__()
self.scale = scale
self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes= [[3,3]], center=[3,3], window_size=7)
self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes= [[2,3],[3,4],[4,3],[3,2]], center=[3,3], window_size=7)
self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes= [[2,2],[2,4],[4,4],[4,2]], center=[3,3], window_size=7)
self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes= [[1,2],[1,4],[5,4],[5,2]], center=[3,3], window_size=7)
self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes= [[1,1],[1,5],[5,5],[5,1]], center=[3,3], window_size=7)
self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes= [[2,1],[2,5],[4,5],[4,1]], center=[3,3], window_size=7)
self._extract_pattern_7 = layers.PercievePattern(receptive_field_idxes= [[3,1],[1,3],[3,5],[5,3]], center=[3,3], window_size=7)
self._extract_pattern_8 = layers.PercievePattern(receptive_field_idxes= [[0,0],[0,6],[6,6],[6,0]], center=[3,3], window_size=7)
self._extract_pattern_9 = layers.PercievePattern(receptive_field_idxes= [[0,3],[3,6],[6,3],[3,0]], center=[3,3], window_size=7)
self._extract_pattern_10 = layers.PercievePattern(receptive_field_idxes=[[2,0],[1,0],[0,1],[0,2]], center=[3,3], window_size=7)
self._extract_pattern_11 = layers.PercievePattern(receptive_field_idxes=[[0,4],[0,5],[1,6],[2,6]], center=[3,3], window_size=7)
self._extract_pattern_12 = layers.PercievePattern(receptive_field_idxes=[[4,6],[5,6],[6,5],[6,4]], center=[3,3], window_size=7)
self._extract_pattern_13 = layers.PercievePattern(receptive_field_idxes=[[6,2],[6,1],[5,0],[4,0]], center=[3,3], window_size=7)
self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_7 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_8 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_9 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_10 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_11 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_12 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_13 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self._extract_pattern_mix = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=13)
self.stage1_Mix = layers.UpscaleBlockChebyKAN(in_features=13, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_7, self.stage1_7)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_8, self.stage1_8)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_9, self.stage1_9)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_10, self.stage1_10)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_11, self.stage1_11)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_12, self.stage1_12)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_13, self.stage1_13)], dim=1)
output = self.forward_stage(output, self._extract_pattern_mix, self.stage1_Mix)
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYMixNetx1v5(SRNetBase):
"""
22
12 23 32 21
11 13 33 31
10 14 34 30
01 03 43 41
00 04 44 40
"""
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYMixNetx1v5, self).__init__()
self.scale = scale
self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes= [[0,0]], center=[0,0], window_size=1)
self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes= [[0,1],[1,2],[2,1],[1,0]], center=[1,1], window_size=3)
self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes= [[0,0],[0,4],[4,4],[4,0]], center=[2,2], window_size=5)
self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes= [[0,3],[3,6],[6,3],[3,0]], center=[3,3], window_size=7)
self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self._extract_pattern_mix = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=4)
self.stage1_Mix = layers.UpscaleBlock(in_features=4, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1)
output = self.forward_stage(output, self._extract_pattern_mix, self.stage1_Mix)
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYMixNetx1v6(SRNetBase):
"""
22
12 23 32 21
11 13 33 31
10 14 34 30
01 03 43 41
00 04 44 40
"""
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYMixNetx1v6, self).__init__()
self.scale = scale
self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes= [[3,3]], center=[3,3], window_size=7)
self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes= [[2,3],[3,4],[4,3],[3,2]], center=[3,3], window_size=7)
self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes= [[2,2],[2,4],[4,4],[4,2]], center=[3,3], window_size=7)
self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes= [[1,2],[1,4],[5,4],[5,2]], center=[3,3], window_size=7)
self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes= [[1,1],[1,5],[5,5],[5,1]], center=[3,3], window_size=7)
self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes= [[2,1],[2,5],[4,5],[4,1]], center=[3,3], window_size=7)
self._extract_pattern_7 = layers.PercievePattern(receptive_field_idxes= [[3,1],[1,3],[3,5],[5,3]], center=[3,3], window_size=7)
self._extract_pattern_8 = layers.PercievePattern(receptive_field_idxes= [[0,0],[0,6],[6,6],[6,0]], center=[3,3], window_size=7)
self._extract_pattern_9 = layers.PercievePattern(receptive_field_idxes= [[0,3],[3,6],[6,3],[3,0]], center=[3,3], window_size=7)
self._extract_pattern_10 = layers.PercievePattern(receptive_field_idxes=[[2,0],[1,0],[0,1],[0,2]], center=[3,3], window_size=7)
self._extract_pattern_11 = layers.PercievePattern(receptive_field_idxes=[[0,4],[0,5],[1,6],[2,6]], center=[3,3], window_size=7)
self._extract_pattern_12 = layers.PercievePattern(receptive_field_idxes=[[4,6],[5,6],[6,5],[6,4]], center=[3,3], window_size=7)
self._extract_pattern_13 = layers.PercievePattern(receptive_field_idxes=[[6,2],[6,1],[5,0],[4,0]], center=[3,3], window_size=7)
self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_7 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_8 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_9 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_10 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_11 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_12 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self.stage1_13 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=1, upscale_factor=scale)
self._extract_pattern_mix = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=13)
self.stage1_Mix = layers.UpscaleBlock(in_features=13, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_7, self.stage1_7)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_8, self.stage1_8)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_9, self.stage1_9)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_10, self.stage1_10)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_11, self.stage1_11)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_12, self.stage1_12)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_13, self.stage1_13)], dim=1)
output = self.forward_stage(output, self._extract_pattern_mix, self.stage1_Mix)
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255) + ssim_loss
return loss_fn
class SDYMixNetx1v7(SRNetBase):
"""
22
12 23 32 21
11 13 33 31
10 14 34 30
01 03 43 41
00 04 44 40
"""
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYMixNetx1v7, self).__init__()
self.scale = scale
self._extract_pattern_1 = layers.PercievePattern(receptive_field_idxes= [[3,3]], center=[3,3], window_size=7)
self._extract_pattern_2 = layers.PercievePattern(receptive_field_idxes= [[2,3],[3,4],[4,3],[3,2]], center=[3,3], window_size=7)
self._extract_pattern_3 = layers.PercievePattern(receptive_field_idxes= [[2,2],[2,4],[4,4],[4,2]], center=[3,3], window_size=7)
self._extract_pattern_4 = layers.PercievePattern(receptive_field_idxes= [[1,2],[1,4],[5,4],[5,2]], center=[3,3], window_size=7)
self._extract_pattern_5 = layers.PercievePattern(receptive_field_idxes= [[1,1],[1,5],[5,5],[5,1]], center=[3,3], window_size=7)
self._extract_pattern_6 = layers.PercievePattern(receptive_field_idxes= [[2,1],[2,5],[4,5],[4,1]], center=[3,3], window_size=7)
self._extract_pattern_7 = layers.PercievePattern(receptive_field_idxes= [[3,1],[1,3],[3,5],[5,3]], center=[3,3], window_size=7)
self._extract_pattern_8 = layers.PercievePattern(receptive_field_idxes= [[0,0],[0,6],[6,6],[6,0]], center=[3,3], window_size=7)
self._extract_pattern_9 = layers.PercievePattern(receptive_field_idxes= [[0,3],[3,6],[6,3],[3,0]], center=[3,3], window_size=7)
self._extract_pattern_10 = layers.PercievePattern(receptive_field_idxes=[[2,0],[1,0],[0,1],[0,2]], center=[3,3], window_size=7)
self._extract_pattern_11 = layers.PercievePattern(receptive_field_idxes=[[0,4],[0,5],[1,6],[2,6]], center=[3,3], window_size=7)
self._extract_pattern_12 = layers.PercievePattern(receptive_field_idxes=[[4,6],[5,6],[6,5],[6,4]], center=[3,3], window_size=7)
self._extract_pattern_13 = layers.PercievePattern(receptive_field_idxes=[[6,2],[6,1],[5,0],[4,0]], center=[3,3], window_size=7)
self.stage1_1 = layers.UpscaleBlock(in_features=1, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_2 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_3 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_4 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_5 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_6 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_7 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_8 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_9 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_10 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_11 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_12 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self.stage1_13 = layers.UpscaleBlock(in_features=4, hidden_dim=32, layers_count=2, upscale_factor=scale)
self._extract_pattern_mix = layers.PercievePattern(receptive_field_idxes=[[0,0]], center=[0,0], window_size=1, channels=13)
self.stage1_Mix = layers.UpscaleBlock(in_features=13, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = self.forward_stage(x, self._extract_pattern_1, self.stage1_1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_2, self.stage1_2)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_3, self.stage1_3)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_4, self.stage1_4)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_5, self.stage1_5)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_6, self.stage1_6)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_7, self.stage1_7)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_8, self.stage1_8)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_9, self.stage1_9)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_10, self.stage1_10)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_11, self.stage1_11)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_12, self.stage1_12)], dim=1)
output = torch.cat([output, self.forward_stage(x, self._extract_pattern_13, self.stage1_13)], dim=1)
output = self.forward_stage(output, self._extract_pattern_mix, self.stage1_Mix)
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
ssim_loss = losses.SSIM(channel=1, data_range=255)
l1_loss = losses.CharbonnierLoss()
def loss_fn(pred, target):
# return F.mse_loss(pred/255, target/255)# + ssim_loss
return ssim_loss(pred, target) + l1_loss(pred, target)
return loss_fn

@ -27,6 +27,19 @@ import argparse
from schedulefree import AdamWScheduleFree from schedulefree import AdamWScheduleFree
from datetime import datetime from datetime import datetime
import signal
class SignalHandler:
def __init__(self, signal_code):
self.is_on = False
signal.signal(signal_code, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print("Early stopping.")
self.is_on = True
signal_interraption_handler = SignalHandler(signal.SIGINT)
class TrainOptions: class TrainOptions:
def __init__(self): def __init__(self):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False)
@ -55,6 +68,7 @@ class TrainOptions:
parser.add_argument('--color_model', type=str, default="RGB", help=f"Color model for train and test dataset. Choose from: {list(PIL_CONVERT_COLOR.keys())}") parser.add_argument('--color_model', type=str, default="RGB", help=f"Color model for train and test dataset. Choose from: {list(PIL_CONVERT_COLOR.keys())}")
parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache') parser.add_argument('--reset_cache', action='store_true', default=False, help='Discard datasets cache')
parser.add_argument('--learning_rate', type=float, default=0.0025, help='Learning rate') parser.add_argument('--learning_rate', type=float, default=0.0025, help='Learning rate')
parser.add_argument('--grad_step', type=int, default=1, help='Optimizer step.')
self.parser = parser self.parser = parser
@ -186,6 +200,8 @@ if __name__ == "__main__":
loss_fn = model.get_loss_fn() loss_fn = model.get_loss_fn()
for i in range(config.start_iter + 1, config.total_iter + 1): for i in range(config.start_iter + 1, config.total_iter + 1):
if signal_interraption_handler.is_on:
break
config.current_iter = i config.current_iter = i
torch.cuda.empty_cache() torch.cuda.empty_cache()
start_time = time.time() start_time = time.time()
@ -199,12 +215,15 @@ if __name__ == "__main__":
prepare_data_time += time.time() - start_time prepare_data_time += time.time() - start_time
start_time = time.time() start_time = time.time()
with torch.set_grad_enabled(True):
pred = model(x=lr_patch, config=config)
loss = loss_fn(pred=pred, target=hr_patch) / config.grad_step
loss.backward()
if i % config.grad_step == 0 and i > 0:
optimizer.step()
optimizer.zero_grad()
pred = model(x=lr_patch, config=config)
loss = loss_fn(pred=pred, target=hr_patch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
torch.cuda.empty_cache() torch.cuda.empty_cache()
forward_backward_time += time.time() - start_time forward_backward_time += time.time() - start_time

Loading…
Cancel
Save