main
protsenkovi 5 months ago
parent 19e0fca745
commit c3dc32c336

File diff suppressed because one or more lines are too long

@ -22,5 +22,5 @@ python image_demo.py --help
Requirements: Requirements:
``` ```
pip install shedulefree tensorboard opencv-python-headless scipy pandas pip install schedulefree tensorboard opencv-python-headless scipy pandas matplotlib
``` ```

@ -25,7 +25,7 @@ class PercievePattern():
return x return x
class UpscaleBlock(nn.Module): class UpscaleBlock(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=5, upscale_factor=1): def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1):
super(UpscaleBlock, self).__init__() super(UpscaleBlock, self).__init__()
assert layers_count > 0 assert layers_count > 0
self.upscale_factor = upscale_factor self.upscale_factor = upscale_factor

@ -0,0 +1,152 @@
import torch
class FourierLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(FourierLoss, self).__init__()
@staticmethod
def perm_roll(im, axis, amount):
permutation = torch.roll(torch.arange(im.shape[axis], device=im.device), amount, dims=0)
return torch.index_select(im, axis, permutation)
@staticmethod
def shift_left(im):
tt = perm_roll(im, axis=-2, amount=-(im.shape[-2]+1)//2)
tt = perm_roll(tt, axis=-1, amount=-(im.shape[-1]+1)//2)
return tt
@staticmethod
def shift_right(im):
tt = FourierLoss.perm_roll(im, axis=-2, amount=(im.shape[-2]+1)//2)
tt = FourierLoss.perm_roll(tt, axis=-1, amount=(im.shape[-1]+1)//2)
return tt
@staticmethod
def aperture(h, w, condition=None, low_pass_frequency=False):
"""torch.where(condition(circle_dist), torch.ones_like(circle_dist), torch.zeros_like(circle_dist)) """
x, y = torch.meshgrid(torch.arange(-h//2, h//2), torch.arange(-w//2, w//2), indexing='ij')
circle_dist = torch.sqrt(x**2 + y**2)
if low_pass_frequency:
circle_aperture = torch.where(condition(circle_dist), torch.zeros_like(circle_dist), torch.ones_like(circle_dist))
else:
circle_aperture = torch.where(condition(circle_dist), torch.ones_like(circle_dist), torch.zeros_like(circle_dist))
return circle_aperture
@staticmethod
def fft(x, aperture):
r = FourierLoss.shift_right(torch.fft.fft2(x, norm='ortho'))*aperture
return r.real**2 + r.imag**2
def forward(self, inputs, targets):
b,c,h,w = inputs.shape
r = min(h, w)//12
low_pass = FourierLoss.aperture(h=h, w=w, condition=lambda circle_dist: circle_dist>r, low_pass_frequency=False).to(inputs.device)
high_pass = FourierLoss.aperture(h=h, w=w, condition=lambda circle_dist: circle_dist>r, low_pass_frequency=True).to(inputs.device)
low_frequency_loss = F.mse_loss(FourierLoss.fft(inputs, low_pass), FourierLoss.fft(targets, low_pass))*(low_pass.sum()/high_pass.sum())
high_frequency_loss = F.mse_loss(FourierLoss.fft(inputs, high_pass), FourierLoss.fft(targets, high_pass))
return low_frequency_loss + high_frequency_loss
class FocalFrequencyLoss(nn.Module):
"""The torch.nn.Module class that implements focal frequency loss - a
frequency domain loss function for optimizing generative models.
Ref:
Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021.
<https://arxiv.org/pdf/2012.12821.pdf>
Args:
loss_weight (float): weight for focal frequency loss. Default: 1.0
alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0
patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1
ave_spectrum (bool): whether to use minibatch average spectrum. Default: False
log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False
batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False
"""
def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False):
super(FocalFrequencyLoss, self).__init__()
self.loss_weight = loss_weight
self.alpha = alpha
self.patch_factor = patch_factor
self.ave_spectrum = ave_spectrum
self.log_matrix = log_matrix
self.batch_matrix = batch_matrix
def tensor2freq(self, x):
# crop image patches
patch_factor = self.patch_factor
_, _, h, w = x.shape
assert h % patch_factor == 0 and w % patch_factor == 0, (
'Patch factor should be divisible by image height and width')
patch_list = []
patch_h = h // patch_factor
patch_w = w // patch_factor
for i in range(patch_factor):
for j in range(patch_factor):
patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
# stack to patch tensor
y = torch.stack(patch_list, 1)
# perform 2D DFT (real-to-complex, orthonormalization)
freq = torch.fft.fft2(y, norm='ortho')
freq = torch.stack([freq.real, freq.imag], -1)
return freq
def loss_formulation(self, recon_freq, real_freq, matrix=None):
# spectrum weight matrix
if matrix is not None:
# if the matrix is predefined
weight_matrix = matrix.detach()
else:
# if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
matrix_tmp = (recon_freq - real_freq) ** 2
matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
# whether to adjust the spectrum weight matrix by logarithm
if self.log_matrix:
matrix_tmp = torch.log(matrix_tmp + 1.0)
# whether to calculate the spectrum weight matrix using batch-based statistics
if self.batch_matrix:
matrix_tmp = matrix_tmp / matrix_tmp.max()
else:
matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
weight_matrix = matrix_tmp.clone().detach()
assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
'The values of spectrum weight matrix should be in the range [0, 1], '
'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
# frequency distance using (squared) Euclidean distance
tmp = (recon_freq - real_freq) ** 2
freq_distance = tmp[..., 0] + tmp[..., 1]
# dynamic spectrum weighting (Hadamard product)
loss = weight_matrix * freq_distance
return torch.mean(loss)
def forward(self, pred, target, matrix=None, **kwargs):
"""Forward function to calculate focal frequency loss.
Args:
pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor.
target (torch.Tensor): of shape (N, C, H, W). Target tensor.
matrix (torch.Tensor, optional): Element-wise spectrum weight matrix.
Default: None (If set to None: calculated online, dynamic).
"""
pred_freq = self.tensor2freq(pred)
target_freq = self.tensor2freq(target)
# whether to use minibatch average spectrum
if self.ave_spectrum:
pred_freq = torch.mean(pred_freq, 0, keepdim=True)
target_freq = torch.mean(target_freq, 0, keepdim=True)
# calculate focal frequency loss
return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight

@ -8,12 +8,71 @@ import numpy as np
##################### TRANSFER ########################## ##################### TRANSFER ##########################
class Domain4DValues(Dataset): class Domain2DValues(Dataset):
def __init__(self, quantization_interval=1): def __init__(self, quantization_interval=1):
super(Domain4DValues, self).__init__() super(Domain2DValues, self).__init__()
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([256])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*2)).view(-1, 1, 2)
def __getitem__(self, idx):
if isinstance(idx, slice):
ix1s, ix2s, batch = [], [], [], []
for i in range(idx.start, idx.stop):
ix1, ix2, values = self.__getitem__(i)
ix1s.append(ix1)
ix2s.append(ix2)
batch.append(values)
return ix1s, ix2s, ix3s, batch
else:
v = self.values[idx]
ix = v[0]//self.quantization_interval
return ix[0], ix[1], v
def __len__(self):
return len(self.values)
def __iter__(self):
for i in range(len(self.values)):
yield self.__getitem__(i)
class Domain3DValues(Dataset):
def __init__(self, quantization_interval=1):
super(Domain3DValues, self).__init__()
values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8) values1d = torch.arange(0, 256, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([256])]) values1d = torch.cat([values1d, torch.tensor([256])])
self.quantization_interval = quantization_interval self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*3)).view(-1, 1, 3)
def __getitem__(self, idx):
if isinstance(idx, slice):
ix1s, ix2s, ix3s, batch = [], [], [], []
for i in range(idx.start, idx.stop):
ix1, ix2, ix3, values = self.__getitem__(i)
ix1s.append(ix1)
ix2s.append(ix2)
ix3s.append(ix3)
batch.append(values)
return ix1s, ix2s, ix3s, batch
else:
v = self.values[idx]
ix = v[0]//self.quantization_interval
return ix[0], ix[1], ix[2], v
def __len__(self):
return len(self.values)
def __iter__(self):
for i in range(len(self.values)):
yield self.__getitem__(i)
class Domain4DValues(Dataset):
def __init__(self, quantization_interval=1, max_value=255):
super(Domain4DValues, self).__init__()
values1d = torch.arange(0, max_value+1, quantization_interval, dtype=torch.uint8)
values1d = torch.cat([values1d, torch.tensor([max_value+1])])
self.quantization_interval = quantization_interval
self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 4) self.values = torch.cartesian_prod(*([values1d]*4)).view(-1, 1, 4)
def __getitem__(self, idx): def __getitem__(self, idx):
@ -54,11 +113,11 @@ def transfer_rc_conv(rc_conv, quantization_interval=1):
print() print()
return lut return lut
def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10): def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2**10, max_value=255):
bucket_count = 256//quantization_interval bucket_count = (max_value+1)//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1 scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2 lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=max_value) # 4DLUT for simple input window 2x2
domain_values = Domain4DValues(quantization_interval=quantization_interval) domain_values = Domain4DValues(quantization_interval=quantization_interval, max_value=max_value)
domain_values_loader = DataLoader( domain_values_loader = DataLoader(
domain_values, domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16, batch_size=batch_size if quantization_interval >= 16 else 2**16,
@ -68,7 +127,7 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2*
) )
counter = 0 counter = 0
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader): for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda() inputs = batch.type(torch.float32).to(list(block.parameters())[0].device)
with torch.no_grad(): with torch.no_grad():
outputs = block(inputs) outputs = block(inputs)
lut[ix1s, ix2s, ix3s, ix4s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8) lut[ix1s, ix2s, ix3s, ix4s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
@ -77,7 +136,51 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2*
print() print()
return lut return lut
def transfer_3_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
domain_values = Domain3DValues(quantization_interval=quantization_interval)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
pin_memory=True,
num_workers=1 if quantization_interval >= 16 else mp.cpu_count(),
shuffle=False,
)
counter = 0
for idx, (ix1s, ix2s, ix3s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
with torch.no_grad():
outputs = block(inputs)
lut[ix1s, ix2s, ix3s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print()
return lut
def transfer_2_input_SxS_output(block, quantization_interval=16, batch_size=2**10):
bucket_count = 256//quantization_interval
scale = block.upscale_factor if hasattr(block, 'upscale_factor') else 1
lut = np.full((bucket_count+1, bucket_count+1, scale, scale), dtype=np.uint8, fill_value=255) # 4DLUT for simple input window 2x2
domain_values = Domain2DValues(quantization_interval=quantization_interval)
domain_values_loader = DataLoader(
domain_values,
batch_size=batch_size if quantization_interval >= 16 else 2**16,
pin_memory=True,
num_workers=1 if quantization_interval >= 16 else mp.cpu_count(),
shuffle=False,
)
counter = 0
for idx, (ix1s, ix2s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda()
with torch.no_grad():
outputs = block(inputs)
lut[ix1s, ix2s, :, :] = outputs.reshape(len(ix1s), scale, scale).cpu().numpy().astype(np.uint8)
counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print()
return lut
##################### FORWARD ########################## ##################### FORWARD ##########################
@ -140,11 +243,11 @@ def select_index_1dlut_linear(ixA, lut):
def select_index_3dlut_tetrahedral(index, lut): def select_index_3dlut_tetrahedral(index, lut):
b, hw, c = index.shape b, hw, c = index.shape
lut = torch.clamp(lut, 0, 255) lut = torch.clamp(lut, 0, 255)
dimA, dimB, dimC, dimD = lut.shape[:4] dimA, dimB, dimC = lut.shape[:4]
q = 256/(dimA-1) q = 256/(dimA-1)
L = dimA L = dimA
upscale = lut.shape[-1] upscale = lut.shape[-1]
weight = lut.reshape(L**4, upscale, upscale) weight = lut.reshape(L**3, upscale, upscale)
msbA = torch.floor_divide(index, q).type(torch.int64) msbA = torch.floor_divide(index, q).type(torch.int64)
msbB = msbA + 1 msbB = msbA + 1

@ -6,6 +6,12 @@ from pathlib import Path
from PIL import Image from PIL import Image
import time import time
from datetime import timedelta, datetime from datetime import timedelta, datetime
from matplotlib import pyplot as plt
cmap = plt.get_cmap('viridis')
cmaplist = [cmap(i) for i in range(cmap.N)]
cmaplut = np.array(cmaplist)
cmaplut = np.round(cmaplut[:, 0:3]*255).astype(np.uint8)
def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'): def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):
with torch.inference_mode(): with torch.inference_mode():
@ -27,7 +33,7 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non
if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr': if pred_lr_image.shape[-1] == 3 and color_model == 'YCbCr':
Image.fromarray(pred_lr_image, mode=color_model).convert("RGB").save(output_image_path) Image.fromarray(pred_lr_image, mode=color_model).convert("RGB").save(output_image_path)
if pred_lr_image.shape[-1] == 1: if pred_lr_image.shape[-1] == 1:
Image.fromarray(pred_lr_image[:,:,0]).save(output_image_path) Image.fromarray(cmaplut[pred_lr_image[:,:,0]]).save(output_image_path)
# metrics # metrics
hr_image = modcrop(hr_image, model.scale) hr_image = modcrop(hr_image, model.scale)
@ -64,7 +70,7 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
if print_progress: if print_progress:
start_datetime = datetime.now() start_datetime = datetime.now()
for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset): for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset):
output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None output_image_path = predictions_path / f'{Path(hr_image_path).stem}_{config.current_iter:06d}.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device) task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
tasks.append(task) tasks.append(task)
if print_progress: if print_progress:
@ -104,6 +110,8 @@ def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
'Total time' 'Total time'
] ]
config.logger.info("\n" + str(pd.DataFrame([row], columns=column_names).set_index('Dataset').T)) config.logger.info("\n" + str(pd.DataFrame([row], columns=column_names).set_index('Dataset').T))
config.writer.add_scalar(f'{dataset_name}_PSNR', np.mean(psnrs), config.current_iter)
config.writer.add_scalar(f'{dataset_name}_SSIM', np.mean(ssims), config.current_iter)
config.writer.flush() config.writer.flush()
results = pd.DataFrame(results, columns=column_names).set_index('Dataset') results = pd.DataFrame(results, columns=column_names).set_index('Dataset')

