added srnet for y channel, remove channel toss in dataloader, remove junk

main
protsenkovi 6 months ago
parent 632d1ab7d6
commit 58dab59a11

@ -50,17 +50,17 @@ class SRTrainDataset(Dataset):
i = random.randint(0, lr_image.shape[0] - self.sz) i = random.randint(0, lr_image.shape[0] - self.sz)
j = random.randint(0, lr_image.shape[1] - self.sz) j = random.randint(0, lr_image.shape[1] - self.sz)
c = random.choice([0, 1, 2]) # c = random.choice([0, 1, 2])
hr_patch = hr_image[ hr_patch = hr_image[
(i*scale):(i*scale + self.sz*scale), (i*scale):(i*scale + self.sz*scale),
(j*scale):(j*scale + self.sz*scale), (j*scale):(j*scale + self.sz*scale),
c:(c+1) :
] ]
lr_patch = lr_image[ lr_patch = lr_image[
i:(i + self.sz), i:(i + self.sz),
j:(j + self.sz), j:(j + self.sz),
c:(c+1) :
] ]
if self.rigid_aug: if self.rigid_aug:

@ -99,4 +99,67 @@ class ConvUpscaleBlock(nn.Module):
x = torch.tanh(x) x = torch.tanh(x)
x = x*127.5 + 127.5 x = x*127.5 + 127.5
x = round_func(x) x = round_func(x)
return x return x
class RgbToYcbcr(nn.Module):
r"""Convert an image from RGB to YCbCr.
The image data is assumed to be in the range of (0, 1).
Returns:
YCbCr version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> ycbcr = RgbToYcbcr()
>>> output = ycbcr(input) # 2x3x4x5
"""
def forward(self, image):
r = image[..., 0, :, :]
g = image[..., 1, :, :]
b = image[..., 2, :, :]
delta = 0.5
y = 0.299 * r + 0.587 * g + 0.114 * b
cb = (b - y) * 0.564 + delta
cr = (r - y) * 0.713 + delta
return torch.stack([y, cb, cr], -3)
class YcbcrToRgb(nn.Module):
r"""Convert an image from YCbCr to Rgb.
The image data is assumed to be in the range of (0, 1).
Returns:
RGB version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> rgb = YcbcrToRgb()
>>> output = rgb(input) # 2x3x4x5
"""
def forward(self, image):
y = image[..., 0, :, :]
cb = image[..., 1, :, :]
cr = image[..., 2, :, :]
delta = 0.5
cb_shifted = cb - delta
cr_shifted = cr - delta
r = y + 1.403 * cr_shifted
g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
b = y + 1.773 * cb_shifted
return torch.stack([r, g, b], -3)

@ -6,6 +6,7 @@ from scipy import signal
import torch import torch
import os import os
def round_func(input): def round_func(input):
# Backward Pass Differentiable Approximation (BPDA) # Backward Pass Differentiable Approximation (BPDA)
# This is equivalent to replacing round function (non-differentiable) # This is equivalent to replacing round function (non-differentiable)

@ -14,11 +14,11 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cu
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1) lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
lr_image = lr_image.unsqueeze(0).to(torch.device(device)) lr_image = lr_image.unsqueeze(0).to(torch.device(device))
b, c, h, w = lr_image.shape b, c, h, w = lr_image.shape
lr_image = lr_image.reshape(b*c, 1, h, w) lr_image = lr_image.reshape(b, c, h, w)
# predict # predict
pred_lr_image = model(lr_image) pred_lr_image = model(lr_image)
# postprocess # postprocess
pred_lr_image = pred_lr_image.reshape(b, c, h*model.scale, w*model.scale).squeeze(0).permute(1,2,0).type(torch.uint8) pred_lr_image = pred_lr_image.squeeze(0).permute(1,2,0).type(torch.uint8)
pred_lr_image = pred_lr_image.cpu().numpy() pred_lr_image = pred_lr_image.cpu().numpy()
run_time_ns = time.perf_counter_ns() - start_time run_time_ns = time.perf_counter_ns() - start_time
torch.cuda.empty_cache() torch.cuda.empty_cache()

