|
|
@ -3,503 +3,503 @@ import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
from common.utils import round_func
|
|
|
|
from common.utils import round_func
|
|
|
|
from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
|
|
|
|
# from common.lut import forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
|
|
|
|
from pathlib import Path
|
|
|
|
# from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
class RCLutCentered_3x3(nn.Module):
|
|
|
|
# class RCLutCentered_3x3(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
# def __init__(
|
|
|
|
self,
|
|
|
|
# self,
|
|
|
|
quantization_interval,
|
|
|
|
# quantization_interval,
|
|
|
|
scale
|
|
|
|
# scale
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
super(RCLutCentered_3x3, self).__init__()
|
|
|
|
# super(RCLutCentered_3x3, self).__init__()
|
|
|
|
self.scale = scale
|
|
|
|
# self.scale = scale
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
# self.quantization_interval = quantization_interval
|
|
|
|
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
# @staticmethod
|
|
|
|
def init_from_numpy(
|
|
|
|
# def init_from_numpy(
|
|
|
|
rc_conv_luts, dense_conv_lut
|
|
|
|
# rc_conv_luts, dense_conv_lut
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
scale = int(dense_conv_lut.shape[-1])
|
|
|
|
# scale = int(dense_conv_lut.shape[-1])
|
|
|
|
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
|
|
|
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
|
|
|
lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
# lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
|
|
|
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
|
|
|
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
|
|
|
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
|
|
|
return lut_model
|
|
|
|
# return lut_model
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# def forward(self, x):
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
# b,c,h,w = x.shape
|
|
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
|
|
|
|
# x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
|
|
|
|
x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
|
|
|
|
# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
|
|
|
|
x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1]
|
|
|
|
# x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1]
|
|
|
|
x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
|
|
|
|
# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
|
|
|
|
x = x.view(b, c, x.shape[-2], x.shape[-1])
|
|
|
|
# x = x.view(b, c, x.shape[-2], x.shape[-1])
|
|
|
|
return x
|
|
|
|
# return x
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
# def __repr__(self):
|
|
|
|
return "\n".join([
|
|
|
|
# return "\n".join([
|
|
|
|
f"{self.__class__.__name__}(",
|
|
|
|
# f"{self.__class__.__name__}(",
|
|
|
|
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
|
|
|
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
|
|
|
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
|
|
|
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
|
|
|
")"])
|
|
|
|
# ")"])
|
|
|
|
|
|
|
|
|
|
|
|
class RCLutCentered_7x7(nn.Module):
|
|
|
|
# class RCLutCentered_7x7(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
# def __init__(
|
|
|
|
self,
|
|
|
|
# self,
|
|
|
|
window_size,
|
|
|
|
# window_size,
|
|
|
|
quantization_interval,
|
|
|
|
# quantization_interval,
|
|
|
|
scale
|
|
|
|
# scale
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
super(RCLutCentered_7x7, self).__init__()
|
|
|
|
# super(RCLutCentered_7x7, self).__init__()
|
|
|
|
self.scale = scale
|
|
|
|
# self.scale = scale
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
# self.quantization_interval = quantization_interval
|
|
|
|
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
# @staticmethod
|
|
|
|
def init_from_numpy(
|
|
|
|
# def init_from_numpy(
|
|
|
|
rc_conv_luts, dense_conv_lut
|
|
|
|
# rc_conv_luts, dense_conv_lut
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
scale = int(dense_conv_lut.shape[-1])
|
|
|
|
# scale = int(dense_conv_lut.shape[-1])
|
|
|
|
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
|
|
|
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
|
|
|
lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
# lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
|
|
|
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
|
|
|
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
|
|
|
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
|
|
|
return lut_model
|
|
|
|
# return lut_model
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# def forward(self, x):
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
# b,c,h,w = x.shape
|
|
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
|
|
|
|
# x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
|
|
|
|
x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
|
|
|
|
# x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut)
|
|
|
|
# x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4)
|
|
|
|
# # x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4)
|
|
|
|
x = x.view(b, c, x.shape[-2], x.shape[-1])
|
|
|
|
# x = x.view(b, c, x.shape[-2], x.shape[-1])
|
|
|
|
return x
|
|
|
|
# return x
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
# def __repr__(self):
|
|
|
|
return "\n".join([
|
|
|
|
# return "\n".join([
|
|
|
|
f"{self.__class__.__name__}(",
|
|
|
|
# f"{self.__class__.__name__}(",
|
|
|
|
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
|
|
|
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
|
|
|
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
|
|
|
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
|
|
|
")"])
|
|
|
|
# ")"])
|
|
|
|
|
|
|
|
|
|
|
|
class RCLutRot90_3x3(nn.Module):
|
|
|
|
# class RCLutRot90_3x3(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
# def __init__(
|
|
|
|
self,
|
|
|
|
# self,
|
|
|
|
quantization_interval,
|
|
|
|
# quantization_interval,
|
|
|
|
scale
|
|
|
|
# scale
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
super(RCLutRot90_3x3, self).__init__()
|
|
|
|
# super(RCLutRot90_3x3, self).__init__()
|
|
|
|
self.scale = scale
|
|
|
|
# self.scale = scale
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
# self.quantization_interval = quantization_interval
|
|
|
|
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
# @staticmethod
|
|
|
|
def init_from_numpy(
|
|
|
|
# def init_from_numpy(
|
|
|
|
rc_conv_luts, dense_conv_lut
|
|
|
|
# rc_conv_luts, dense_conv_lut
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
scale = int(dense_conv_lut.shape[-1])
|
|
|
|
# scale = int(dense_conv_lut.shape[-1])
|
|
|
|
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
|
|
|
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
|
|
|
lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
# lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
|
|
|
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
|
|
|
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
|
|
|
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
|
|
|
return lut_model
|
|
|
|
# return lut_model
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# def forward(self, x):
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
# b,c,h,w = x.shape
|
|
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
for rotations_count in range(4):
|
|
|
|
# for rotations_count in range(4):
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
|
|
|
|
# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
|
|
|
|
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
|
|
|
|
# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
|
|
|
|
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
|
|
|
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
|
|
|
output += unrotated_prediction
|
|
|
|
# output += unrotated_prediction
|
|
|
|
output /= 4
|
|
|
|
# output /= 4
|
|
|
|
output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
return output
|
|
|
|
# return output
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
# def __repr__(self):
|
|
|
|
return "\n".join([
|
|
|
|
# return "\n".join([
|
|
|
|
f"{self.__class__.__name__}(",
|
|
|
|
# f"{self.__class__.__name__}(",
|
|
|
|
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
|
|
|
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
|
|
|
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
|
|
|
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
|
|
|
")"])
|
|
|
|
# ")"])
|
|
|
|
|
|
|
|
|
|
|
|
class RCLutRot90_7x7(nn.Module):
|
|
|
|
# class RCLutRot90_7x7(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
# def __init__(
|
|
|
|
self,
|
|
|
|
# self,
|
|
|
|
quantization_interval,
|
|
|
|
# quantization_interval,
|
|
|
|
scale
|
|
|
|
# scale
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
super(RCLutRot90_7x7, self).__init__()
|
|
|
|
# super(RCLutRot90_7x7, self).__init__()
|
|
|
|
self.scale = scale
|
|
|
|
# self.scale = scale
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
# self.quantization_interval = quantization_interval
|
|
|
|
self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
# @staticmethod
|
|
|
|
def init_from_numpy(
|
|
|
|
# def init_from_numpy(
|
|
|
|
rc_conv_luts, dense_conv_lut
|
|
|
|
# rc_conv_luts, dense_conv_lut
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
scale = int(dense_conv_lut.shape[-1])
|
|
|
|
# scale = int(dense_conv_lut.shape[-1])
|
|
|
|
quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
|
|
|
# quantization_interval = 256//(dense_conv_lut.shape[0]-1)
|
|
|
|
window_size = rc_conv_luts.shape[0]
|
|
|
|
# window_size = rc_conv_luts.shape[0]
|
|
|
|
lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
# lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
|
|
|
# lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
|
|
|
|
lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
|
|
|
# lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
|
|
|
|
return lut_model
|
|
|
|
# return lut_model
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# def forward(self, x):
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
# b,c,h,w = x.shape
|
|
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
for rotations_count in range(4):
|
|
|
|
# for rotations_count in range(4):
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
|
|
|
|
# rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
|
|
|
|
rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
|
|
|
|
# rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut)
|
|
|
|
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
|
|
|
# unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
|
|
|
|
output += unrotated_prediction
|
|
|
|
# output += unrotated_prediction
|
|
|
|
output /= 4
|
|
|
|
# output /= 4
|
|
|
|
output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
return output
|
|
|
|
# return output
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
# def __repr__(self):
|
|
|
|
return "\n".join([
|
|
|
|
# return "\n".join([
|
|
|
|
f"{self.__class__.__name__}(",
|
|
|
|
# f"{self.__class__.__name__}(",
|
|
|
|
f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
|
|
|
# f" rc_conv_luts size: {self.rc_conv_luts.shape}",
|
|
|
|
f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
|
|
|
# f" dense_conv_lut size: {self.dense_conv_lut.shape}",
|
|
|
|
")"])
|
|
|
|
# ")"])
|
|
|
|
|
|
|
|
|
|
|
|
class RCLutx1(nn.Module):
|
|
|
|
# class RCLutx1(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
# def __init__(
|
|
|
|
self,
|
|
|
|
# self,
|
|
|
|
quantization_interval,
|
|
|
|
# quantization_interval,
|
|
|
|
scale
|
|
|
|
# scale
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
super(RCLutx1, self).__init__()
|
|
|
|
# super(RCLutx1, self).__init__()
|
|
|
|
self.scale = scale
|
|
|
|
# self.scale = scale
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
# self.quantization_interval = quantization_interval
|
|
|
|
self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
# @staticmethod
|
|
|
|
def init_from_numpy(
|
|
|
|
# def init_from_numpy(
|
|
|
|
rc_conv_luts_3x3, dense_conv_lut_3x3,
|
|
|
|
# rc_conv_luts_3x3, dense_conv_lut_3x3,
|
|
|
|
rc_conv_luts_5x5, dense_conv_lut_5x5,
|
|
|
|
# rc_conv_luts_5x5, dense_conv_lut_5x5,
|
|
|
|
rc_conv_luts_7x7, dense_conv_lut_7x7
|
|
|
|
# rc_conv_luts_7x7, dense_conv_lut_7x7
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
scale = int(dense_conv_lut_3x3.shape[-1])
|
|
|
|
# scale = int(dense_conv_lut_3x3.shape[-1])
|
|
|
|
quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1)
|
|
|
|
# quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1)
|
|
|
|
|
|
|
|
|
|
|
|
lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
# lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
# lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
# lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
# lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
# lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
# lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
# lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
return lut_model
|
|
|
|
# return lut_model
|
|
|
|
|
|
|
|
|
|
|
|
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
|
|
|
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
|
|
|
x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
|
|
|
|
# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
|
|
|
|
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
|
|
|
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
|
|
|
return x
|
|
|
|
# return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# def forward(self, x):
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
# b,c,h,w = x.shape
|
|
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
for rotations_count in range(4):
|
|
|
|
# for rotations_count in range(4):
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output /= 3*4
|
|
|
|
# output /= 3*4
|
|
|
|
output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
return output
|
|
|
|
# return output
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
# def __repr__(self):
|
|
|
|
return "\n".join([
|
|
|
|
# return "\n".join([
|
|
|
|
f"{self.__class__.__name__}(",
|
|
|
|
# f"{self.__class__.__name__}(",
|
|
|
|
f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}",
|
|
|
|
# f" rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}",
|
|
|
|
f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}",
|
|
|
|
# f" dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}",
|
|
|
|
f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}",
|
|
|
|
# f" rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}",
|
|
|
|
f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}",
|
|
|
|
# f" dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}",
|
|
|
|
f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}",
|
|
|
|
# f" rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}",
|
|
|
|
f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}",
|
|
|
|
# f" dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}",
|
|
|
|
")"])
|
|
|
|
# ")"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RCLutx2(nn.Module):
|
|
|
|
# class RCLutx2(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
# def __init__(
|
|
|
|
self,
|
|
|
|
# self,
|
|
|
|
quantization_interval,
|
|
|
|
# quantization_interval,
|
|
|
|
scale
|
|
|
|
# scale
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
super(RCLutx2, self).__init__()
|
|
|
|
# super(RCLutx2, self).__init__()
|
|
|
|
self.scale = scale
|
|
|
|
# self.scale = scale
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
# self.quantization_interval = quantization_interval
|
|
|
|
self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
# @staticmethod
|
|
|
|
def init_from_numpy(
|
|
|
|
# def init_from_numpy(
|
|
|
|
s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
|
|
|
|
# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
|
|
|
|
s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
|
|
|
|
# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
|
|
|
|
s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
|
|
|
|
# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
|
|
|
|
s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
|
|
|
|
# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
|
|
|
|
s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
|
|
|
|
# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
|
|
|
|
s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
|
|
|
|
# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
scale = int(s2_dense_conv_lut_3x3.shape[-1])
|
|
|
|
# scale = int(s2_dense_conv_lut_3x3.shape[-1])
|
|
|
|
quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
|
|
|
|
# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
|
|
|
|
|
|
|
|
|
|
|
|
lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
# lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
return lut_model
|
|
|
|
# return lut_model
|
|
|
|
|
|
|
|
|
|
|
|
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
|
|
|
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
|
|
|
x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
|
|
|
|
# x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
|
|
|
|
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
|
|
|
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
|
|
|
return x
|
|
|
|
# return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# def forward(self, x):
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
# b,c,h,w = x.shape
|
|
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
|
|
|
|
# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
|
|
|
|
for rotations_count in range(4):
|
|
|
|
# for rotations_count in range(4):
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output /= 3*4
|
|
|
|
# output /= 3*4
|
|
|
|
x = output
|
|
|
|
# x = output
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
for rotations_count in range(4):
|
|
|
|
# for rotations_count in range(4):
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output /= 3*4
|
|
|
|
# output /= 3*4
|
|
|
|
output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
return output
|
|
|
|
# return output
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
# def __repr__(self):
|
|
|
|
return "\n".join([
|
|
|
|
# return "\n".join([
|
|
|
|
f"{self.__class__.__name__}(",
|
|
|
|
# f"{self.__class__.__name__}(",
|
|
|
|
f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
|
|
|
|
# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
|
|
|
|
f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
|
|
|
|
# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
|
|
|
|
f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
|
|
|
|
# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
|
|
|
|
f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
|
|
|
|
# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
|
|
|
|
f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
|
|
|
|
# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
|
|
|
|
f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
|
|
|
|
# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
|
|
|
|
f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
|
|
|
|
# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
|
|
|
|
f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
|
|
|
|
# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
|
|
|
|
f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
|
|
|
|
# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
|
|
|
|
f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
|
|
|
|
# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
|
|
|
|
f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
|
|
|
|
# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
|
|
|
|
f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
|
|
|
|
# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
|
|
|
|
")"])
|
|
|
|
# ")"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RCLutx2Centered(nn.Module):
|
|
|
|
# class RCLutx2Centered(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
# def __init__(
|
|
|
|
self,
|
|
|
|
# self,
|
|
|
|
quantization_interval,
|
|
|
|
# quantization_interval,
|
|
|
|
scale
|
|
|
|
# scale
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
super(RCLutx2Centered, self).__init__()
|
|
|
|
# super(RCLutx2Centered, self).__init__()
|
|
|
|
self.scale = scale
|
|
|
|
# self.scale = scale
|
|
|
|
self.quantization_interval = quantization_interval
|
|
|
|
# self.quantization_interval = quantization_interval
|
|
|
|
self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
# self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
# self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
# self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
|
|
|
|
self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
# self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
|
|
|
|
self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
# self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
# @staticmethod
|
|
|
|
def init_from_numpy(
|
|
|
|
# def init_from_numpy(
|
|
|
|
s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
|
|
|
|
# s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
|
|
|
|
s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
|
|
|
|
# s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
|
|
|
|
s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
|
|
|
|
# s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
|
|
|
|
s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
|
|
|
|
# s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
|
|
|
|
s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
|
|
|
|
# s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
|
|
|
|
s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
|
|
|
|
# s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
|
|
|
|
):
|
|
|
|
# ):
|
|
|
|
scale = int(s2_dense_conv_lut_3x3.shape[-1])
|
|
|
|
# scale = int(s2_dense_conv_lut_3x3.shape[-1])
|
|
|
|
quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
|
|
|
|
# quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
|
|
|
|
|
|
|
|
|
|
|
|
lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
# lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
# lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
# lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
# lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
# lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
# lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
# lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
# lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
|
|
|
|
lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
# lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
# lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
|
|
|
|
lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
# lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
# lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
|
|
|
|
lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
# lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
|
|
|
|
|
|
|
|
|
|
|
|
return lut_model
|
|
|
|
# return lut_model
|
|
|
|
|
|
|
|
|
|
|
|
def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
|
|
|
# def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
|
|
|
|
x = forward_rc_conv_centered(index=index, lut=rc_conv_lut)
|
|
|
|
# x = forward_rc_conv_centered(index=index, lut=rc_conv_lut)
|
|
|
|
x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
|
|
|
# x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
|
|
|
|
return x
|
|
|
|
# return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# def forward(self, x):
|
|
|
|
b,c,h,w = x.shape
|
|
|
|
# b,c,h,w = x.shape
|
|
|
|
x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
# x = x.view(b*c, 1, h, w).type(torch.float32)
|
|
|
|
output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
|
|
|
|
# output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
|
|
|
|
for rotations_count in range(4):
|
|
|
|
# for rotations_count in range(4):
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output /= 3*4
|
|
|
|
# output /= 3*4
|
|
|
|
x = output
|
|
|
|
# x = output
|
|
|
|
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
# output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
|
|
|
|
for rotations_count in range(4):
|
|
|
|
# for rotations_count in range(4):
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
# rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output += torch.rot90(
|
|
|
|
# output += torch.rot90(
|
|
|
|
self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
|
|
|
|
# self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7),
|
|
|
|
k=-rotations_count,
|
|
|
|
# k=-rotations_count,
|
|
|
|
dims=[2, 3]
|
|
|
|
# dims=[2, 3]
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
output /= 3*4
|
|
|
|
# output /= 3*4
|
|
|
|
output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
# output = output.view(b, c, output.shape[-2], output.shape[-1])
|
|
|
|
return output
|
|
|
|
# return output
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
# def __repr__(self):
|
|
|
|
return "\n".join([
|
|
|
|
# return "\n".join([
|
|
|
|
f"{self.__class__.__name__}(",
|
|
|
|
# f"{self.__class__.__name__}(",
|
|
|
|
f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
|
|
|
|
# f" s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
|
|
|
|
f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
|
|
|
|
# f" s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
|
|
|
|
f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
|
|
|
|
# f" s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
|
|
|
|
f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
|
|
|
|
# f" s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
|
|
|
|
f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
|
|
|
|
# f" s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
|
|
|
|
f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
|
|
|
|
# f" s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
|
|
|
|
f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
|
|
|
|
# f" s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
|
|
|
|
f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
|
|
|
|
# f" s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
|
|
|
|
f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
|
|
|
|
# f" s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
|
|
|
|
f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
|
|
|
|
# f" s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
|
|
|
|
f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
|
|
|
|
# f" s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
|
|
|
|
f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
|
|
|
|
# f" s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
|
|
|
|
")"])
|
|
|
|
# ")"])
|