You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
4.5 KiB
Python
122 lines
4.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from common.lut import forward_2x2_input_SxS_output
|
|
from common import layers
|
|
|
|
class SRLut(nn.Module):
|
|
def __init__(
|
|
self,
|
|
quantization_interval,
|
|
scale
|
|
):
|
|
super(SRLut, self).__init__()
|
|
self.scale = scale
|
|
self.quantization_interval = quantization_interval
|
|
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
@staticmethod
|
|
def init_from_numpy(
|
|
stage_lut
|
|
):
|
|
scale = int(stage_lut.shape[-1])
|
|
quantization_interval = 256//(stage_lut.shape[0]-1)
|
|
lut_model = SRLut(quantization_interval=quantization_interval, scale=scale)
|
|
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
|
return lut_model
|
|
|
|
def forward(self, x):
|
|
b,c,h,w = x.shape
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
x = forward_2x2_input_SxS_output(index=x, lut=self.stage_lut)
|
|
x = x.view(b, c, x.shape[-2], x.shape[-1])
|
|
return x
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
|
|
|
|
|
|
|
class SRLutR90(nn.Module):
|
|
def __init__(
|
|
self,
|
|
quantization_interval,
|
|
scale
|
|
):
|
|
super(SRLutR90, self).__init__()
|
|
self.scale = scale
|
|
self.quantization_interval = quantization_interval
|
|
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
@staticmethod
|
|
def init_from_numpy(
|
|
stage_lut
|
|
):
|
|
scale = int(stage_lut.shape[-1])
|
|
quantization_interval = 256//(stage_lut.shape[0]-1)
|
|
lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale)
|
|
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
|
return lut_model
|
|
|
|
def forward(self, x):
|
|
b,c,h,w = x.shape
|
|
x = x.view(b*c, 1, h, w)
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
for rotations_count in range(4):
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
|
|
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
|
output += unrotated_prediction
|
|
output /= 4
|
|
output = output.view(b, c, h*self.scale, w*self.scale)
|
|
return output
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
|
|
|
|
|
class SRLutR90Y(nn.Module):
|
|
def __init__(
|
|
self,
|
|
quantization_interval,
|
|
scale
|
|
):
|
|
super(SRLutR90Y, self).__init__()
|
|
self.scale = scale
|
|
self.quantization_interval = quantization_interval
|
|
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
self.rgb_to_ycbcr = layers.RgbToYcbcr()
|
|
self.ycbcr_to_rgb = layers.YcbcrToRgb()
|
|
|
|
@staticmethod
|
|
def init_from_numpy(
|
|
stage_lut
|
|
):
|
|
scale = int(stage_lut.shape[-1])
|
|
quantization_interval = 256//(stage_lut.shape[0]-1)
|
|
lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale)
|
|
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
|
return lut_model
|
|
|
|
def forward(self, x):
|
|
b,c,h,w = x.shape
|
|
x = self.rgb_to_ycbcr(x)
|
|
y = x[:,0:1,:,:]
|
|
cbcr = x[:,1:,:,:]
|
|
cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
|
|
|
|
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
for rotations_count in range(4):
|
|
rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
|
|
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.stage_lut)
|
|
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
|
output += unrotated_prediction
|
|
output /= 4
|
|
output = torch.cat([output, cbcr_scaled], dim=1)
|
|
output = self.ycbcr_to_rgb(output).clamp(0, 255)
|
|
return output
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}" |