@ -11,6 +11,7 @@ AVAILABLE_MODELS = {
'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut, 'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut,
'SRNetDense': srnet.SRNetDense, 'SRNetDense': srnet.SRNetDense,
'SRNetDenseRot90': srnet.SRNetDenseRot90, 'SRLutRot90': srlut.SRLutRot90, 'SRNetDenseRot90': srnet.SRNetDenseRot90, 'SRLutRot90': srlut.SRLutRot90,
'SRNetDenseRot90Y': srnet.SRNetDenseRot90Y, 'SRLutRot90Y': srlut.SRLutRot90Y,
'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,

@ -4,6 +4,7 @@ import torch.nn.functional as F
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from common.lut import forward_2x2_input_SxS_output from common.lut import forward_2x2_input_SxS_output
from common.layers import RgbToYcbcr, YcbcrToRgb
class SRLut(nn.Module): class SRLut(nn.Module):
def __init__( def __init__(
@ -72,5 +73,50 @@ class SRLutRot90(nn.Module):
output = output.view(b, c, h*self.scale, w*self.scale) output = output.view(b, c, h*self.scale, w*self.scale)
return output return output
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
class SRLutRot90Y(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(SRLutRot90Y, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rgb_to_ycbcr = RgbToYcbcr()
self.ycbcr_to_rgb = YcbcrToRgb()
@staticmethod
def init_from_lut(
stage_lut
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutRot90Y(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output)
return output
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"

@ -5,8 +5,9 @@ import numpy as np
from common.utils import round_func from common.utils import round_func
from common import lut from common import lut
from pathlib import Path from pathlib import Path
from .srlut import SRLut, SRLutRot90 from . import srlut
from common.layers import PercievePattern, DenseConvUpscaleBlock, ConvUpscaleBlock from common.layers import PercievePattern, DenseConvUpscaleBlock, ConvUpscaleBlock, RgbToYcbcr, YcbcrToRgb
class SRNet(nn.Module): class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
@ -26,7 +27,7 @@ class SRNet(nn.Module):
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = SRLut.init_from_lut(stage_lut) lut_model = srlut.SRLut.init_from_lut(stage_lut)
return lut_model return lut_model
@ -48,7 +49,7 @@ class SRNetDense(nn.Module):
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = SRLut.init_from_lut(stage_lut) lut_model = srlut.SRLut.init_from_lut(stage_lut)
return lut_model return lut_model
class SRNetDenseRot90(nn.Module): class SRNetDenseRot90(nn.Module):
@ -75,5 +76,40 @@ class SRNetDenseRot90(nn.Module):
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = SRLutRot90.init_from_lut(stage_lut) lut_model = srlut.SRLutRot90.init_from_lut(stage_lut)
return lut_model return lut_model
class SRNetDenseRot90Y(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetDenseRot90Y, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.rgb_to_ycbcr = RgbToYcbcr()
self.ycbcr_to_rgb = YcbcrToRgb()
def forward(self, x):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
x = y.view(b, 1, h, w)
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rx = torch.rot90(x, k=rotations_count, dims=[2, 3])
_,_,rh,rw = rx.shape
rx = self._extract_pattern_S(rx)
rx = self.stage(rx)
rx = rx.view(b, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(b, 1, rh*self.scale, rw*self.scale)
output += torch.rot90(rx, k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutRot90Y.init_from_lut(stage_lut)
return lut_model

@ -91,21 +91,6 @@ def prepare_experiment_folder(config):
if not config.logs_dir.exists(): if not config.logs_dir.exists():
config.logs_dir.mkdir() config.logs_dir.mkdir()
def dice_loss(inputs, targets, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = F.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
return 1 - dice
if __name__ == "__main__": if __name__ == "__main__":
script_start_time = datetime.now() script_start_time = datetime.now()

Loading…
Cancel
Save