lut_reproduce/src/models/rclut.py

505 lines
26 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from common.utils import round_func
from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
from pathlib import Path
class RCLutCentered_3x3(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutCentered_3x3, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts, dense_conv_lut
):
scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale)
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1]
x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
x = x.view(b, c, x.shape[-2], x.shape[-1])
return x
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
")"])
class RCLutCentered_7x7(nn.Module):
def __init__(
self,
window_size,
quantization_interval,
scale
):
super(RCLutCentered_7x7, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts, dense_conv_lut
):
scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale)
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
# x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4)
x = x.view(b, c, x.shape[-2], x.shape[-1])
return x
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
")"])
class RCLutRot90_3x3(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutRot90_3x3, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts, dense_conv_lut
):
scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
")"])
class RCLutRot90_7x7(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutRot90_7x7, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts, dense_conv_lut
):
scale = int(dense_conv_lut.shape[-1])
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
window_size = rc_conv_luts.shape[0]
lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale)
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
return lut_model
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
output /= 4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
")"])
class RCLutx1(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutx1, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
rc_conv_luts_3x3, dense_conv_lut_3x3,
rc_conv_luts_5x5, dense_conv_lut_5x5,
rc_conv_luts_7x7, dense_conv_lut_7x7
):
scale = int(dense_conv_lut_3x3.shape[-1])
quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1)
lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale)
lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32))
lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32))
lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32))
lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32))
lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32))
lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32))
return lut_model
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}",
f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}",
f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}",
f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}",
f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}",
f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}",
")"])
class RCLutx2(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutx2, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
):
scale = int(s2_dense_conv_lut_3x3.shape[-1])
quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale)
lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
return lut_model
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
")"])
class RCLutx2Centered(nn.Module):
def __init__(
self,
quantization_interval,
scale
):
super(RCLutx2Centered, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@staticmethod
def init_from_lut(
s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
):
scale = int(s2_dense_conv_lut_3x3.shape[-1])
quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale)
lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
return lut_model
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
x = forward_rc_conv_centered(index=index, lut=rc_conv_lut)
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
return x
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w).type(torch.float32)
output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
x = output
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
k=-rotations_count,
dims=[2, 3]
)
output += torch.rot90(
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
k=-rotations_count,
dims=[2, 3]
)
output /= 3*4
output = output.view(b, c, output.shape[-2], output.shape[-1])
return output
def __repr__(self):
return "\n".join([
f"{self.__class__.__name__}(",
f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
")"])