You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
4.1 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common import lut
from pathlib import Path
from . import srlut
from common import layers
class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__()
self.scale = scale
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
receptive_field_idxes=s_pattern,
center=[0,0],
window_size=2,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = self.stage1_S(x)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_lut(stage_lut)
return lut_model
class SRNetR90(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90, self).__init__()
self.scale = scale
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
receptive_field_idxes=s_pattern,
center=[0,0],
window_size=2,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.stage1_S(x)
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutRot90.init_from_lut(stage_lut)
return lut_model
class SRNetR90Y(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90Y, self).__init__()
self.scale = scale
s_pattern=[[0,0],[0,1],[1,0],[1,1]]
self.stage1_S = layers.UpscaleBlock(
receptive_field_idxes=s_pattern,
center=[0,0],
window_size=2,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
def forward(self, x):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
y = x[:,0:1,:,:]
cbcr = x[:,1:,:,:]
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
x = y.view(b, 1, h, w)
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.stage1_S(x)
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutRot90Y.init_from_lut(stage_lut)
return lut_model