From 96d27bfd053b161e77659bdb54ff6f9b9ff76929 Mon Sep 17 00:00:00 2001 From: vlpr Date: Wed, 15 May 2024 13:38:41 +0000 Subject: [PATCH] tmpfix --- src/models/rcnet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/models/rcnet.py b/src/models/rcnet.py index e036fd6..d15c23a 100644 --- a/src/models/rcnet.py +++ b/src/models/rcnet.py @@ -6,7 +6,7 @@ from common.utils import round_func from pathlib import Path from common import lut from . import rclut -from common.layers import DenseConvUpscaleBlock +from common import layers class ReconstructedConvCentered(nn.Module): def __init__(self, hidden_dim, window_size=7): @@ -40,7 +40,7 @@ class RCBlockCentered(nn.Module): super(RCBlockCentered, self).__init__() self.window_size = window_size self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) def forward(self, x): b,c,hs,ws = x.shape @@ -127,7 +127,7 @@ class RCBlockRot90(nn.Module): super(RCBlockRot90, self).__init__() self.window_size = window_size self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) def forward(self, x): b,c,hs,ws = x.shape @@ -390,7 +390,7 @@ class RCBlockRot90Unlutable(nn.Module): super(RCBlockRot90Unlutable, self).__init__() self.window_size = window_size self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) def forward(self, x): b,c,hs,ws = x.shape @@ -495,7 +495,7 @@ class RCBlockCenteredUnlutable(nn.Module): super(RCBlockRot90Unlutable, self).__init__() self.window_size = window_size self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size) - self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) + self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) def forward(self, x): b,c,hs,ws = x.shape