|
|
@ -6,7 +6,7 @@ from common.utils import round_func
|
|
|
|
from pathlib import Path
|
|
|
|
from pathlib import Path
|
|
|
|
from common import lut
|
|
|
|
from common import lut
|
|
|
|
from . import rclut
|
|
|
|
from . import rclut
|
|
|
|
from common.layers import DenseConvUpscaleBlock
|
|
|
|
from common import layers
|
|
|
|
|
|
|
|
|
|
|
|
class ReconstructedConvCentered(nn.Module):
|
|
|
|
class ReconstructedConvCentered(nn.Module):
|
|
|
|
def __init__(self, hidden_dim, window_size=7):
|
|
|
|
def __init__(self, hidden_dim, window_size=7):
|
|
|
@ -40,7 +40,7 @@ class RCBlockCentered(nn.Module):
|
|
|
|
super(RCBlockCentered, self).__init__()
|
|
|
|
super(RCBlockCentered, self).__init__()
|
|
|
|
self.window_size = window_size
|
|
|
|
self.window_size = window_size
|
|
|
|
self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
@ -127,7 +127,7 @@ class RCBlockRot90(nn.Module):
|
|
|
|
super(RCBlockRot90, self).__init__()
|
|
|
|
super(RCBlockRot90, self).__init__()
|
|
|
|
self.window_size = window_size
|
|
|
|
self.window_size = window_size
|
|
|
|
self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
@ -390,7 +390,7 @@ class RCBlockRot90Unlutable(nn.Module):
|
|
|
|
super(RCBlockRot90Unlutable, self).__init__()
|
|
|
|
super(RCBlockRot90Unlutable, self).__init__()
|
|
|
|
self.window_size = window_size
|
|
|
|
self.window_size = window_size
|
|
|
|
self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
@ -495,7 +495,7 @@ class RCBlockCenteredUnlutable(nn.Module):
|
|
|
|
super(RCBlockRot90Unlutable, self).__init__()
|
|
|
|
super(RCBlockRot90Unlutable, self).__init__()
|
|
|
|
self.window_size = window_size
|
|
|
|
self.window_size = window_size
|
|
|
|
self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size)
|
|
|
|
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
|
b,c,hs,ws = x.shape
|
|
|
|