@ -16,6 +16,7 @@ class ImageDemoOptions():
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path") self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path")
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.") self.parser.add_argument('--output_path', type=str, default="../models/", help="Output path.")
self.parser.add_argument('--output_name', type=str, default="image_demo.png", help="Output name.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False) self.parser.add_argument('--mirror', action='store_true', default=False)
self.parser.add_argument('--device', default='cuda', help='Device of the model') self.parser.add_argument('--device', default='cuda', help='Device of the model')
@ -92,6 +93,6 @@ for i in range(row_count):
row.append(images[i*column_count + j]) row.append(images[i*column_count + j])
columns.append(np.concatenate(row, axis=1)) columns.append(np.concatenate(row, axis=1))
canvas = np.concatenate(columns, axis=0).astype(np.uint8) canvas = np.concatenate(columns, axis=0).astype(np.uint8)
Image.fromarray(canvas).save(config.output_path / 'image_demo.png') Image.fromarray(canvas).save(config.output_path / config.output_name)
print(datetime.now() - start_script_time ) print(datetime.now() - start_script_time )

@ -30,6 +30,9 @@ AVAILABLE_MODELS = {
'SRLutY': srlut.SRLutY, 'SRLutY': srlut.SRLutY,
'HDBNet': hdbnet.HDBNet, 'HDBNet': hdbnet.HDBNet,
'HDBLut': hdblut.HDBLut, 'HDBLut': hdblut.HDBLut,
'HDBLNet': hdbnet.HDBLNet,
'HDBHNet': hdbnet.HDBHNet,
'SRMsbLsb4R90Net': srnet.SRMsbLsb4R90Net,
# 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, # 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
# 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, # 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
# 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, # 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,

@ -45,7 +45,7 @@ class SDYLutx1(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32) x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -63,6 +63,11 @@ class SDYLutx1(nn.Module):
f"\n stageD size: {self.stageD.shape}" + \ f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}" f"\n stageY size: {self.stageY.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutx2(nn.Module): class SDYLutx2(nn.Module):
def __init__( def __init__(
self, self,
@ -107,7 +112,7 @@ class SDYLutx2(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32) x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -135,6 +140,11 @@ class SDYLutx2(nn.Module):
f"\n stage2_D size: {self.stage2_D.shape}" + \ f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}" f"\n stage2_Y size: {self.stage2_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutx3(nn.Module): class SDYLutx3(nn.Module):
def __init__( def __init__(
@ -186,7 +196,7 @@ class SDYLutx3(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32) x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -224,6 +234,12 @@ class SDYLutx3(nn.Module):
f"\n stage3_D size: {self.stage3_D.shape}" + \ f"\n stage3_D size: {self.stage3_D.shape}" + \
f"\n stage3_Y size: {self.stage3_Y.shape}" f"\n stage3_Y size: {self.stage3_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutR90x1(nn.Module): class SDYLutR90x1(nn.Module):
def __init__( def __init__(
self, self,
@ -262,7 +278,7 @@ class SDYLutR90x1(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32) x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -285,6 +301,10 @@ class SDYLutR90x1(nn.Module):
f"\n stageD size: {self.stageD.shape}" + \ f"\n stageD size: {self.stageD.shape}" + \
f"\n stageY size: {self.stageY.shape}" f"\n stageY size: {self.stageY.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYLutR90x2(nn.Module): class SDYLutR90x2(nn.Module):
@ -331,7 +351,7 @@ class SDYLutR90x2(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32) x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -368,3 +388,8 @@ class SDYLutR90x2(nn.Module):
f"\n stage2_S size: {self.stage2_S.shape}" + \ f"\n stage2_S size: {self.stage2_S.shape}" + \
f"\n stage2_D size: {self.stage2_D.shape}" + \ f"\n stage2_D size: {self.stage2_D.shape}" + \
f"\n stage2_Y size: {self.stage2_Y.shape}" f"\n stage2_Y size: {self.stage2_Y.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -30,7 +30,7 @@ class SDYNetx1(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w) x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -50,6 +50,12 @@ class SDYNetx1(nn.Module):
lut_model = sdylut.SDYLutx1.init_from_numpy(stageS, stageD, stageY) lut_model = sdylut.SDYLutx1.init_from_numpy(stageS, stageD, stageY)
return lut_model return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetx2(nn.Module): class SDYNetx2(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx2, self).__init__() super(SDYNetx2, self).__init__()
@ -74,7 +80,7 @@ class SDYNetx2(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w) x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -104,6 +110,11 @@ class SDYNetx2(nn.Module):
lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y)
return lut_model return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetx3(nn.Module): class SDYNetx3(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx3, self).__init__() super(SDYNetx3, self).__init__()
@ -131,7 +142,7 @@ class SDYNetx3(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w) x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -171,6 +182,11 @@ class SDYNetx3(nn.Module):
lut_model = sdylut.SDYLutx3.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y, stage3_S, stage3_D, stage3_Y) lut_model = sdylut.SDYLutx3.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y, stage3_S, stage3_D, stage3_Y)
return lut_model return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetR90x1(nn.Module): class SDYNetR90x1(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetR90x1, self).__init__() super(SDYNetR90x1, self).__init__()
@ -192,7 +208,7 @@ class SDYNetR90x1(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w) x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -217,6 +233,11 @@ class SDYNetR90x1(nn.Module):
lut_model = sdylut.SDYLutR90x1.init_from_numpy(stageS, stageD, stageY) lut_model = sdylut.SDYLutR90x1.init_from_numpy(stageS, stageD, stageY)
return lut_model return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SDYNetR90x2(nn.Module): class SDYNetR90x2(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetR90x2, self).__init__() super(SDYNetR90x2, self).__init__()
@ -241,7 +262,7 @@ class SDYNetR90x2(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w) x = x.view(b*c, 1, h, w)
output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) output_1 = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
@ -278,3 +299,8 @@ class SDYNetR90x2(nn.Module):
stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size) stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutR90x2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) lut_model = sdylut.SDYLutR90x2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y)
return lut_model return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -39,7 +39,7 @@ class SRLut(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w).type(torch.float32) x = x.reshape(b*c, 1, h, w).type(torch.float32)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut) x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
@ -49,6 +49,11 @@ class SRLut(nn.Module):
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutY(nn.Module): class SRLutY(nn.Module):
def __init__( def __init__(
self, self,
@ -83,7 +88,7 @@ class SRLutY(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x) x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:] y = x[:,0:1,:,:]
@ -98,6 +103,11 @@ class SRLutY(nn.Module):
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90(nn.Module): class SRLutR90(nn.Module):
def __init__( def __init__(
self, self,
@ -130,7 +140,7 @@ class SRLutR90(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w) x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
@ -145,6 +155,11 @@ class SRLutR90(nn.Module):
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRLutR90Y(nn.Module): class SRLutR90Y(nn.Module):
def __init__( def __init__(
@ -180,7 +195,7 @@ class SRLutR90Y(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x) x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:] y = x[:,0:1,:,:]
@ -199,3 +214,8 @@ class SRLutR90Y(nn.Module):
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn

@ -7,6 +7,8 @@ from common import lut
from pathlib import Path from pathlib import Path
from . import srlut from . import srlut
from common import layers from common import layers
from itertools import cycle
from common import losses
class SRNet(nn.Module): class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
@ -29,7 +31,7 @@ class SRNet(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w) x = x.reshape(b*c, 1, h, w)
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
@ -41,6 +43,11 @@ class SRNet(nn.Module):
lut_model = srlut.SRLut.init_from_numpy(stage_lut) lut_model = srlut.SRLut.init_from_numpy(stage_lut)
return lut_model return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNetY(nn.Module): class SRNetY(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetY, self).__init__() super(SRNetY, self).__init__()
@ -64,7 +71,7 @@ class SRNetY(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x) x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:] y = x[:,0:1,:,:]
@ -82,6 +89,11 @@ class SRNetY(nn.Module):
lut_model = srlut.SRLutY.init_from_numpy(stage_lut) lut_model = srlut.SRLutY.init_from_numpy(stage_lut)
return lut_model return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNetR90(nn.Module): class SRNetR90(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90, self).__init__() super(SRNetR90, self).__init__()
@ -103,7 +115,7 @@ class SRNetR90(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w) x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
@ -120,6 +132,11 @@ class SRNetR90(nn.Module):
lut_model = srlut.SRLutR90.init_from_numpy(stage_lut) lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
return lut_model return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNetR90Y(nn.Module): class SRNetR90Y(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90Y, self).__init__() super(SRNetR90Y, self).__init__()
@ -144,7 +161,7 @@ class SRNetR90Y(nn.Module):
x = x.reshape(b, c, h*scale, w*scale) x = x.reshape(b, c, h*scale, w*scale)
return x return x
def forward(self, x): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x) x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:] y = x[:,0:1,:,:]
@ -166,3 +183,75 @@ class SRNetR90Y(nn.Module):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut) lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut)
return lut_model return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRMsbLsb4R90Net(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRMsbLsb4R90Net, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.msb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(4)])
self.lsb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(4)])
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count, msb_fn, lsb_fn in zip(range(4), cycle(self.msb_fns), cycle(self.lsb_fns)):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, lsb_fn)
output_msb_r = round_func((output_msb_r / 255)*16) * 15
output_lsb_r = (output_lsb_r / 255) * 15
output_msb += torch.rot90(output_msb_r, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsb_r, k=-rotations_count, dims=[2, 3])
output_msb /= 4
output_lsb /= 4
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
fourier_loss_fn = losses.FocalFrequencyLoss()
def loss_fn(pred, target):
return fourier_loss_fn(pred/255, target/255) * 1e8
return loss_fn

