main
protsenkovi 5 months ago
parent 19e0fca745
commit c3dc32c336

File diff suppressed because one or more lines are too long

@ -22,5 +22,5 @@ python image_demo.py --help
Requirements:
```
pip install shedulefree tensorboard opencv-python-headless scipy pandas
pip install schedulefree tensorboard opencv-python-headless scipy pandas matplotlib
```

@ -25,7 +25,7 @@ class PercievePattern():
return x
class UpscaleBlock(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=5, upscale_factor=1):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1):
super(UpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor

@ -0,0 +1,152 @@
import torch
class FourierLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(FourierLoss, self).__init__()
@staticmethod
def perm_roll(im, axis, amount):
permutation = torch.roll(torch.arange(im.shape[axis], device=im.device), amount, dims=0)
return torch.index_select(im, axis, permutation)
@staticmethod
def shift_left(im):
tt = perm_roll(im, axis=-2, amount=-(im.shape[-2]+1)//2)
tt = perm_roll(tt, axis=-1, amount=-(im.shape[-1]+1)//2)
return tt
@staticmethod
def shift_right(im):
tt = FourierLoss.perm_roll(im, axis=-2, amount=(im.shape[-2]+1)//2)
tt = FourierLoss.perm_roll(tt, axis=-1, amount=(im.shape[-1]+1)//2)
return tt
@staticmethod
def aperture(h, w, condition=None, low_pass_frequency=False):
"""torch.where(condition(circle_dist), torch.ones_like(circle_dist), torch.zeros_like(circle_dist)) """
x, y = torch.meshgrid(torch.arange(-h//2, h//2), torch.arange(-w//2, w//2), indexing='ij')
circle_dist = torch.sqrt(x**2 + y**2)
if low_pass_frequency:
circle_aperture = torch.where(condition(circle_dist), torch.zeros_like(circle_dist), torch.ones_like(circle_dist))
else:
circle_aperture = torch.where(condition(circle_dist), torch.ones_like(circle_dist), torch.zeros_like(circle_dist))
return circle_aperture
@staticmethod
def fft(x, aperture):
r = FourierLoss.shift_right(torch.fft.fft2(x, norm='ortho'))*aperture
return r.real**2 + r.imag**2
def forward(self, inputs, targets):
b,c,h,w = inputs.shape
r = min(h, w)//12
low_pass = FourierLoss.aperture(h=h, w=w, condition=lambda circle_dist: circle_dist>r, low_pass_frequency=False).to(inputs.device)
high_pass = FourierLoss.aperture(h=h, w=w, condition=lambda circle_dist: circle_dist>r, low_pass_frequency=True).to(inputs.device)
low_frequency_loss = F.mse_loss(FourierLoss.fft(inputs, low_pass), FourierLoss.fft(targets, low_pass))*(low_pass.sum()/high_pass.sum())
high_frequency_loss = F.mse_loss(FourierLoss.fft(inputs, high_pass), FourierLoss.fft(targets, high_pass))
return low_frequency_loss + high_frequency_loss
class FocalFrequencyLoss(nn.Module):
"""The torch.nn.Module class that implements focal frequency loss - a
frequency domain loss function for optimizing generative models.
Ref:
Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021.
<https://arxiv.org/pdf/2012.12821.pdf>
Args:
loss_weight (float): weight for focal frequency loss. Default: 1.0
alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0
patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1
ave_spectrum (bool): whether to use minibatch average spectrum. Default: False
log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False
batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False
"""
def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False):
super(FocalFrequencyLoss, self).__init__()
self.loss_weight = loss_weight
self.alpha = alpha
self.patch_factor = patch_factor
self.ave_spectrum = ave_spectrum
self.log_matrix = log_matrix
self.batch_matrix = batch_matrix
def tensor2freq(self, x):
# crop image patches
patch_factor = self.patch_factor
_, _, h, w = x.shape
assert h % patch_factor == 0 and w % patch_factor == 0, (
'Patch factor should be divisible by image height and width')
patch_list = []
patch_h = h // patch_factor
patch_w = w // patch_factor
for i in range(patch_factor):
for j in range(patch_factor):
patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
# stack to patch tensor
y = torch.stack(patch_list, 1)
# perform 2D DFT (real-to-complex, orthonormalization)
freq = torch.fft.fft2(y, norm='ortho')
freq = torch.stack([freq.real, freq.imag], -1)
return freq
def loss_formulation(self, recon_freq, real_freq, matrix=None):
# spectrum weight matrix
if matrix is not None:
# if the matrix is predefined
weight_matrix = matrix.detach()
else:
# if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
matrix_tmp = (recon_freq - real_freq) ** 2
matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
# whether to adjust the spectrum weight matrix by logarithm
if self.log_matrix:
matrix_tmp = torch.log(matrix_tmp + 1.0)
# whether to calculate the spectrum weight matrix using batch-based statistics
if self.batch_matrix:
matrix_tmp = matrix_tmp / matrix_tmp.max()
else:
matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
weight_matrix = matrix_tmp.clone().detach()
assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
'The values of spectrum weight matrix should be in the range [0, 1], '
'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
# frequency distance using (squared) Euclidean distance
tmp = (recon_freq - real_freq) ** 2
freq_distance = tmp[..., 0] + tmp[..., 1]
# dynamic spectrum weighting (Hadamard product)
loss = weight_matrix * freq_distance
return torch.mean(loss)
def forward(self, pred, target, matrix=None, **kwargs):
"""Forward function to calculate focal frequency loss.
Args:
pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor.
target (torch.Tensor): of shape (N, C, H, W). Target tensor.
matrix (torch.Tensor, optional): Element-wise spectrum weight matrix.
Default: None (If set to None: calculated online, dynamic).
"""
pred_freq = self.tensor2freq(pred)
target_freq = self.tensor2freq(target)
# whether to use minibatch average spectrum
if self.ave_spectrum:
pred_freq = torch.mean(pred_freq, 0, keepdim=True)
target_freq = torch.mean(target_freq, 0, keepdim=True)
# calculate focal frequency loss
return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight

@ -8,12 +8,71 @@ import numpy as np
##################### TRANSFER ##########################
class Domain4DValues(Dataset):
class Domain2DValues(Dataset):
def __init__(self, quantization_interval=1):
super(Domain4DValues, self).__init__()
super(Domain2DValues, self).__init__()
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([256])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*2)).view(-1, 1, 2)
def __getitem__(self, idx):
if isinstance(idx, slice):
ix1s, ix2s, batch = [], [], [], []
for i in range(idx.start, idx.stop):
ix1, ix2, values = self.__getitem__(i)
ix1s.append(ix1)
ix2s.append(ix2)
batch.append(values)
return ix1s, ix2s, ix3s, batch
else:
v = self.values[idx]
ix = v[0]//self.quantization_interval
return ix[0], ix[1], v
def __len__(self):
return len(self.values)
def __iter__(self):
for i in range(len(self.values)):
yield self.__getitem__(i)
class Domain3DValues(Dataset):
def __init__(self, quantization_interval=1):
super(Domain3DValues, self).__init__()
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([256])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*3)).view(-1, 1, 3)
def __getitem__(self, idx):
if isinstance(idx, slice):
ix1s, ix2s, ix3s, batch = [], [], [], []
for i in range(idx.start, idx.stop):
ix1, ix2, ix3, values = self.__getitem__(i)
ix1s.append(ix1)
ix2s.append(ix2)
ix3s.append(ix3)
batch.append(values)
return ix1s, ix2s, ix3s, batch
else:
v = self.values[idx]
ix = v[0]//self.quantization_interval
return ix[0], ix[1], ix[2], v
def __len__(self):
return len(self.values)
def __iter__(self):
for i in range(len(self.values)):
yield self.__getitem__(i)
class Domain4DValues(Dataset):
def __init__(self, quantization_interval=1, max_value=255):
super(Domain4DValues, self).__init__()
values1d = torch.arange(0, max_value+1, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([max_value+1])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 4)
def __getitem__(self, idx):
@ -54,11 +113,11 @@ def transfer_rc_conv(rc_conv, quantization_interval=1):
print()
return lut
def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
domain_values = Domain4DValues(quantization_interval=quantization_interval)
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain4DValues(quantization_interval=quantization_interval, max_value=max_value)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
@ -68,7 +127,7 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2*
)
counter = 0
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
inputs = batch.type(torch.float32).to(list(block.parameters())[0].device)
with torch.no_grad():
outputs = block(inputs)
lut[ix1s, ix2s, ix3s, ix4s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
@ -77,7 +136,51 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2*
print()
return lut
def transfer_3_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
domain_values = Domain3DValues(quantization_interval=quantization_interval)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
pin_memory=True,
num_workers=1 if quantization_interval >= 16 else mp.cpu_count(),
shuffle=False,
)
counter = 0
for idx, (ix1s, ix2s, ix3s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
with torch.no_grad():
outputs = block(inputs)
lut[ix1s, ix2s, ix3s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print()
return lut
def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
domain_values = Domain2DValues(quantization_interval=quantization_interval)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
pin_memory=True,
num_workers=1 if quantization_interval >= 16 else mp.cpu_count(),
shuffle=False,
)
counter = 0
for idx, (ix1s, ix2s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
with torch.no_grad():
outputs = block(inputs)
lut[ix1s, ix2s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print()
return lut
##################### FORWARD ##########################
@ -140,11 +243,11 @@ def select_index_1dlut_linear(ixA, lut):
def select_index_3dlut_tetrahedral(index, lut):
b, hw, c = index.shape
lut = torch.clamp(lut, 0, 255)
dimA, dimB, dimC, dimD = lut.shape[:4]
dimA, dimB, dimC = lut.shape[:4]
q = 256/(dimA-1)
L = dimA
upscale = lut.shape[-1]
weight = lut.reshape(L**4, upscale, upscale)
weight = lut.reshape(L**3, upscale, upscale)
msbA = torch.floor_divide(index, q).type(torch.int64)
msbB = msbA + 1

@ -6,6 +6,12 @@ from pathlib import Path
from PIL import Image
import time
from datetime import timedelta, datetime
from matplotlib import pyplot as plt
cmap = plt.get_cmap('viridis')
cmaplist = [cmap(i) for i in range(cmap.N)]
cmaplut = np.array(cmaplist)
cmaplut = np.round(cmaplut[:, 0:3]*255).astype(np.uint8)
def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
with torch.inference_mode():
@ -27,7 +33,7 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non
if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr':
Image.fromarray(pred_lr_image, mode=color_model).convert("RGB").save(output_image_path)
if pred_lr_image.shape[-1] == 1:
Image.fromarray(pred_lr_image[:,:,0]).save(output_image_path)
Image.fromarray(cmaplut[pred_lr_image[:,:,0]]).save(output_image_path)
# metrics
hr_image = modcrop(hr_image, model.scale)
@ -64,7 +70,7 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
if print_progress:
start_datetime = datetime.now()
for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset):
output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
tasks.append(task)
if print_progress:
@ -104,6 +110,8 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
'Total time'
]
config.logger.info("\n" + str(pd.DataFrame([row], columns=column_names).set_index('Dataset').T))
config.writer.add_scalar(f'{dataset_name}_PSNR', np.mean(psnrs), config.current_iter)
config.writer.add_scalar(f'{dataset_name}_SSIM', np.mean(ssims), config.current_iter)
config.writer.flush()
results = pd.DataFrame(results, columns=column_names).set_index('Dataset')

@ -16,6 +16,7 @@ class ImageDemoOptions():
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path")
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.")
self.parser.add_argument('--output_name', type=str, default="image_demo.png", help="Output name.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False)
self.parser.add_argument('--device', default='cuda', help='Device of the model')
@ -92,6 +93,6 @@ for i in range(row_count):
row.append(images[i*column_count + j])
columns.append(np.concatenate(row, axis=1))
canvas = np.concatenate(columns, axis=0).astype(np.uint8)
Image.fromarray(canvas).save(config.output_path / 'image_demo.png')
Image.fromarray(canvas).save(config.output_path / config.output_name)
print(datetime.now() - start_script_time )

@ -30,6 +30,9 @@ AVAILABLE_MODELS = {
'SRLutY': srlut.SRLutY,
'HDBNet': hdbnet.HDBNet,
'HDBLut': hdblut.HDBLut,
'HDBLNet': hdbnet.HDBLNet,
'HDBHNet': hdbnet.HDBHNet,
'SRMsbLsb4R90Net': srnet.SRMsbLsb4R90Net,
# 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
# 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
# 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,

@ -45,7 +45,7 @@ class SDYLutx1(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -63,6 +63,11 @@ class SDYLutx1(nn.Module):
f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutx2(nn.Module):
def __init__(
self,
@ -107,7 +112,7 @@ class SDYLutx2(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -135,6 +140,11 @@ class SDYLutx2(nn.Module):
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutx3(nn.Module):
def __init__(
@ -186,7 +196,7 @@ class SDYLutx3(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -224,6 +234,12 @@ class SDYLutx3(nn.Module):
f"\n stage3_D size: {self.stage3_D.shape}" + \
f"\n stage3_Y size: {self.stage3_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutR90x1(nn.Module):
def __init__(
self,
@ -262,7 +278,7 @@ class SDYLutR90x1(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -285,6 +301,10 @@ class SDYLutR90x1(nn.Module):
f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutR90x2(nn.Module):
@ -331,7 +351,7 @@ class SDYLutR90x2(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -368,3 +388,8 @@ class SDYLutR90x2(nn.Module):
f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -30,7 +30,7 @@ class SDYNetx1(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -50,6 +50,12 @@ class SDYNetx1(nn.Module):
lut_model = sdylut.SDYLutx1.init_from_numpy(stageS, stageD, stageY)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetx2(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx2, self).__init__()
@ -74,7 +80,7 @@ class SDYNetx2(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -104,6 +110,11 @@ class SDYNetx2(nn.Module):
lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetx3(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx3, self).__init__()
@ -131,7 +142,7 @@ class SDYNetx3(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -171,6 +182,11 @@ class SDYNetx3(nn.Module):
lut_model = sdylut.SDYLutx3.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y, stage3_S, stage3_D, stage3_Y)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetR90x1(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetR90x1, self).__init__()
@ -192,7 +208,7 @@ class SDYNetR90x1(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -217,6 +233,11 @@ class SDYNetR90x1(nn.Module):
lut_model = sdylut.SDYLutR90x1.init_from_numpy(stageS, stageD, stageY)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetR90x2(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetR90x2, self).__init__()
@ -241,7 +262,7 @@ class SDYNetR90x2(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -278,3 +299,8 @@ class SDYNetR90x2(nn.Module):
stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutR90x2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -39,7 +39,7 @@ class SRLut(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w).type(torch.float32)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
@ -49,6 +49,11 @@ class SRLut(nn.Module):
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutY(nn.Module):
def __init__(
self,
@ -83,7 +88,7 @@ class SRLutY(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
@ -98,6 +103,11 @@ class SRLutY(nn.Module):
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90(nn.Module):
def __init__(
self,
@ -130,7 +140,7 @@ class SRLutR90(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
@ -145,6 +155,11 @@ class SRLutR90(nn.Module):
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90Y(nn.Module):
def __init__(
@ -180,7 +195,7 @@ class SRLutR90Y(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
@ -199,3 +214,8 @@ class SRLutR90Y(nn.Module):
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -7,6 +7,8 @@ from common import lut
from pathlib import Path
from . import srlut
from common import layers
from itertools import cycle
from common import losses
class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
@ -29,7 +31,7 @@ class SRNet(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
@ -41,6 +43,11 @@ class SRNet(nn.Module):
lut_model = srlut.SRLut.init_from_numpy(stage_lut)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNetY(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetY, self).__init__()
@ -64,7 +71,7 @@ class SRNetY(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
@ -82,6 +89,11 @@ class SRNetY(nn.Module):
lut_model = srlut.SRLutY.init_from_numpy(stage_lut)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNetR90(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90, self).__init__()
@ -103,7 +115,7 @@ class SRNetR90(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -120,6 +132,11 @@ class SRNetR90(nn.Module):
lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNetR90Y(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90Y, self).__init__()
@ -144,7 +161,7 @@ class SRNetR90Y(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
@ -166,3 +183,75 @@ class SRNetR90Y(nn.Module):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRMsbLsb4R90Net(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRMsbLsb4R90Net, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.msb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(4)])
self.lsb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(4)])
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count, msb_fn, lsb_fn in zip(range(4), cycle(self.msb_fns), cycle(self.lsb_fns)):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, lsb_fn)
output_msb_r = round_func((output_msb_r / 255)*16) * 15
output_lsb_r = (output_lsb_r / 255) * 15
output_msb += torch.rot90(output_msb_r, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsb_r, k=-rotations_count, dims=[2, 3])
output_msb /= 4
output_lsb /= 4
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
fourier_loss_fn = losses.FocalFrequencyLoss()
def loss_fn(pred, target):
return fourier_loss_fn(pred/255, target/255) * 1e8
return loss_fn

@ -41,7 +41,7 @@ class ValOptions():
args.valout_dir = Path(args.exp_dir).resolve() / 'val'
if not args.valout_dir.exists():
args.valout_dir.mkdir()
args.current_iter = args.model_name.split('_')[-1]
args.current_iter = int(args.model_name.split('_')[-1])
args.results_path = os.path.join(args.valout_dir, f'results_{args.model_name}_{args.device}.csv')
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.valout_dir)

@ -45,7 +45,7 @@ class TrainOptions:
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
@ -59,6 +59,8 @@ class TrainOptions:
args.model_path = Path(args.model_path) if not args.model_path is None else None
args.train_datasets = args.train_datasets.split(',')
args.val_datasets = args.val_datasets.split(',')
if not args.model_path is None:
args.start_iter = int(args.model_path.stem.split("_")[-1])
return args
def __repr__(self):
@ -110,7 +112,8 @@ if __name__ == "__main__":
if 'lut' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale)
model = model.to(torch.device(config.device))
optimizer = AdamWScheduleFree(model.parameters())
optimizer = AdamWScheduleFree(model.parameters(), lr=1e-2, betas=(0.9, 0.95))
# optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95))
print(optimizer)
prepare_experiment_folder(config)
@ -140,7 +143,7 @@ if __name__ == "__main__":
dataset = train_dataset,
batch_size = config.batch_size,
num_workers = config.worker_num,
shuffle = False,
shuffle = True,
drop_last = False,
pin_memory = True,
prefetch_factor = config.prefetch_factor
@ -165,7 +168,9 @@ if __name__ == "__main__":
config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
loss_fn = model.get_loss_fn()
for i in range(config.start_iter + 1, config.total_iter + 1):
config.current_iter = i
torch.cuda.empty_cache()
start_time = time.time()
try:
@ -179,8 +184,8 @@ if __name__ == "__main__":
start_time = time.time()
pred = model(lr_patch)
loss = F.mse_loss(pred/255, hr_patch/255)
pred = model(x=lr_patch, config=config)
loss = loss_fn(pred=pred, target=hr_patch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
@ -194,10 +199,15 @@ if __name__ == "__main__":
# Show information
if i % config.display_step == 0:
config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i)
config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
model.__class__.__name__, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step,
forward_backward_time / config.display_step))
config.writer.add_scalar('loss', loss.item(), i)
config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, loss:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
model.__class__.__name__,
i,
accum_samples,
l_accum[0] / config.display_step,
prepare_data_time / config.display_step,
forward_backward_time / config.display_step))
l_accum = [0., 0., 0.]
prepare_data_time = 0.
forward_backward_time = 0.
@ -208,7 +218,6 @@ if __name__ == "__main__":
# Validation
if i % config.val_step == 0:
config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve()

Loading…
Cancel
Save