hklut and sdy lut experiments.
parent
9992763c9f
commit
0dd154d0cd
File diff suppressed because one or more lines are too long
@ -0,0 +1,201 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from common.lut import select_index_4dlut_tetrahedral
|
||||
from common import layers
|
||||
from common.utils import round_func
|
||||
|
||||
class HDBLut(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
quantization_interval,
|
||||
scale
|
||||
):
|
||||
super(HDBLut, self).__init__()
|
||||
self.scale = scale
|
||||
self.quantization_interval = quantization_interval
|
||||
self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
@staticmethod
|
||||
def init_from_numpy(
|
||||
stage_lut
|
||||
):
|
||||
scale = int(stage_lut.shape[-1])
|
||||
quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
lut_model = HDBLut(quantization_interval=quantization_interval, scale=scale)
|
||||
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
return lut_model
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w).type(torch.float32)
|
||||
x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
# class SRLutY(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(SRLutY, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
# self.rgb_to_ycbcr = layers.RgbToYcbcr()
|
||||
# self.ycbcr_to_rgb = layers.YcbcrToRgb()
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# stage_lut
|
||||
# ):
|
||||
# scale = int(stage_lut.shape[-1])
|
||||
# quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
# lut_model = SRLutY(quantization_interval=quantization_interval, scale=scale)
|
||||
# lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
# return lut_model
|
||||
|
||||
# def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
# b,c,h,w = x.shape
|
||||
# x = percieve_pattern(x)
|
||||
# x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
# x = round_func(x)
|
||||
# x = x.reshape(b, c, h, w, scale, scale)
|
||||
# x = x.permute(0,1,2,4,3,5)
|
||||
# x = x.reshape(b, c, h*scale, w*scale)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = self.rgb_to_ycbcr(x)
|
||||
# y = x[:,0:1,:,:]
|
||||
# cbcr = x[:,1:,:,:]
|
||||
# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
|
||||
|
||||
# output = self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
|
||||
# output = torch.cat([output, cbcr_scaled], dim=1)
|
||||
# output = self.ycbcr_to_rgb(output).clamp(0, 255)
|
||||
# return output
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
# class SRLutR90(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(SRLutR90, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# stage_lut
|
||||
# ):
|
||||
# scale = int(stage_lut.shape[-1])
|
||||
# quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
# lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale)
|
||||
# lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
# return lut_model
|
||||
|
||||
# def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
# b,c,h,w = x.shape
|
||||
# x = percieve_pattern(x)
|
||||
# x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
# x = round_func(x)
|
||||
# x = x.reshape(b, c, h, w, scale, scale)
|
||||
# x = x.permute(0,1,2,4,3,5)
|
||||
# x = x.reshape(b, c, h*scale, w*scale)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.reshape(b*c, 1, h, w)
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
# output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
|
||||
# for rotations_count in range(1, 4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 4
|
||||
# output = output.reshape(b, c, h*self.scale, w*self.scale)
|
||||
# return output
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
||||
|
||||
|
||||
# class SRLutR90Y(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# quantization_interval,
|
||||
# scale
|
||||
# ):
|
||||
# super(SRLutR90Y, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.quantization_interval = quantization_interval
|
||||
# self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
||||
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
# self.rgb_to_ycbcr = layers.RgbToYcbcr()
|
||||
# self.ycbcr_to_rgb = layers.YcbcrToRgb()
|
||||
|
||||
# @staticmethod
|
||||
# def init_from_numpy(
|
||||
# stage_lut
|
||||
# ):
|
||||
# scale = int(stage_lut.shape[-1])
|
||||
# quantization_interval = 256//(stage_lut.shape[0]-1)
|
||||
# lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale)
|
||||
# lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
||||
# return lut_model
|
||||
|
||||
# def forward_stage(self, x, scale, percieve_pattern, lut):
|
||||
# b,c,h,w = x.shape
|
||||
# x = percieve_pattern(x)
|
||||
# x = select_index_4dlut_tetrahedral(index=x, lut=lut)
|
||||
# x = round_func(x)
|
||||
# x = x.reshape(b, c, h, w, scale, scale)
|
||||
# x = x.permute(0,1,2,4,3,5)
|
||||
# x = x.reshape(b, c, h*scale, w*scale)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = self.rgb_to_ycbcr(x)
|
||||
# y = x[:,0:1,:,:]
|
||||
# cbcr = x[:,1:,:,:]
|
||||
# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
|
||||
|
||||
# output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
||||
# output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
|
||||
# for rotations_count in range(1,4):
|
||||
# rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 4
|
||||
# output = torch.cat([output, cbcr_scaled], dim=1)
|
||||
# output = self.ycbcr_to_rgb(output).clamp(0, 255)
|
||||
# return output
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
|
@ -0,0 +1,281 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from common.utils import round_func
|
||||
from common import lut
|
||||
from pathlib import Path
|
||||
# from . import srlut
|
||||
from common import layers
|
||||
|
||||
class HDBNet(nn.Module):
|
||||
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
super(HDBNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
|
||||
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
|
||||
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
|
||||
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = stage(x)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= 4*3
|
||||
output_lsb /= 4*2
|
||||
output_msb = output_msb + output_lsb
|
||||
x = output_msb
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= 4*3
|
||||
output_lsb /= 4*2
|
||||
output_msb = output_msb + output_lsb
|
||||
x = output_msb
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
lut_model = srlut.SRLut.init_from_numpy(stage_lut)
|
||||
return lut_model
|
||||
|
||||
class HDBLNet(nn.Module):
|
||||
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
super(HDBLNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
|
||||
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
self.stage2_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
|
||||
|
||||
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
|
||||
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=3)
|
||||
|
||||
def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
b,c,h,w = x.shape
|
||||
x = percieve_pattern(x)
|
||||
x = stage(x)
|
||||
x = round_func(x)
|
||||
x = x.reshape(b, c, h, w, scale, scale)
|
||||
x = x.permute(0,1,2,4,3,5)
|
||||
x = x.reshape(b, c, h*scale, w*scale)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
b,c,h,w = x.shape
|
||||
x = x.reshape(b*c, 1, h, w)
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_3L, self.stage1_3L), k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= 4*3
|
||||
output_lsb /= 4
|
||||
output_msb = output_msb + output_lsb
|
||||
x = output_msb
|
||||
lsb = x % 16
|
||||
msb = x - lsb
|
||||
output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
|
||||
output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
|
||||
for rotations_count in range(4):
|
||||
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
|
||||
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
|
||||
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
|
||||
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_3L, self.stage2_3L), k=-rotations_count, dims=[2, 3])
|
||||
output_msb /= 4*3
|
||||
output_lsb /= 4
|
||||
output_msb = output_msb + output_lsb
|
||||
x = output_msb
|
||||
x = x.reshape(b, c, h*self.scale, w*self.scale)
|
||||
return x
|
||||
|
||||
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
lut_model = srlut.SRLut.init_from_numpy(stage_lut)
|
||||
return lut_model
|
||||
|
||||
|
||||
# class SRNetY(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(SRNetY, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.stage1_S = layers.UpscaleBlock(
|
||||
# hidden_dim=hidden_dim,
|
||||
# layers_count=layers_count,
|
||||
# upscale_factor=self.scale
|
||||
# )
|
||||
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
# self.rgb_to_ycbcr = layers.RgbToYcbcr()
|
||||
# self.ycbcr_to_rgb = layers.YcbcrToRgb()
|
||||
|
||||
# def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
# b,c,h,w = x.shape
|
||||
# x = percieve_pattern(x)
|
||||
# x = stage(x)
|
||||
# x = round_func(x)
|
||||
# x = x.reshape(b, c, h, w, scale, scale)
|
||||
# x = x.permute(0,1,2,4,3,5)
|
||||
# x = x.reshape(b, c, h*scale, w*scale)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = self.rgb_to_ycbcr(x)
|
||||
# y = x[:,0:1,:,:]
|
||||
# cbcr = x[:,1:,:,:]
|
||||
# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
|
||||
|
||||
# x = y.view(b, 1, h, w)
|
||||
# output = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
|
||||
# output = torch.cat([output, cbcr_scaled], dim=1)
|
||||
# output = self.ycbcr_to_rgb(output).clamp(0, 255)
|
||||
# return output
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
# lut_model = srlut.SRLutY.init_from_numpy(stage_lut)
|
||||
# return lut_model
|
||||
|
||||
# class SRNetR90(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(SRNetR90, self).__init__()
|
||||
# self.scale = scale
|
||||
# self.stage1_S = layers.UpscaleBlock(
|
||||
# hidden_dim=hidden_dim,
|
||||
# layers_count=layers_count,
|
||||
# upscale_factor=self.scale
|
||||
# )
|
||||
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
|
||||
# def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
# b,c,h,w = x.shape
|
||||
# x = percieve_pattern(x)
|
||||
# x = stage(x)
|
||||
# x = round_func(x)
|
||||
# x = x.reshape(b, c, h, w, scale, scale)
|
||||
# x = x.permute(0,1,2,4,3,5)
|
||||
# x = x.reshape(b, c, h*scale, w*scale)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = x.reshape(b*c, 1, h, w)
|
||||
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
# output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
|
||||
# for rotations_count in range(1,4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 4
|
||||
# output = output.reshape(b, c, h*self.scale, w*self.scale)
|
||||
# return output
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
# lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
|
||||
# return lut_model
|
||||
|
||||
# class SRNetR90Y(nn.Module):
|
||||
# def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
|
||||
# super(SRNetR90Y, self).__init__()
|
||||
# self.scale = scale
|
||||
# s_pattern=[[0,0],[0,1],[1,0],[1,1]]
|
||||
# self.stage1_S = layers.UpscaleBlock(
|
||||
# hidden_dim=hidden_dim,
|
||||
# layers_count=layers_count,
|
||||
# upscale_factor=self.scale
|
||||
# )
|
||||
# self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
|
||||
# self.rgb_to_ycbcr = layers.RgbToYcbcr()
|
||||
# self.ycbcr_to_rgb = layers.YcbcrToRgb()
|
||||
|
||||
# def forward_stage(self, x, scale, percieve_pattern, stage):
|
||||
# b,c,h,w = x.shape
|
||||
# x = percieve_pattern(x)
|
||||
# x = stage(x)
|
||||
# x = round_func(x)
|
||||
# x = x.reshape(b, c, h, w, scale, scale)
|
||||
# x = x.permute(0,1,2,4,3,5)
|
||||
# x = x.reshape(b, c, h*scale, w*scale)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# b,c,h,w = x.shape
|
||||
# x = self.rgb_to_ycbcr(x)
|
||||
# y = x[:,0:1,:,:]
|
||||
# cbcr = x[:,1:,:,:]
|
||||
# cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
|
||||
|
||||
# x = y.view(b, 1, h, w)
|
||||
# output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
|
||||
# output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S)
|
||||
# for rotations_count in range(1,4):
|
||||
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
||||
# output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
|
||||
# output /= 4
|
||||
# output = torch.cat([output, cbcr_scaled], dim=1)
|
||||
# output = self.ycbcr_to_rgb(output).clamp(0, 255)
|
||||
# return output
|
||||
|
||||
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
|
||||
# stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
|
||||
# lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut)
|
||||
# return lut_model
|
Loading…
Reference in New Issue