diff --git a/src/models/rcnet.py b/src/models/rcnet.py index e036fd6..d15c23a 100644 --- a/src/models/rcnet.py +++ b/src/models/rcnet.py @@ -6,7 +6,7 @@ from common.utils import round_func from pathlib import Path from common import lut from . import rclut -from common.layers import DenseConvUpscaleBlock +from common import layers class ReconstructedConvCentered(nn.Module): def __init__(self, hidden_dim, window_size=7): @@ -40,7 +40,7 @@ class RCBlockCentered(nn.Module): super(RCBlockCentered, self).__init__() self.window_size = window_size self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) def forward(self, x): b,c,hs,ws = x.shape @@ -127,7 +127,7 @@ class RCBlockRot90(nn.Module): super(RCBlockRot90, self).__init__() self.window_size = window_size self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) def forward(self, x): b,c,hs,ws = x.shape @@ -390,7 +390,7 @@ class RCBlockRot90Unlutable(nn.Module): super(RCBlockRot90Unlutable, self).__init__() self.window_size = window_size self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) def forward(self, x): b,c,hs,ws = x.shape @@ -495,7 +495,7 @@ class RCBlockCenteredUnlutable(nn.Module): super(RCBlockRot90Unlutable, self).__init__() self.window_size = window_size self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) def forward(self, x): b,c,hs,ws = x.shape