@ -41,7 +41,7 @@ class ValOptions():
args.valout_dir = Path(args.exp_dir).resolve() / 'val' args.valout_dir = Path(args.exp_dir).resolve() / 'val'
if not args.valout_dir.exists(): if not args.valout_dir.exists():
args.valout_dir.mkdir() args.valout_dir.mkdir()
args.current_iter = args.model_name.split('_')[-1] args.current_iter = int(args.model_name.split('_')[-1])
args.results_path = os.path.join(args.valout_dir, f'results_{args.model_name}_{args.device}.csv') args.results_path = os.path.join(args.valout_dir, f'results_{args.model_name}_{args.device}.csv')
# Tensorboard for monitoring # Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.valout_dir) writer = SummaryWriter(log_dir=args.valout_dir)

@ -45,7 +45,7 @@ class TrainOptions:
parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration') parser.add_argument('--save_step', type=int, default=2000, help='save models every N iteration')
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers") parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model') parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.") parser.add_argument('--color_model', type=str, default="RGB", help="Color model for train and test dataset.")
@ -59,6 +59,8 @@ class TrainOptions:
args.model_path = Path(args.model_path) if not args.model_path is None else None args.model_path = Path(args.model_path) if not args.model_path is None else None
args.train_datasets = args.train_datasets.split(',') args.train_datasets = args.train_datasets.split(',')
args.val_datasets = args.val_datasets.split(',') args.val_datasets = args.val_datasets.split(',')
if not args.model_path is None:
args.start_iter = int(args.model_path.stem.split("_")[-1])
return args return args
def __repr__(self): def __repr__(self):
@ -110,7 +112,8 @@ if __name__ == "__main__":
if 'lut' in config.model.lower(): if 'lut' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale) model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale)
model = model.to(torch.device(config.device)) model = model.to(torch.device(config.device))
optimizer = AdamWScheduleFree(model.parameters()) optimizer = AdamWScheduleFree(model.parameters(), lr=1e-2, betas=(0.9, 0.95))
# optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95))
print(optimizer) print(optimizer)
prepare_experiment_folder(config) prepare_experiment_folder(config)
@ -140,7 +143,7 @@ if __name__ == "__main__":
dataset = train_dataset, dataset = train_dataset,
batch_size = config.batch_size, batch_size = config.batch_size,
num_workers = config.worker_num, num_workers = config.worker_num,
shuffle = False, shuffle = True,
drop_last = False, drop_last = False,
pin_memory = True, pin_memory = True,
prefetch_factor = config.prefetch_factor prefetch_factor = config.prefetch_factor
@ -165,7 +168,9 @@ if __name__ == "__main__":
config.current_iter = i config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
loss_fn = model.get_loss_fn()
for i in range(config.start_iter + 1, config.total_iter + 1): for i in range(config.start_iter + 1, config.total_iter + 1):
config.current_iter = i
torch.cuda.empty_cache() torch.cuda.empty_cache()
start_time = time.time() start_time = time.time()
try: try:
@ -179,8 +184,8 @@ if __name__ == "__main__":
start_time = time.time() start_time = time.time()
pred = model(lr_patch) pred = model(x=lr_patch, config=config)
loss = F.mse_loss(pred/255, hr_patch/255) loss = loss_fn(pred=pred, target=hr_patch)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
@ -194,9 +199,14 @@ if __name__ == "__main__":
# Show information # Show information
if i % config.display_step == 0: if i % config.display_step == 0:
config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i) config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i)
config.writer.add_scalar('loss', loss.item(), i)
config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, GPixel:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
model.__class__.__name__, i, accum_samples, l_accum[0] / config.display_step, prepare_data_time / config.display_step, config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, loss:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
model.__class__.__name__,
i,
accum_samples,
l_accum[0] / config.display_step,
prepare_data_time / config.display_step,
forward_backward_time / config.display_step)) forward_backward_time / config.display_step))
l_accum = [0., 0., 0.] l_accum = [0., 0., 0.]
prepare_data_time = 0. prepare_data_time = 0.
@ -208,7 +218,6 @@ if __name__ == "__main__":
# Validation # Validation
if i % config.val_step == 0: if i % config.val_step == 0:
config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve() model_path = (Path(config.checkpoint_dir) / f"{model.__class__.__name__}_{i}.pth").resolve()

Loading…
Cancel
Save